using MLAgents; using OpenCvSharp; using System; using System.Collections; using System.Collections.Generic; using UnityEngine; public class CozmoAgent : Agent { // Possible Actions private const int STOP = 0; private const int FORWARD = 1; private const int RIGHT = 2; private const int LEFT = 3; // Used to determine different areas in the image (near to the center, far away) private const float NEAR_AREA_PERCENTAGE_OFFSET = 0.3f; [Tooltip("The virtual Cozmo camera")] public Camera renderCamera; [Tooltip("Reference to the CozmoMovement script")] public CozmoMovementController movementController; public float timeBetweenDecisionsAtInference; private Academy academy; // CozmoAcademy private float timeSinceDecision; // time since last decision private ImageProcessor imageProcessor; // reference to the ImageProcessor private int nearAreaLimit = 0; // X coordinate limit for the near to the imagecenter area private int centerOfImageX = 0; // Middle of the image in x direction private MovementState lastChosenMovement = MovementState.Stop; // The last action/movement that was executed private double startTime = Time.time; private void Start() { academy = FindObjectOfType(typeof(CozmoAcademy)) as CozmoAcademy; imageProcessor = renderCamera.GetComponent(); nearAreaLimit = (int)(renderCamera.targetTexture.width / 2 * NEAR_AREA_PERCENTAGE_OFFSET); centerOfImageX = renderCamera.targetTexture.width / 2; } public void FixedUpdate() { WaitTimeInference(); } public override void CollectObservations() { SetMask(); } // Set ActionMask for training private void SetMask() { switch (lastChosenMovement) { // Do not allow stop decision after a stop case (MovementState.Stop): SetActionMask(STOP); break; // Do not allow stop after forward case (MovementState.Forward): SetActionMask(STOP); break; // Do not allow stop & left after right case (MovementState.Right): SetActionMask(STOP); SetActionMask(LEFT); break; // Do not allow stop & right after left case (MovementState.Left): SetActionMask(STOP); SetActionMask(RIGHT); break; default: throw new ArgumentException("Invalid MovementState."); } } // to be implemented by the developer public override void AgentAction(float[] vectorAction, string textAction) { double elapsedTime = Time.time - startTime; print("Elapsed time: " + elapsedTime); startTime = Time.time; int action = Mathf.FloorToInt(vectorAction[0]); Point centerOfGravity = imageProcessor.CenterOfGravity; AddReward(-0.01f); switch (action) { case STOP: movementController.currentMovementState = MovementState.Stop; lastChosenMovement = MovementState.Stop; //Test SetReward(-0.1f); break; case FORWARD: movementController.currentMovementState = MovementState.Forward; lastChosenMovement = MovementState.Forward; //Test SetReward(0.01f); break; case RIGHT: movementController.currentMovementState = MovementState.Right; lastChosenMovement = MovementState.Right; //Test SetReward(-0.02f); break; case LEFT: movementController.currentMovementState = MovementState.Left; lastChosenMovement = MovementState.Left; //Test SetReward(-0.02f); break; default: //movement.Move(0); throw new ArgumentException("Invalid action value. Stop movement."); } // Render new image after movement in order to update the centerOfGravity if (renderCamera != null) { renderCamera.Render(); } RewardAgent(); } // Set the reward for the agent based on how far away the center of gravity is from the center of the image private void RewardAgent() { float centerOfGravityX = imageProcessor.CenterOfGravity.X; float reward = 0; // Center of gravity is far away from the center (left) if (centerOfGravityX <= centerOfImageX - nearAreaLimit && centerOfGravityX >= 0) { float range = centerOfImageX - nearAreaLimit; reward = -(1 - (centerOfGravityX / range)); // Clamp the reward to max -1 in order to handle rewards if the center of gravity is outside of the image reward = Mathf.Clamp(reward, -1, 0) / 2; } // Center of gravity is near left of the center else if ((centerOfGravityX <= centerOfImageX) && (centerOfGravityX >= (centerOfImageX - nearAreaLimit))) { float range = centerOfImageX - (centerOfImageX - nearAreaLimit); float distanceToLeftFarBorder = centerOfGravityX - (centerOfImageX - nearAreaLimit); reward = (distanceToLeftFarBorder / range); } // Center of gravity is far away from the center (right) else if ((centerOfGravityX >= (centerOfImageX + nearAreaLimit)) && (centerOfGravityX <= renderCamera.targetTexture.width)) { float range = renderCamera.targetTexture.width - (centerOfImageX + nearAreaLimit); reward = -(((centerOfGravityX - (centerOfImageX + nearAreaLimit)) / range)); // Clamp the reward to max -1 in order to handle rewards if the center of gravity is outside of the image reward = Mathf.Clamp(reward, -1, 0) / 2; } // Center of gravity is near right of the center else if ((centerOfGravityX >= centerOfImageX) && (centerOfGravityX <= (centerOfImageX + nearAreaLimit))) { float range = (centerOfImageX + nearAreaLimit) - centerOfImageX; float distanceToCenterOfImage = centerOfGravityX - centerOfImageX; reward = (1 - distanceToCenterOfImage / range); } else { SetReward(-1); AgentReset(); Debug.Log("Out of image range"); } Debug.Log("Reward: " + reward); SetReward(reward); } // to be implemented by the developer public override void AgentReset() { academy.AcademyReset(); } private void OnTriggerEnter(Collider other) { if (other.transform.CompareTag("Goal")) { Done(); } } private void WaitTimeInference() { if (renderCamera != null) { renderCamera.Render(); } if (!academy.GetIsInference()) { RequestDecision(); } else { if (timeSinceDecision >= timeBetweenDecisionsAtInference) { timeSinceDecision = 0f; RequestDecision(); } else { timeSinceDecision += Time.fixedDeltaTime; } } } }