using MLAgents; using OpenCvSharp; using System; using System.Collections; using System.Collections.Generic; using UnityEngine; public class CozmoAgent : Agent { // Possible Actions private const int STOP = 0; private const int FORWARD = 1; private const int RIGHT = 2; private const int LEFT = 3; // Used to determine different areas in the image (near to the center, far away) private const float NEAR_AREA_PERCENTAGE_OFFSET = 0.2f; private const float FAR_AREA_PERCENTAGE_OFFSET = 0.3f; [Tooltip("The virtual Cozmo camera")] public Camera renderCamera; //[Tooltip("Final cropped and scaled rendertexture")] //public RenderTexture renderTextureScaled; [Tooltip("Reference to the CozmoMovement script")] public CozmoMovementController movementController; public float timeBetweenDecisionsAtInference; private Academy academy; // CozmoAcademy private float timeSinceDecision; // time since last decision private ImageProcessor imageProcessor; // reference to the ImageProcessor private int nearAreaLimit = 0; // X coordinate limit for the near to the imagecenter area private int farAreaLimit = 0; // X coordinate limit for the far away to the imagecenter area private void Start() { academy = FindObjectOfType(typeof(CozmoAcademy)) as CozmoAcademy; imageProcessor = renderCamera.GetComponent(); nearAreaLimit = (int)(renderCamera.targetTexture.width / 2 * NEAR_AREA_PERCENTAGE_OFFSET); farAreaLimit = (int)(renderCamera.targetTexture.width / 2 * FAR_AREA_PERCENTAGE_OFFSET); } public void FixedUpdate() { WaitTimeInference(); } // to be implemented by the developer public override void AgentAction(float[] vectorAction, string textAction) { //print("Action before FloorToInt: " + vectorAction[0]); int action = Mathf.FloorToInt(vectorAction[0]); Point centerOfGravity = imageProcessor.CenterOfGravity; //Vector3 targetPos = transform.position; //print("Action after FloorToInt: " + action); AddReward(-0.01f); switch (action) { case STOP: movementController.currentMovementState = MovementState.Stop; break; case FORWARD: movementController.currentMovementState = MovementState.Forward; break; case RIGHT: movementController.currentMovementState = MovementState.Right; break; case LEFT: movementController.currentMovementState = MovementState.Left; break; default: //movement.Move(0); throw new ArgumentException("Invalid action value. Stop movement."); } // Render new image after movement in order to update the centerOfGravity if (renderCamera != null) { renderCamera.Render(); } RewardAgent(); imageProcessor.enabled = false; } /// /// TODO: Cleanup code /// private void RewardAgent() { float centerOfImageX = renderCamera.targetTexture.width / 2; float centerOfGravityX = imageProcessor.CenterOfGravity.X; float reward = 0; // Center of gravity is far left of the center if (centerOfGravityX <= centerOfImageX - farAreaLimit) { reward = -1; } // Center of gravity is between far and near left of the center else if (centerOfGravityX <= centerOfImageX - nearAreaLimit) { float range = (centerOfImageX - nearAreaLimit) - (centerOfImageX - farAreaLimit); float distanceToLeftFarBorder = centerOfGravityX - (centerOfImageX - farAreaLimit); reward = -(1 - (distanceToLeftFarBorder / range)); } // Center of gravity is near left of the center else if (centerOfGravityX <= centerOfImageX) { float range = centerOfImageX - (centerOfImageX - nearAreaLimit); float distanceToLeftFarBorder = centerOfGravityX - (centerOfImageX - nearAreaLimit); reward = (distanceToLeftFarBorder / range); } // Center of gravity is far right of the center else if (centerOfGravityX >= centerOfImageX + farAreaLimit) { reward = -1; } // Center of gravity is between far and near right of the center else if (centerOfGravityX >= centerOfImageX + nearAreaLimit) { float range = (centerOfImageX + farAreaLimit) - (centerOfImageX + nearAreaLimit); float distanceToLeftFarBorder = centerOfGravityX - (centerOfImageX + nearAreaLimit); reward = -(distanceToLeftFarBorder / range); } // Center of gravity is near right of the center else if (centerOfGravityX >= centerOfImageX) { float range = (centerOfImageX + nearAreaLimit) - centerOfImageX; float distanceToLeftFarBorder = centerOfGravityX - centerOfImageX; reward = (1 - distanceToLeftFarBorder / range); } SetReward(reward); } // to be implemented by the developer public override void AgentReset() { academy.AcademyReset(); } private void OnTriggerEnter(Collider other) { if (other.transform.CompareTag("Goal")) { print("Collission"); Done(); } } private void WaitTimeInference() { if (renderCamera != null) { renderCamera.Render(); } if (!academy.GetIsInference()) { RequestDecision(); } else { if (timeSinceDecision >= timeBetweenDecisionsAtInference) { timeSinceDecision = 0f; RequestDecision(); } else { timeSinceDecision += Time.fixedDeltaTime; } } } }