You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

CozmoAgent.cs 8.5KB


  1. ///-----------------------------------------------------------------
  2. /// Namespace: <Cozmo>
  3. /// Class: <CozmoAgent>
  4. /// Description: <The actual agent in the scene. Collects observations and executes the actions.
  5. /// Also rewards the agent and sets an actionmask.>
  6. /// Author: <Tobias Hassel> Date: <29.07.2019>
  7. /// Notes: <>
  8. ///-----------------------------------------------------------------
  9. ///
  10. using MLAgents;
  11. using OpenCvSharp;
  12. using System;
  13. using UnityEngine;
  14. namespace Cozmo
  15. {
  16. public class CozmoAgent : Agent
  17. {
  18. // Possible Actions
  19. private const int STOP = 0;
  20. private const int FORWARD = 1;
  21. private const int RIGHT = 2;
  22. private const int LEFT = 3;
  23. // Used to determine different areas in the image (near to the center, far away)
  24. private const float NEAR_AREA_PERCENTAGE_OFFSET = 0.3f;
  25. [Tooltip("The virtual Cozmo camera")]
  26. public Camera renderCamera;
  27. [Tooltip("Reference to the CozmoMovement script")]
  28. public CozmoMovementController movementController;
  29. public float timeBetweenDecisionsAtInference;
  30. private Academy academy; // CozmoAcademy
  31. private float timeSinceDecision; // time since last decision
  32. private ImageProcessor imageProcessor; // reference to the ImageProcessor
  33. private int nearAreaLimit = 0; // X coordinate limit for the near to the imagecenter area
  34. private int centerOfImageX = 0; // Middle of the image in x direction
  35. private MovementState lastChosenMovement = MovementState.Stop; // The last action/movement that was executed
  36. private double startTime = Time.time;
  37. private void Start()
  38. {
  39. academy = FindObjectOfType(typeof(CozmoAcademy)) as CozmoAcademy;
  40. imageProcessor = renderCamera.GetComponent<ImageProcessor>();
  41. nearAreaLimit = (int)(renderCamera.targetTexture.width / 2 * NEAR_AREA_PERCENTAGE_OFFSET);
  42. centerOfImageX = renderCamera.targetTexture.width / 2;
  43. }
  44. public void FixedUpdate()
  45. {
  46. WaitTimeInference();
  47. }
  48. public override void CollectObservations()
  49. {
  50. SetMask();
  51. }
  52. // Set ActionMask for training
  53. private void SetMask()
  54. {
  55. switch (lastChosenMovement)
  56. {
  57. // Do not allow stop decision after a stop
  58. case (MovementState.Stop):
  59. SetActionMask(STOP);
  60. break;
  61. // Do not allow stop after forward
  62. case (MovementState.Forward):
  63. SetActionMask(STOP);
  64. break;
  65. // Do not allow stop & left after right
  66. case (MovementState.Right):
  67. SetActionMask(STOP);
  68. SetActionMask(LEFT);
  69. break;
  70. // Do not allow stop & right after left
  71. case (MovementState.Left):
  72. SetActionMask(STOP);
  73. SetActionMask(RIGHT);
  74. break;
  75. default:
  76. throw new ArgumentException("Invalid MovementState.");
  77. }
  78. }
  79. // to be implemented by the developer
  80. public override void AgentAction(float[] vectorAction, string textAction)
  81. {
  82. double elapsedTime = Time.time - startTime;
  83. print("Elapsed time: " + elapsedTime);
  84. startTime = Time.time;
  85. int action = Mathf.FloorToInt(vectorAction[0]);
  86. Point centerOfGravity = imageProcessor.CenterOfGravity;
  87. AddReward(-0.01f);
  88. switch (action)
  89. {
  90. case STOP:
  91. movementController.currentMovementState = MovementState.Stop;
  92. lastChosenMovement = MovementState.Stop;
  93. //Test
  94. SetReward(-0.1f);
  95. break;
  96. case FORWARD:
  97. movementController.currentMovementState = MovementState.Forward;
  98. lastChosenMovement = MovementState.Forward;
  99. //Test
  100. SetReward(0.01f);
  101. break;
  102. case RIGHT:
  103. movementController.currentMovementState = MovementState.Right;
  104. lastChosenMovement = MovementState.Right;
  105. //Test
  106. SetReward(-0.02f);
  107. break;
  108. case LEFT:
  109. movementController.currentMovementState = MovementState.Left;
  110. lastChosenMovement = MovementState.Left;
  111. //Test
  112. SetReward(-0.02f);
  113. break;
  114. default:
  115. //movement.Move(0);
  116. throw new ArgumentException("Invalid action value. Stop movement.");
  117. }
  118. // Render new image after movement in order to update the centerOfGravity
  119. if (renderCamera != null)
  120. {
  121. renderCamera.Render();
  122. }
  123. RewardAgent();
  124. }
  125. // Set the reward for the agent based on how far away the center of gravity is from the center of the image
  126. private void RewardAgent()
  127. {
  128. float centerOfGravityX = imageProcessor.CenterOfGravity.X;
  129. float reward = 0;
  130. // Center of gravity is far away from the center (left)
  131. if (centerOfGravityX <= centerOfImageX - nearAreaLimit && centerOfGravityX >= 0)
  132. {
  133. float range = centerOfImageX - nearAreaLimit;
  134. reward = -(1 - (centerOfGravityX / range));
  135. // Clamp the reward to max -1 in order to handle rewards if the center of gravity is outside of the image
  136. reward = Mathf.Clamp(reward, -1, 0) / 2;
  137. }
  138. // Center of gravity is near left of the center
  139. else if ((centerOfGravityX <= centerOfImageX) && (centerOfGravityX >= (centerOfImageX - nearAreaLimit)))
  140. {
  141. float range = centerOfImageX - (centerOfImageX - nearAreaLimit);
  142. float distanceToLeftFarBorder = centerOfGravityX - (centerOfImageX - nearAreaLimit);
  143. reward = (distanceToLeftFarBorder / range);
  144. }
  145. // Center of gravity is far away from the center (right)
  146. else if ((centerOfGravityX >= (centerOfImageX + nearAreaLimit)) && (centerOfGravityX <= renderCamera.targetTexture.width))
  147. {
  148. float range = renderCamera.targetTexture.width - (centerOfImageX + nearAreaLimit);
  149. reward = -(((centerOfGravityX - (centerOfImageX + nearAreaLimit)) / range));
  150. // Clamp the reward to max -1 in order to handle rewards if the center of gravity is outside of the image
  151. reward = Mathf.Clamp(reward, -1, 0) / 2;
  152. }
  153. // Center of gravity is near right of the center
  154. else if ((centerOfGravityX >= centerOfImageX) && (centerOfGravityX <= (centerOfImageX + nearAreaLimit)))
  155. {
  156. float range = (centerOfImageX + nearAreaLimit) - centerOfImageX;
  157. float distanceToCenterOfImage = centerOfGravityX - centerOfImageX;
  158. reward = (1 - distanceToCenterOfImage / range);
  159. }
  160. else
  161. {
  162. SetReward(-1);
  163. AgentReset();
  164. Debug.Log("Out of image range");
  165. }
  166. Debug.Log("Reward: " + reward);
  167. SetReward(reward);
  168. }
  169. // to be implemented by the developer
  170. public override void AgentReset()
  171. {
  172. academy.AcademyReset();
  173. }
  174. private void OnTriggerEnter(Collider other)
  175. {
  176. if (other.transform.CompareTag("Goal"))
  177. {
  178. Done();
  179. }
  180. }
  181. private void WaitTimeInference()
  182. {
  183. if (renderCamera != null)
  184. {
  185. renderCamera.Render();
  186. }
  187. if (!academy.GetIsInference())
  188. {
  189. RequestDecision();
  190. }
  191. else
  192. {
  193. if (timeSinceDecision >= timeBetweenDecisionsAtInference)
  194. {
  195. timeSinceDecision = 0f;
  196. RequestDecision();
  197. }
  198. else
  199. {
  200. timeSinceDecision += Time.fixedDeltaTime;
  201. }
  202. }
  203. }
  204. }
  205. }