///----------------------------------------------------------------- /// Namespace: /// Class: /// Description: /// Author: Date: <29.07.2019> /// Notes: <> ///----------------------------------------------------------------- /// using MLAgents; using OpenCvSharp; using System; using System.Collections.Generic; using UnityEngine; namespace Cozmo { public class CozmoAgent : Agent { // Possible Actions private const int STOP = 0; private const int FORWARD = 1; private const int RIGHT = 2; private const int LEFT = 3; // Used to determine different areas in the image (near to the center, far away) private const float NEAR_AREA_PERCENTAGE_OFFSET = 0.3f; [Tooltip("The virtual Cozmo camera")] public Camera renderCamera; [Tooltip("Reference to the CozmoMovement script")] public CozmoMovementController movementController; [Tooltip("The time between decisions at inference")] public float timeBetweenDecisionsAtInference; [Tooltip("The amout of actions that should be remembered by the agent.")] public int maxStoredMovementStates = 1; [Tooltip("If activated the maxStoredMovementStates has no impact on the training.")] public bool useOriginalActionMasking = true; private Academy academy; // CozmoAcademy private float timeSinceDecision; // time since last decision private ImageProcessor imageProcessor; // reference to the ImageProcessor private int nearAreaLimit = 0; // X coordinate limit for the near to the imagecenter area private int centerOfImageX = 0; // Middle of the image in x direction private MovementState lastChosenMovement = MovementState.Stop; // The last action/movement that was executed private Queue lastActions = new Queue(); // Queue to store the last chosen Actions private double startTime = 0; private void Start() { startTime = Time.time; academy = FindObjectOfType(typeof(CozmoAcademy)) as CozmoAcademy; imageProcessor = renderCamera.GetComponent(); nearAreaLimit = (int)(renderCamera.targetTexture.width / 2 * NEAR_AREA_PERCENTAGE_OFFSET); centerOfImageX = renderCamera.targetTexture.width / 2; } public void FixedUpdate() { WaitTimeInference(); } public override void CollectObservations() { SetMask(); } // Set ActionMask for training // Needs to be called from CollectObservations() private void SetMask() { if (useOriginalActionMasking) { // Stop is never allowed switch (lastChosenMovement) { // Do not allow stop decision after a stop case (MovementState.Stop): SetActionMask(STOP); break; // Do not allow stop after forward case (MovementState.Forward): SetActionMask(STOP); break; // Do not allow stop & left after right case (MovementState.Right): SetActionMask(STOP); SetActionMask(LEFT); break; // Do not allow stop & right after left case (MovementState.Left): SetActionMask(STOP); SetActionMask(RIGHT); break; default: throw new ArgumentException("Invalid MovementState."); } } else { // Do not allow stop decision after a stop if (lastChosenMovement == MovementState.Stop) { SetActionMask(STOP); } // Do not allow left decision if right was in the last actions if (lastActions.Contains(MovementState.Right)) { SetActionMask(LEFT); } // Do not allow right decision if left was in the last actions if (lastActions.Contains(MovementState.Left)) { SetActionMask(RIGHT); } } } // to be implemented by the developer public override void AgentAction(float[] vectorAction, string textAction) { double elapsedTime = Time.time - startTime; //Debug.Log("Elapsed time: " + elapsedTime); startTime = Time.time; int action = Mathf.FloorToInt(vectorAction[0]); Point centerOfGravity = imageProcessor.CenterOfGravity; AddReward(-0.01f); switch (action) { case STOP: movementController.currentMovementState = MovementState.Stop; lastChosenMovement = MovementState.Stop; SetReward(-0.02f); break; case FORWARD: movementController.currentMovementState = MovementState.Forward; lastChosenMovement = MovementState.Forward; SetReward(0.02f); break; case RIGHT: movementController.currentMovementState = MovementState.Right; lastChosenMovement = MovementState.Right; SetReward(0.01f); break; case LEFT: movementController.currentMovementState = MovementState.Left; lastChosenMovement = MovementState.Left; SetReward(0.01f); break; default: //movement.Move(0); throw new ArgumentException("Invalid action value. Stop movement."); } if (!useOriginalActionMasking) CollectLastMovementStates(lastChosenMovement); // Render new image after movement in order to update the centerOfGravity if (renderCamera != null) { renderCamera.Render(); } RewardAgent(); } // Set the reward for the agent based on how far away the center of gravity is from the center of the image private void RewardAgent() { float centerOfGravityX = imageProcessor.CenterOfGravity.X; float reward = 0; // Center of gravity is far away from the center (left) if (centerOfGravityX <= centerOfImageX - nearAreaLimit && centerOfGravityX >= 0) { float range = centerOfImageX - nearAreaLimit; reward = -(1 - (centerOfGravityX / range)); // Clamp the reward to max -1 and divide it by 2 reward = Mathf.Clamp(reward, -1, 0) / 2; } // Center of gravity is near left of the center else if ((centerOfGravityX <= centerOfImageX) && (centerOfGravityX >= (centerOfImageX - nearAreaLimit))) { float range = centerOfImageX - (centerOfImageX - nearAreaLimit); float distanceToLeftFarBorder = centerOfGravityX - (centerOfImageX - nearAreaLimit); reward = (distanceToLeftFarBorder / range); } // Center of gravity is far away from the center (right) else if ((centerOfGravityX >= (centerOfImageX + nearAreaLimit)) && (centerOfGravityX <= renderCamera.targetTexture.width)) { float range = renderCamera.targetTexture.width - (centerOfImageX + nearAreaLimit); reward = -(((centerOfGravityX - (centerOfImageX + nearAreaLimit)) / range)); // Clamp the reward to max -1 in order to handle rewards if the center of gravity is outside of the image reward = Mathf.Clamp(reward, -1, 0) / 2; } // Center of gravity is near right of the center else if ((centerOfGravityX >= centerOfImageX) && (centerOfGravityX <= (centerOfImageX + nearAreaLimit))) { float range = (centerOfImageX + nearAreaLimit) - centerOfImageX; float distanceToCenterOfImage = centerOfGravityX - centerOfImageX; reward = (1 - distanceToCenterOfImage / range); } else { SetReward(-1); AgentReset(); Debug.Log("Out of image range"); } Debug.Log("Reward: " + reward); SetReward(reward); } // Store the last movementStates in a Queue private void CollectLastMovementStates(MovementState movementState) { // Check if Queue exists and values should be stored if ((lastActions != null) && (maxStoredMovementStates > 0)) { // maxStoredMovementStates is reached if (lastActions.Count >= maxStoredMovementStates) { // deque first value(s) when maxStoredMovementStates is reached for (int i = 0; i <= (lastActions.Count - maxStoredMovementStates); i++) { lastActions.Dequeue(); } } // add last action to queue lastActions.Enqueue(movementState); } } // to be implemented by the developer public override void AgentReset() { academy.AcademyReset(); } private void OnTriggerEnter(Collider other) { if (other.transform.CompareTag("Goal")) { Done(); } } private void WaitTimeInference() { if (renderCamera != null) { renderCamera.Render(); } if (!academy.GetIsInference()) { RequestDecision(); } else { if (timeSinceDecision >= timeBetweenDecisionsAtInference) { timeSinceDecision = 0f; RequestDecision(); } else { timeSinceDecision += Time.fixedDeltaTime; } } } } }