You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

CozmoAgent.cs 7.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. using MLAgents;
  2. using OpenCvSharp;
  3. using System;
  4. using System.Collections;
  5. using System.Collections.Generic;
  6. using UnityEngine;
  7. public class CozmoAgent : Agent
  8. {
  9. // Possible Actions
  10. private const int STOP = 0;
  11. private const int FORWARD = 1;
  12. private const int RIGHT = 2;
  13. private const int LEFT = 3;
  14. // Used to determine different areas in the image (near to the center, far away)
  15. private const float NEAR_AREA_PERCENTAGE_OFFSET = 0.3f;
  16. [Tooltip("The virtual Cozmo camera")]
  17. public Camera renderCamera;
  18. [Tooltip("Reference to the CozmoMovement script")]
  19. public CozmoMovementController movementController;
  20. public float timeBetweenDecisionsAtInference;
  21. private Academy academy; // CozmoAcademy
  22. private float timeSinceDecision; // time since last decision
  23. private ImageProcessor imageProcessor; // reference to the ImageProcessor
  24. private int nearAreaLimit = 0; // X coordinate limit for the near to the imagecenter area
  25. private int centerOfImageX = 0; // Middle of the image in x direction
  26. private MovementState lastChosenMovement = MovementState.Stop; // The last action/movement that was executed
  27. private double startTime = Time.time;
  28. private void Start()
  29. {
  30. academy = FindObjectOfType(typeof(CozmoAcademy)) as CozmoAcademy;
  31. imageProcessor = renderCamera.GetComponent<ImageProcessor>();
  32. nearAreaLimit = (int)(renderCamera.targetTexture.width / 2 * NEAR_AREA_PERCENTAGE_OFFSET);
  33. centerOfImageX = renderCamera.targetTexture.width / 2;
  34. }
  35. public void FixedUpdate()
  36. {
  37. WaitTimeInference();
  38. }
  39. public override void CollectObservations()
  40. {
  41. SetMask();
  42. }
  43. // Set ActionMask for training
  44. private void SetMask()
  45. {
  46. switch (lastChosenMovement)
  47. {
  48. // Do not allow stop decision after a stop
  49. case (MovementState.Stop):
  50. SetActionMask(STOP);
  51. break;
  52. // Do not allow stop after forward
  53. case (MovementState.Forward):
  54. SetActionMask(STOP);
  55. break;
  56. // Do not allow stop & left after right
  57. case (MovementState.Right):
  58. SetActionMask(STOP);
  59. SetActionMask(LEFT);
  60. break;
  61. // Do not allow stop & right after left
  62. case (MovementState.Left):
  63. SetActionMask(STOP);
  64. SetActionMask(RIGHT);
  65. break;
  66. default:
  67. throw new ArgumentException("Invalid MovementState.");
  68. }
  69. }
  70. // to be implemented by the developer
  71. public override void AgentAction(float[] vectorAction, string textAction)
  72. {
  73. double elapsedTime = Time.time - startTime;
  74. print("Elapsed time: " + elapsedTime);
  75. startTime = Time.time;
  76. int action = Mathf.FloorToInt(vectorAction[0]);
  77. Point centerOfGravity = imageProcessor.CenterOfGravity;
  78. AddReward(-0.01f);
  79. switch (action)
  80. {
  81. case STOP:
  82. movementController.currentMovementState = MovementState.Stop;
  83. lastChosenMovement = MovementState.Stop;
  84. //Test
  85. SetReward(-0.1f);
  86. break;
  87. case FORWARD:
  88. movementController.currentMovementState = MovementState.Forward;
  89. lastChosenMovement = MovementState.Forward;
  90. //Test
  91. SetReward(0.01f);
  92. break;
  93. case RIGHT:
  94. movementController.currentMovementState = MovementState.Right;
  95. lastChosenMovement = MovementState.Right;
  96. //Test
  97. SetReward(-0.02f);
  98. break;
  99. case LEFT:
  100. movementController.currentMovementState = MovementState.Left;
  101. lastChosenMovement = MovementState.Left;
  102. //Test
  103. SetReward(-0.02f);
  104. break;
  105. default:
  106. //movement.Move(0);
  107. throw new ArgumentException("Invalid action value. Stop movement.");
  108. }
  109. // Render new image after movement in order to update the centerOfGravity
  110. if (renderCamera != null)
  111. {
  112. renderCamera.Render();
  113. }
  114. RewardAgent();
  115. }
  116. // Set the reward for the agent based on how far away the center of gravity is from the center of the image
  117. private void RewardAgent()
  118. {
  119. float centerOfGravityX = imageProcessor.CenterOfGravity.X;
  120. float reward = 0;
  121. // Center of gravity is far away from the center (left)
  122. if (centerOfGravityX <= centerOfImageX - nearAreaLimit && centerOfGravityX >= 0)
  123. {
  124. float range = centerOfImageX - nearAreaLimit;
  125. reward = -(1 - (centerOfGravityX / range));
  126. // Clamp the reward to max -1 in order to handle rewards if the center of gravity is outside of the image
  127. reward = Mathf.Clamp(reward, -1, 0) / 2;
  128. }
  129. // Center of gravity is near left of the center
  130. else if ((centerOfGravityX <= centerOfImageX) && (centerOfGravityX >= (centerOfImageX - nearAreaLimit)))
  131. {
  132. float range = centerOfImageX - (centerOfImageX - nearAreaLimit);
  133. float distanceToLeftFarBorder = centerOfGravityX - (centerOfImageX - nearAreaLimit);
  134. reward = (distanceToLeftFarBorder / range);
  135. }
  136. // Center of gravity is far away from the center (right)
  137. else if ((centerOfGravityX >= (centerOfImageX + nearAreaLimit)) && (centerOfGravityX <= renderCamera.targetTexture.width))
  138. {
  139. float range = renderCamera.targetTexture.width - (centerOfImageX + nearAreaLimit);
  140. reward = -(((centerOfGravityX - (centerOfImageX + nearAreaLimit)) / range));
  141. // Clamp the reward to max -1 in order to handle rewards if the center of gravity is outside of the image
  142. reward = Mathf.Clamp(reward, -1, 0) / 2;
  143. }
  144. // Center of gravity is near right of the center
  145. else if ((centerOfGravityX >= centerOfImageX) && (centerOfGravityX <= (centerOfImageX + nearAreaLimit)))
  146. {
  147. float range = (centerOfImageX + nearAreaLimit) - centerOfImageX;
  148. float distanceToCenterOfImage = centerOfGravityX - centerOfImageX;
  149. reward = (1 - distanceToCenterOfImage / range);
  150. }
  151. else
  152. {
  153. SetReward(-1);
  154. AgentReset();
  155. Debug.Log("Out of image range");
  156. }
  157. Debug.Log("Reward: " + reward);
  158. SetReward(reward);
  159. }
  160. // to be implemented by the developer
  161. public override void AgentReset()
  162. {
  163. academy.AcademyReset();
  164. }
  165. private void OnTriggerEnter(Collider other)
  166. {
  167. if (other.transform.CompareTag("Goal"))
  168. {
  169. Done();
  170. }
  171. }
  172. private void WaitTimeInference()
  173. {
  174. if (renderCamera != null)
  175. {
  176. renderCamera.Render();
  177. }
  178. if (!academy.GetIsInference())
  179. {
  180. RequestDecision();
  181. }
  182. else
  183. {
  184. if (timeSinceDecision >= timeBetweenDecisionsAtInference)
  185. {
  186. timeSinceDecision = 0f;
  187. RequestDecision();
  188. }
  189. else
  190. {
  191. timeSinceDecision += Time.fixedDeltaTime;
  192. }
  193. }
  194. }
  195. }