You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

CozmoAgent.cs 10KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. ///-----------------------------------------------------------------
  2. /// Namespace: <Cozmo>
  3. /// Class: <CozmoAgent>
  4. /// Description: <The actual agent in the scene. Collects observations and executes the actions.
  5. /// Also rewards the agent and sets an actionmask.>
  6. /// Author: <Tobias Hassel> Date: <29.07.2019>
  7. /// Notes: <>
  8. ///-----------------------------------------------------------------
  9. ///
  10. using MLAgents;
  11. using OpenCvSharp;
  12. using System;
  13. using System.Collections.Generic;
  14. using UnityEngine;
  15. namespace Cozmo
  16. {
  17. public class CozmoAgent : Agent
  18. {
  19. // Possible Actions
  20. private const int STOP = 0;
  21. private const int FORWARD = 1;
  22. private const int RIGHT = 2;
  23. private const int LEFT = 3;
  24. // Used to determine different areas in the image (near to the center, far away)
  25. private const float NEAR_AREA_PERCENTAGE_OFFSET = 0.3f;
  26. [Tooltip("The virtual Cozmo camera")]
  27. public Camera renderCamera;
  28. [Tooltip("Reference to the CozmoMovement script")]
  29. public CozmoMovementController movementController;
  30. public float timeBetweenDecisionsAtInference;
  31. public int maxStoredMovementStates = 1;
  32. private Academy academy; // CozmoAcademy
  33. private float timeSinceDecision; // time since last decision
  34. private ImageProcessor imageProcessor; // reference to the ImageProcessor
  35. private int nearAreaLimit = 0; // X coordinate limit for the near to the imagecenter area
  36. private int centerOfImageX = 0; // Middle of the image in x direction
  37. private MovementState lastChosenMovement = MovementState.Stop; // The last action/movement that was executed
  38. private Queue<MovementState> lastActions = new Queue<MovementState>();
  39. private double startTime = Time.time;
  40. private void Start()
  41. {
  42. academy = FindObjectOfType(typeof(CozmoAcademy)) as CozmoAcademy;
  43. imageProcessor = renderCamera.GetComponent<ImageProcessor>();
  44. nearAreaLimit = (int)(renderCamera.targetTexture.width / 2 * NEAR_AREA_PERCENTAGE_OFFSET);
  45. centerOfImageX = renderCamera.targetTexture.width / 2;
  46. }
  47. public void FixedUpdate()
  48. {
  49. WaitTimeInference();
  50. }
  51. public override void CollectObservations()
  52. {
  53. SetMask();
  54. }
  55. // Set ActionMask for training
  56. private void SetMask()
  57. {
  58. // Do not allow stop decision after a stop
  59. if (lastChosenMovement == MovementState.Stop)
  60. {
  61. SetActionMask(STOP);
  62. }
  63. // Do not allow left decision if right was in the last actions
  64. if (lastActions.Contains(MovementState.Right))
  65. {
  66. SetActionMask(LEFT);
  67. }
  68. // Do not allow right decision if left was in the last actions
  69. if (lastActions.Contains(MovementState.Left))
  70. {
  71. SetActionMask(RIGHT);
  72. }
  73. //switch (lastChosenMovement)
  74. //{
  75. // // Do not allow stop decision after a stop
  76. // case (MovementState.Stop):
  77. // SetActionMask(STOP);
  78. // break;
  79. // // Do not allow stop after forward
  80. // case (MovementState.Forward):
  81. // //SetActionMask(STOP);
  82. // break;
  83. // // Do not allow stop & left after right
  84. // case (MovementState.Right):
  85. // //SetActionMask(STOP);
  86. // if (lastActions.Contains(MovementState.))
  87. // SetActionMask(LEFT);
  88. // break;
  89. // // Do not allow stop & right after left
  90. // case (MovementState.Left):
  91. // //SetActionMask(STOP);
  92. // SetActionMask(RIGHT);
  93. // break;
  94. // default:
  95. // throw new ArgumentException("Invalid MovementState.");
  96. //}
  97. }
  98. // to be implemented by the developer
  99. public override void AgentAction(float[] vectorAction, string textAction)
  100. {
  101. double elapsedTime = Time.time - startTime;
  102. print("Elapsed time: " + elapsedTime);
  103. startTime = Time.time;
  104. int action = Mathf.FloorToInt(vectorAction[0]);
  105. Point centerOfGravity = imageProcessor.CenterOfGravity;
  106. AddReward(-0.01f);
  107. switch (action)
  108. {
  109. case STOP:
  110. movementController.currentMovementState = MovementState.Stop;
  111. lastChosenMovement = MovementState.Stop;
  112. //Test
  113. SetReward(-0.02f);
  114. break;
  115. case FORWARD:
  116. movementController.currentMovementState = MovementState.Forward;
  117. lastChosenMovement = MovementState.Forward;
  118. //Test
  119. SetReward(0.02f);
  120. break;
  121. case RIGHT:
  122. movementController.currentMovementState = MovementState.Right;
  123. lastChosenMovement = MovementState.Right;
  124. //Test
  125. SetReward(0.01f);
  126. break;
  127. case LEFT:
  128. movementController.currentMovementState = MovementState.Left;
  129. lastChosenMovement = MovementState.Left;
  130. //Test
  131. SetReward(0.01f);
  132. break;
  133. default:
  134. //movement.Move(0);
  135. throw new ArgumentException("Invalid action value. Stop movement.");
  136. }
  137. CollectLastMovementStates(lastChosenMovement);
  138. // Render new image after movement in order to update the centerOfGravity
  139. if (renderCamera != null)
  140. {
  141. renderCamera.Render();
  142. }
  143. RewardAgent();
  144. }
  145. // Set the reward for the agent based on how far away the center of gravity is from the center of the image
  146. private void RewardAgent()
  147. {
  148. float centerOfGravityX = imageProcessor.CenterOfGravity.X;
  149. float reward = 0;
  150. // Center of gravity is far away from the center (left)
  151. if (centerOfGravityX <= centerOfImageX - nearAreaLimit && centerOfGravityX >= 0)
  152. {
  153. float range = centerOfImageX - nearAreaLimit;
  154. reward = -(1 - (centerOfGravityX / range));
  155. // Clamp the reward to max -1 in order to handle rewards if the center of gravity is outside of the image
  156. reward = Mathf.Clamp(reward, -1, 0) / 2;
  157. }
  158. // Center of gravity is near left of the center
  159. else if ((centerOfGravityX <= centerOfImageX) && (centerOfGravityX >= (centerOfImageX - nearAreaLimit)))
  160. {
  161. float range = centerOfImageX - (centerOfImageX - nearAreaLimit);
  162. float distanceToLeftFarBorder = centerOfGravityX - (centerOfImageX - nearAreaLimit);
  163. reward = (distanceToLeftFarBorder / range);
  164. }
  165. // Center of gravity is far away from the center (right)
  166. else if ((centerOfGravityX >= (centerOfImageX + nearAreaLimit)) && (centerOfGravityX <= renderCamera.targetTexture.width))
  167. {
  168. float range = renderCamera.targetTexture.width - (centerOfImageX + nearAreaLimit);
  169. reward = -(((centerOfGravityX - (centerOfImageX + nearAreaLimit)) / range));
  170. // Clamp the reward to max -1 in order to handle rewards if the center of gravity is outside of the image
  171. reward = Mathf.Clamp(reward, -1, 0) / 2;
  172. }
  173. // Center of gravity is near right of the center
  174. else if ((centerOfGravityX >= centerOfImageX) && (centerOfGravityX <= (centerOfImageX + nearAreaLimit)))
  175. {
  176. float range = (centerOfImageX + nearAreaLimit) - centerOfImageX;
  177. float distanceToCenterOfImage = centerOfGravityX - centerOfImageX;
  178. reward = (1 - distanceToCenterOfImage / range);
  179. }
  180. else
  181. {
  182. SetReward(-1);
  183. AgentReset();
  184. Debug.Log("Out of image range");
  185. }
  186. Debug.Log("Reward: " + reward);
  187. SetReward(reward);
  188. }
  189. // Store the last movementStates in a Queue
  190. private void CollectLastMovementStates(MovementState movementState)
  191. {
  192. // Check if Queue exists and values should be stored
  193. if ((lastActions != null) && (maxStoredMovementStates > 0))
  194. {
  195. // maxStoredMovementStates is reached
  196. if (lastActions.Count >= maxStoredMovementStates)
  197. {
  198. // deque first value(s) when maxStoredMovementStates is reached
  199. for (int i = 0; i <= (lastActions.Count - maxStoredMovementStates); i++)
  200. {
  201. lastActions.Dequeue();
  202. }
  203. }
  204. // add last action to queue
  205. lastActions.Enqueue(movementState);
  206. }
  207. }
  208. // to be implemented by the developer
  209. public override void AgentReset()
  210. {
  211. academy.AcademyReset();
  212. }
  213. private void OnTriggerEnter(Collider other)
  214. {
  215. if (other.transform.CompareTag("Goal"))
  216. {
  217. Done();
  218. }
  219. }
  220. private void WaitTimeInference()
  221. {
  222. if (renderCamera != null)
  223. {
  224. renderCamera.Render();
  225. }
  226. if (!academy.GetIsInference())
  227. {
  228. RequestDecision();
  229. }
  230. else
  231. {
  232. if (timeSinceDecision >= timeBetweenDecisionsAtInference)
  233. {
  234. timeSinceDecision = 0f;
  235. RequestDecision();
  236. }
  237. else
  238. {
  239. timeSinceDecision += Time.fixedDeltaTime;
  240. }
  241. }
  242. }
  243. }
  244. }