You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

EditModeTestInternalBrainTensorGenerator.cs 5.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using NUnit.Framework;
  5. using UnityEngine;
  6. using MLAgents.InferenceBrain;
  7. namespace MLAgents.Tests
  8. {
  9. public class EditModeTestInternalBrainTensorGenerator
  10. {
  11. private class TestAgent : Agent
  12. {
  13. }
  14. private Dictionary<Agent, AgentInfo> GetFakeAgentInfos()
  15. {
  16. var goA = new GameObject("goA");
  17. var agentA = goA.AddComponent<TestAgent>();
  18. var infoA = new AgentInfo()
  19. {
  20. stackedVectorObservation = (new float[] {1f, 2f, 3f}).ToList(),
  21. memories = null,
  22. storedVectorActions = new float[] {1, 2},
  23. actionMasks = null,
  24. };
  25. var goB = new GameObject("goB");
  26. var agentB = goB.AddComponent<TestAgent>();
  27. var infoB = new AgentInfo()
  28. {
  29. stackedVectorObservation = (new float[] {4f, 5f, 6f}).ToList(),
  30. memories = (new float[] {1f, 1f, 1f}).ToList(),
  31. storedVectorActions = new float[] {3, 4},
  32. actionMasks = new bool[] {true, false, false, false, false},
  33. };
  34. return new Dictionary<Agent, AgentInfo>(){{agentA, infoA},{agentB, infoB}};
  35. }
  36. [Test]
  37. public void Contruction()
  38. {
  39. var bp = new BrainParameters();
  40. var tensorGenerator = new TensorGenerator(bp, 0);
  41. Assert.IsNotNull(tensorGenerator);
  42. }
  43. [Test]
  44. public void GenerateBatchSize()
  45. {
  46. var inputTensor = new Tensor();
  47. var batchSize = 4;
  48. var generator = new BatchSizeGenerator();
  49. generator.Generate(inputTensor, batchSize, null);
  50. Assert.IsNotNull(inputTensor.Data as int[]);
  51. Assert.AreEqual((inputTensor.Data as int[])[0], batchSize);
  52. }
  53. [Test]
  54. public void GenerateSequenceLength()
  55. {
  56. var inputTensor = new Tensor();
  57. var batchSize = 4;
  58. var generator = new SequenceLengthGenerator();
  59. generator.Generate(inputTensor, batchSize, null);
  60. Assert.IsNotNull(inputTensor.Data as int[]);
  61. Assert.AreEqual((inputTensor.Data as int[])[0], 1);
  62. }
  63. [Test]
  64. public void GenerateVectorObservation()
  65. {
  66. var inputTensor = new Tensor()
  67. {
  68. Shape = new long[] {2, 3}
  69. };
  70. var batchSize = 4;
  71. var agentInfos = GetFakeAgentInfos();
  72. var generator = new VectorObservationGenerator();
  73. generator.Generate(inputTensor, batchSize, agentInfos);
  74. Assert.IsNotNull(inputTensor.Data as float[,]);
  75. Assert.AreEqual((inputTensor.Data as float[,])[0, 0], 1);
  76. Assert.AreEqual((inputTensor.Data as float[,])[0, 2], 3);
  77. Assert.AreEqual((inputTensor.Data as float[,])[1, 0], 4);
  78. Assert.AreEqual((inputTensor.Data as float[,])[1, 2], 6);
  79. }
  80. [Test]
  81. public void GenerateRecurrentInput()
  82. {
  83. var inputTensor = new Tensor()
  84. {
  85. Shape = new long[] {2, 5}
  86. };
  87. var batchSize = 4;
  88. var agentInfos = GetFakeAgentInfos();
  89. var generator = new RecurrentInputGenerator();
  90. generator.Generate(inputTensor, batchSize, agentInfos);
  91. Assert.IsNotNull(inputTensor.Data as float[,]);
  92. Assert.AreEqual((inputTensor.Data as float[,])[0, 0], 0);
  93. Assert.AreEqual((inputTensor.Data as float[,])[0, 4], 0);
  94. Assert.AreEqual((inputTensor.Data as float[,])[1, 0], 1);
  95. Assert.AreEqual((inputTensor.Data as float[,])[1, 4], 0);
  96. }
  97. [Test]
  98. public void GeneratePreviousActionInput()
  99. {
  100. var inputTensor = new Tensor()
  101. {
  102. Shape = new long[] {2, 2},
  103. ValueType = Tensor.TensorType.Integer
  104. };
  105. var batchSize = 4;
  106. var agentInfos = GetFakeAgentInfos();
  107. var generator = new PreviousActionInputGenerator();
  108. generator.Generate(inputTensor, batchSize, agentInfos);
  109. Assert.IsNotNull(inputTensor.Data as int[,]);
  110. Assert.AreEqual((inputTensor.Data as int[,])[0, 0], 1);
  111. Assert.AreEqual((inputTensor.Data as int[,])[0, 1], 2);
  112. Assert.AreEqual((inputTensor.Data as int[,])[1, 0], 3);
  113. Assert.AreEqual((inputTensor.Data as int[,])[1, 1], 4);
  114. }
  115. [Test]
  116. public void GenerateActionMaskInput()
  117. {
  118. var inputTensor = new Tensor()
  119. {
  120. Shape = new long[] {2, 5},
  121. ValueType = Tensor.TensorType.FloatingPoint
  122. };
  123. var batchSize = 4;
  124. var agentInfos = GetFakeAgentInfos();
  125. var generator = new ActionMaskInputGenerator();
  126. generator.Generate(inputTensor, batchSize, agentInfos);
  127. Assert.IsNotNull(inputTensor.Data as float[,]);
  128. Assert.AreEqual((inputTensor.Data as float[,])[0, 0], 1);
  129. Assert.AreEqual((inputTensor.Data as float[,])[0, 4], 1);
  130. Assert.AreEqual((inputTensor.Data as float[,])[1, 0], 0);
  131. Assert.AreEqual((inputTensor.Data as float[,])[1, 4], 1);
  132. }
  133. }
  134. }