You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

PyramidAgent.cs 3.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using UnityEngine;
  6. using Random = UnityEngine.Random;
  7. using MLAgents;
  8. public class PyramidAgent : Agent
  9. {
  10. public GameObject area;
  11. private PyramidArea myArea;
  12. private Rigidbody agentRb;
  13. private RayPerception rayPer;
  14. private PyramidSwitch switchLogic;
  15. public GameObject areaSwitch;
  16. public bool useVectorObs;
  17. public override void InitializeAgent()
  18. {
  19. base.InitializeAgent();
  20. agentRb = GetComponent<Rigidbody>();
  21. myArea = area.GetComponent<PyramidArea>();
  22. rayPer = GetComponent<RayPerception>();
  23. switchLogic = areaSwitch.GetComponent<PyramidSwitch>();
  24. }
  25. public override void CollectObservations()
  26. {
  27. if (useVectorObs)
  28. {
  29. const float rayDistance = 35f;
  30. float[] rayAngles = {20f, 90f, 160f, 45f, 135f, 70f, 110f};
  31. float[] rayAngles1 = {25f, 95f, 165f, 50f, 140f, 75f, 115f};
  32. float[] rayAngles2 = {15f, 85f, 155f, 40f, 130f, 65f, 105f};
  33. string[] detectableObjects = {"block", "wall", "goal", "switchOff", "switchOn", "stone"};
  34. AddVectorObs(rayPer.Perceive(rayDistance, rayAngles, detectableObjects, 0f, 0f));
  35. AddVectorObs(rayPer.Perceive(rayDistance, rayAngles1, detectableObjects, 0f, 5f));
  36. AddVectorObs(rayPer.Perceive(rayDistance, rayAngles2, detectableObjects, 0f, 10f));
  37. AddVectorObs(switchLogic.GetState());
  38. AddVectorObs(transform.InverseTransformDirection(agentRb.velocity));
  39. }
  40. }
  41. public void MoveAgent(float[] act)
  42. {
  43. var dirToGo = Vector3.zero;
  44. var rotateDir = Vector3.zero;
  45. if (brain.brainParameters.vectorActionSpaceType == SpaceType.continuous)
  46. {
  47. dirToGo = transform.forward * Mathf.Clamp(act[0], -1f, 1f);
  48. rotateDir = transform.up * Mathf.Clamp(act[1], -1f, 1f);
  49. }
  50. else
  51. {
  52. var action = Mathf.FloorToInt(act[0]);
  53. switch (action)
  54. {
  55. case 1:
  56. dirToGo = transform.forward * 1f;
  57. break;
  58. case 2:
  59. dirToGo = transform.forward * -1f;
  60. break;
  61. case 3:
  62. rotateDir = transform.up * 1f;
  63. break;
  64. case 4:
  65. rotateDir = transform.up * -1f;
  66. break;
  67. }
  68. }
  69. transform.Rotate(rotateDir, Time.deltaTime * 200f);
  70. agentRb.AddForce(dirToGo * 2f, ForceMode.VelocityChange);
  71. }
  72. public override void AgentAction(float[] vectorAction, string textAction)
  73. {
  74. AddReward(-1f / agentParameters.maxStep);
  75. MoveAgent(vectorAction);
  76. }
  77. public override void AgentReset()
  78. {
  79. var enumerable = Enumerable.Range(0, 9).OrderBy(x => Guid.NewGuid()).Take(9);
  80. var items = enumerable.ToArray();
  81. myArea.CleanPyramidArea();
  82. agentRb.velocity = Vector3.zero;
  83. myArea.PlaceObject(gameObject, items[0]);
  84. transform.rotation = Quaternion.Euler(new Vector3(0f, Random.Range(0, 360)));
  85. switchLogic.ResetSwitch(items[1], items[2]);
  86. myArea.CreateStonePyramid(1, items[3]);
  87. myArea.CreateStonePyramid(1, items[4]);
  88. myArea.CreateStonePyramid(1, items[5]);
  89. myArea.CreateStonePyramid(1, items[6]);
  90. myArea.CreateStonePyramid(1, items[7]);
  91. myArea.CreateStonePyramid(1, items[8]);
  92. }
  93. private void OnCollisionEnter(Collision collision)
  94. {
  95. if (collision.gameObject.CompareTag("goal"))
  96. {
  97. SetReward(2f);
  98. Done();
  99. }
  100. }
  101. public override void AgentOnDone()
  102. {
  103. }
  104. }