using MLAgents; using OpenCvSharp; using System; using System.Collections; using System.Collections.Generic; using UnityEngine; public class CozmoAgent : Agent { // Possible Actions private const int STOP = 0; private const int FORWARD = 1; private const int RIGHT = 2; private const int LEFT = 3; // Used to determine different areas in the image (near to the center, far away) private const float NEAR_AREA_PERCENTAGE_OFFSET = 0.3f; [Tooltip("The virtual Cozmo camera")] public Camera renderCamera; [Tooltip("Reference to the CozmoMovement script")] public CozmoMovementController movementController; public float timeBetweenDecisionsAtInference; private Academy academy; // CozmoAcademy private float timeSinceDecision; // time since last decision private ImageProcessor imageProcessor; // reference to the ImageProcessor private int nearAreaLimit = 0; // X coordinate limit for the near to the imagecenter area private int centerOfImageX = 0; // Middle of the image in x direction private void Start() { academy = FindObjectOfType(typeof(CozmoAcademy)) as CozmoAcademy; imageProcessor = renderCamera.GetComponent(); nearAreaLimit = (int)(renderCamera.targetTexture.width / 2 * NEAR_AREA_PERCENTAGE_OFFSET); centerOfImageX = renderCamera.targetTexture.width / 2; } public void FixedUpdate() { WaitTimeInference(); } // to be implemented by the developer public override void AgentAction(float[] vectorAction, string textAction) { int action = Mathf.FloorToInt(vectorAction[0]); Point centerOfGravity = imageProcessor.CenterOfGravity; AddReward(-0.01f); switch (action) { case STOP: movementController.currentMovementState = MovementState.Stop; break; case FORWARD: movementController.currentMovementState = MovementState.Forward; break; case RIGHT: movementController.currentMovementState = MovementState.Right; break; case LEFT: movementController.currentMovementState = MovementState.Left; break; default: //movement.Move(0); throw new ArgumentException("Invalid action value. Stop movement."); } // Render new image after movement in order to update the centerOfGravity if (renderCamera != null) { renderCamera.Render(); } RewardAgent(); } // Set the reward for the agent based on how far away the center of gravity is from the center of the image private void RewardAgent() { float centerOfGravityX = imageProcessor.CenterOfGravity.X; float reward = 0; // Center of gravity is far away from the center (left) if (centerOfGravityX <= centerOfImageX - nearAreaLimit) { float range = centerOfImageX - nearAreaLimit; reward = -(1 - (centerOfGravityX / range)); // Clamp the reward to max -1 in order to handle rewards if the center of gravity is outside of the image reward = Mathf.Clamp(reward, -1, 0); } // Center of gravity is near left of the center else if (centerOfGravityX <= centerOfImageX) { float range = centerOfImageX - (centerOfImageX - nearAreaLimit); float distanceToLeftFarBorder = centerOfGravityX - (centerOfImageX - nearAreaLimit); reward = (distanceToLeftFarBorder / range); } // Center of gravity is far away from the center (right) else if (centerOfGravityX >= centerOfImageX + nearAreaLimit) { float range = renderCamera.targetTexture.width - (centerOfImageX + nearAreaLimit); reward = -(((centerOfGravityX - (centerOfImageX + nearAreaLimit)) / range)); // Clamp the reward to max -1 in order to handle rewards if the center of gravity is outside of the image reward = Mathf.Clamp(reward, -1, 0); } // Center of gravity is near right of the center else if (centerOfGravityX >= centerOfImageX) { float range = (centerOfImageX + nearAreaLimit) - centerOfImageX; float distanceToCenterOfImage = centerOfGravityX - centerOfImageX; reward = (1 - distanceToCenterOfImage / range); } SetReward(reward); } // to be implemented by the developer public override void AgentReset() { academy.AcademyReset(); } private void OnTriggerEnter(Collider other) { if (other.transform.CompareTag("Goal")) { Done(); } } private void WaitTimeInference() { if (renderCamera != null) { renderCamera.Render(); } if (!academy.GetIsInference()) { RequestDecision(); } else { if (timeSinceDecision >= timeBetweenDecisionsAtInference) { timeSinceDecision = 0f; RequestDecision(); } else { timeSinceDecision += Time.fixedDeltaTime; } } } }