You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

EditModeTestInternalBrainTensorApplier.cs 4.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. using System.Collections.Generic;
  2. using System.Linq;
  3. using NUnit.Framework;
  4. using UnityEngine;
  5. using System.Reflection;
  6. using MLAgents.InferenceBrain;
  7. namespace MLAgents.Tests
  8. {
  9. public class EditModeTestInternalBrainTensorApplier
  10. {
  11. private class TestAgent : Agent
  12. {
  13. public AgentAction GetAction()
  14. {
  15. FieldInfo f = typeof(Agent).GetField(
  16. "action", BindingFlags.Instance | BindingFlags.NonPublic);
  17. return (AgentAction) f.GetValue(this);
  18. }
  19. }
  20. private Dictionary<Agent, AgentInfo> GetFakeAgentInfos()
  21. {
  22. var goA = new GameObject("goA");
  23. var agentA = goA.AddComponent<TestAgent>();
  24. var infoA = new AgentInfo();
  25. var goB = new GameObject("goB");
  26. var agentB = goB.AddComponent<TestAgent>();
  27. var infoB = new AgentInfo();
  28. return new Dictionary<Agent, AgentInfo>(){{agentA, infoA},{agentB, infoB}};
  29. }
  30. [Test]
  31. public void Contruction()
  32. {
  33. var bp = new BrainParameters();
  34. var tensorGenerator = new TensorApplier(bp, 0);
  35. Assert.IsNotNull(tensorGenerator);
  36. }
  37. [Test]
  38. public void ApplyContinuousActionOutput()
  39. {
  40. var inputTensor = new Tensor()
  41. {
  42. Shape = new long[] {2, 3},
  43. Data = new float[,] {{1, 2, 3}, {4, 5, 6}}
  44. };
  45. var agentInfos = GetFakeAgentInfos();
  46. var applier = new ContinuousActionOutputApplier();
  47. applier.Apply(inputTensor, agentInfos);
  48. var agents = agentInfos.Keys.ToList();
  49. var agent = agents[0] as TestAgent;
  50. var action = agent.GetAction();
  51. Assert.AreEqual(action.vectorActions[0], 1);
  52. Assert.AreEqual(action.vectorActions[1], 2);
  53. Assert.AreEqual(action.vectorActions[2], 3);
  54. agent = agents[1] as TestAgent;
  55. action = agent.GetAction();
  56. Assert.AreEqual(action.vectorActions[0], 4);
  57. Assert.AreEqual(action.vectorActions[1], 5);
  58. Assert.AreEqual(action.vectorActions[2], 6);
  59. }
  60. [Test]
  61. public void ApplyDiscreteActionOutput()
  62. {
  63. var inputTensor = new Tensor()
  64. {
  65. Shape = new long[] {2, 5},
  66. Data = new float[,] {{0.5f, 22.5f, 0.1f, 5f, 1f},
  67. {4f, 5f, 6f, 7f, 8f}}
  68. };
  69. var agentInfos = GetFakeAgentInfos();
  70. var applier = new DiscreteActionOutputApplier(new int[]{2, 3}, 0);
  71. applier.Apply(inputTensor, agentInfos);
  72. var agents = agentInfos.Keys.ToList();
  73. var agent = agents[0] as TestAgent;
  74. var action = agent.GetAction();
  75. Assert.AreEqual(action.vectorActions[0], 1);
  76. Assert.AreEqual(action.vectorActions[1], 1);
  77. agent = agents[1] as TestAgent;
  78. action = agent.GetAction();
  79. Assert.AreEqual(action.vectorActions[0], 1);
  80. Assert.AreEqual(action.vectorActions[1], 2);
  81. }
  82. [Test]
  83. public void ApplyMemoryOutput()
  84. {
  85. var inputTensor = new Tensor()
  86. {
  87. Shape = new long[] {2, 5},
  88. Data = new float[,] {{0.5f, 22.5f, 0.1f, 5f, 1f},
  89. {4f, 5f, 6f, 7f, 8f}}
  90. };
  91. var agentInfos = GetFakeAgentInfos();
  92. var applier = new MemoryOutputApplier();
  93. applier.Apply(inputTensor, agentInfos);
  94. var agents = agentInfos.Keys.ToList();
  95. var agent = agents[0] as TestAgent;
  96. var action = agent.GetAction();
  97. Assert.AreEqual(action.memories[0], 0.5f);
  98. Assert.AreEqual(action.memories[1], 22.5f);
  99. agent = agents[1] as TestAgent;
  100. action = agent.GetAction();
  101. Assert.AreEqual(action.memories[2], 6);
  102. Assert.AreEqual(action.memories[3], 7);
  103. }
  104. [Test]
  105. public void ApplyValueEstimate()
  106. {
  107. var inputTensor = new Tensor()
  108. {
  109. Shape = new long[] {2, 1},
  110. Data = new float[,] {{0.5f}, {8f}}
  111. };
  112. var agentInfos = GetFakeAgentInfos();
  113. var applier = new ValueEstimateApplier();
  114. applier.Apply(inputTensor, agentInfos);
  115. var agents = agentInfos.Keys.ToList();
  116. var agent = agents[0] as TestAgent;
  117. var action = agent.GetAction();
  118. Assert.AreEqual(action.value, 0.5f);
  119. agent = agents[1] as TestAgent;
  120. action = agent.GetAction();
  121. Assert.AreEqual(action.value, 8);
  122. }
  123. }
  124. }