You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

CozmoAgent.cs 4.3KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. using MLAgents;
  2. using OpenCvSharp;
  3. using System;
  4. using System.Collections;
  5. using System.Collections.Generic;
  6. using UnityEngine;
  7. public class CozmoAgent : Agent
  8. {
  9. // Possible Actions
  10. private const int STOP = 0;
  11. private const int FORWARD = 1;
  12. private const int RIGHT = 2;
  13. private const int LEFT = 3;
  14. // Used to determine different areas in the image (near to the center, far away)
  15. private const float NEAR_AREA_PERCENTAGE_OFFSET = 0.2f;
  16. private const float FAR_AREA_PERCENTAGE_OFFSET = 0.3f;
  17. [Tooltip("The virtual Cozmo camera")]
  18. public Camera renderCamera;
  19. [Tooltip("Reference to the CozmoMovement script")]
  20. public CozmoMovement movement;
  21. public float timeBetweenDecisionsAtInference;
  22. private Academy academy; // CozmoAcademy
  23. private float timeSinceDecision; // time since last decision
  24. private ImageProcessor onRenderImageTest; // reference to the ImageProcessor
  25. private int nearAreaLimit = 0; // X coordinate limit for the near to the imagecenter area
  26. private int farAreaLimit = 0; // X coordinate limit for the far away to the imagecenter area
  27. // for testing
  28. //private float[] floats = { 1.0f, 2.0f, 3.0f };
  29. private void Start()
  30. {
  31. academy = FindObjectOfType(typeof(CozmoAcademy)) as CozmoAcademy;
  32. onRenderImageTest = renderCamera.GetComponent<ImageProcessor>();
  33. nearAreaLimit = (int)(renderCamera.targetTexture.width / 2 * NEAR_AREA_PERCENTAGE_OFFSET);
  34. farAreaLimit = (int)(renderCamera.targetTexture.width / 2 * FAR_AREA_PERCENTAGE_OFFSET);
  35. }
  36. public void FixedUpdate()
  37. {
  38. WaitTimeInference();
  39. // for testing
  40. //AgentAction(floats, "ActionText");
  41. }
  42. // to be implemented by the developer
  43. public override void AgentAction(float[] vectorAction, string textAction)
  44. {
  45. int action = Mathf.FloorToInt(vectorAction[0]);
  46. Point centerOfGravity = onRenderImageTest.CenterOfGravity;
  47. Vector3 targetPos = transform.position;
  48. AddReward(-0.01f);
  49. switch (action)
  50. {
  51. case STOP:
  52. movement.Move(0);
  53. break;
  54. case FORWARD:
  55. movement.Move(1);
  56. break;
  57. case RIGHT:
  58. movement.Turn(1);
  59. break;
  60. case LEFT:
  61. movement.Turn(-1);
  62. break;
  63. default:
  64. movement.Move(0);
  65. throw new ArgumentException("Invalid action value. Stop movement.");
  66. }
  67. // Render new image after movement in order to update the centerOfGravity
  68. if (renderCamera != null)
  69. {
  70. renderCamera.Render();
  71. }
  72. // If centerOfGravity lies near to the center of the image horizontally
  73. if (centerOfGravity.X > renderCamera.targetTexture.width / 2 - nearAreaLimit && centerOfGravity.X < renderCamera.targetTexture.width / 2 + nearAreaLimit)
  74. {
  75. Done();
  76. SetReward(1);
  77. print("Reward: +1");
  78. }
  79. else if (centerOfGravity.X > renderCamera.targetTexture.width / 2 - farAreaLimit && centerOfGravity.X < renderCamera.targetTexture.width / 2 + farAreaLimit)
  80. {
  81. Done();
  82. SetReward(-1);
  83. print("Reward: -1");
  84. }
  85. else
  86. {
  87. Done();
  88. SetReward(-2);
  89. print("Reward: -2");
  90. }
  91. }
  92. // to be implemented by the developer
  93. public override void AgentReset()
  94. {
  95. academy.AcademyReset();
  96. }
  97. private void OnTriggerEnter(Collider other)
  98. {
  99. if (other.transform.CompareTag("Goal"))
  100. {
  101. print("Collission");
  102. Done();
  103. }
  104. }
  105. private void WaitTimeInference()
  106. {
  107. if (renderCamera != null)
  108. {
  109. renderCamera.Render();
  110. }
  111. if (!academy.GetIsInference())
  112. {
  113. RequestDecision();
  114. }
  115. else
  116. {
  117. if (timeSinceDecision >= timeBetweenDecisionsAtInference)
  118. {
  119. timeSinceDecision = 0f;
  120. RequestDecision();
  121. }
  122. else
  123. {
  124. timeSinceDecision += Time.fixedDeltaTime;
  125. }
  126. }
  127. }
  128. }