You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

HallwayAgent.cs 5.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. using System.Collections;
  2. using System.Collections.Generic;
  3. using UnityEngine;
  4. using MLAgents;
  5. public class HallwayAgent : Agent
  6. {
  7. public GameObject ground;
  8. public GameObject area;
  9. public GameObject orangeGoal;
  10. public GameObject redGoal;
  11. public GameObject orangeBlock;
  12. public GameObject redBlock;
  13. public bool useVectorObs;
  14. RayPerception rayPer;
  15. Rigidbody shortBlockRB;
  16. Rigidbody agentRB;
  17. Material groundMaterial;
  18. Renderer groundRenderer;
  19. HallwayAcademy academy;
  20. int selection;
  21. public override void InitializeAgent()
  22. {
  23. base.InitializeAgent();
  24. academy = FindObjectOfType<HallwayAcademy>();
  25. rayPer = GetComponent<RayPerception>();
  26. agentRB = GetComponent<Rigidbody>();
  27. groundRenderer = ground.GetComponent<Renderer>();
  28. groundMaterial = groundRenderer.material;
  29. }
  30. public override void CollectObservations()
  31. {
  32. if (useVectorObs)
  33. {
  34. float rayDistance = 12f;
  35. float[] rayAngles = { 20f, 60f, 90f, 120f, 160f };
  36. string[] detectableObjects = { "orangeGoal", "redGoal", "orangeBlock", "redBlock", "wall" };
  37. AddVectorObs(GetStepCount() / (float)agentParameters.maxStep);
  38. AddVectorObs(rayPer.Perceive(rayDistance, rayAngles, detectableObjects, 0f, 0f));
  39. }
  40. }
  41. IEnumerator GoalScoredSwapGroundMaterial(Material mat, float time)
  42. {
  43. groundRenderer.material = mat;
  44. yield return new WaitForSeconds(time);
  45. groundRenderer.material = groundMaterial;
  46. }
  47. public void MoveAgent(float[] act)
  48. {
  49. Vector3 dirToGo = Vector3.zero;
  50. Vector3 rotateDir = Vector3.zero;
  51. if (brain.brainParameters.vectorActionSpaceType == SpaceType.continuous)
  52. {
  53. dirToGo = transform.forward * Mathf.Clamp(act[0], -1f, 1f);
  54. rotateDir = transform.up * Mathf.Clamp(act[1], -1f, 1f);
  55. }
  56. else
  57. {
  58. int action = Mathf.FloorToInt(act[0]);
  59. switch (action)
  60. {
  61. case 1:
  62. dirToGo = transform.forward * 1f;
  63. break;
  64. case 2:
  65. dirToGo = transform.forward * -1f;
  66. break;
  67. case 3:
  68. rotateDir = transform.up * 1f;
  69. break;
  70. case 4:
  71. rotateDir = transform.up * -1f;
  72. break;
  73. }
  74. }
  75. transform.Rotate(rotateDir, Time.deltaTime * 150f);
  76. agentRB.AddForce(dirToGo * academy.agentRunSpeed, ForceMode.VelocityChange);
  77. }
  78. public override void AgentAction(float[] vectorAction, string textAction)
  79. {
  80. AddReward(-1f / agentParameters.maxStep);
  81. MoveAgent(vectorAction);
  82. }
  83. void OnCollisionEnter(Collision col)
  84. {
  85. if (col.gameObject.CompareTag("orangeGoal") || col.gameObject.CompareTag("redGoal"))
  86. {
  87. if ((selection == 0 && col.gameObject.CompareTag("orangeGoal")) ||
  88. (selection == 1 && col.gameObject.CompareTag("redGoal")))
  89. {
  90. SetReward(1f);
  91. StartCoroutine(GoalScoredSwapGroundMaterial(academy.goalScoredMaterial, 0.5f));
  92. }
  93. else
  94. {
  95. SetReward(-0.1f);
  96. StartCoroutine(GoalScoredSwapGroundMaterial(academy.failMaterial, 0.5f));
  97. }
  98. Done();
  99. }
  100. }
  101. public override void AgentReset()
  102. {
  103. float agentOffset = -15f;
  104. float blockOffset = 0f;
  105. selection = Random.Range(0, 2);
  106. if (selection == 0)
  107. {
  108. orangeBlock.transform.position =
  109. new Vector3(0f + Random.Range(-3f, 3f), 2f, blockOffset + Random.Range(-5f, 5f))
  110. + ground.transform.position;
  111. redBlock.transform.position =
  112. new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
  113. + ground.transform.position;
  114. }
  115. else
  116. {
  117. orangeBlock.transform.position =
  118. new Vector3(0f, -1000f, blockOffset + Random.Range(-5f, 5f))
  119. + ground.transform.position;
  120. redBlock.transform.position =
  121. new Vector3(0f, 2f, blockOffset + Random.Range(-5f, 5f))
  122. + ground.transform.position;
  123. }
  124. transform.position = new Vector3(0f + Random.Range(-3f, 3f),
  125. 1f, agentOffset + Random.Range(-5f, 5f))
  126. + ground.transform.position;
  127. transform.rotation = Quaternion.Euler(0f, Random.Range(0f, 360f), 0f);
  128. agentRB.velocity *= 0f;
  129. int goalPos = Random.Range(0, 2);
  130. if (goalPos == 0)
  131. {
  132. orangeGoal.transform.position = new Vector3(7f, 0.5f, 9f) + area.transform.position;
  133. redGoal.transform.position = new Vector3(-7f, 0.5f, 9f) + area.transform.position;
  134. }
  135. else
  136. {
  137. redGoal.transform.position = new Vector3(7f, 0.5f, 9f) + area.transform.position;
  138. orangeGoal.transform.position = new Vector3(-7f, 0.5f, 9f) + area.transform.position;
  139. }
  140. }
  141. }