You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

MultinomialTest.cs 5.9KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. using System;
  2. using NUnit.Framework;
  3. using UnityEngine;
  4. using MLAgents.InferenceBrain;
  5. using MLAgents.InferenceBrain.Utils;
  6. namespace MLAgents.Tests
  7. {
  8. public class MultinomialTest
  9. {
  10. [Test]
  11. public void TestEvalP()
  12. {
  13. Multinomial m = new Multinomial(2018);
  14. Tensor src = new Tensor
  15. {
  16. Data = new float[1, 3] {{0.1f, 0.2f, 0.7f}},
  17. ValueType = Tensor.TensorType.FloatingPoint
  18. };
  19. Tensor dst = new Tensor
  20. {
  21. Data = new float[1, 3],
  22. ValueType = Tensor.TensorType.FloatingPoint
  23. };
  24. m.Eval(src, dst);
  25. float[] reference = {2, 2, 1};
  26. int i = 0;
  27. foreach (var f in dst.Data)
  28. {
  29. Assert.AreEqual(reference[i], f);
  30. ++i;
  31. }
  32. }
  33. [Test]
  34. public void TestEvalLogits()
  35. {
  36. Multinomial m = new Multinomial(2018);
  37. Tensor src = new Tensor
  38. {
  39. Data = new float[1, 3] {{Mathf.Log(0.1f) - 50, Mathf.Log(0.2f) - 50, Mathf.Log(0.7f) - 50}},
  40. ValueType = Tensor.TensorType.FloatingPoint
  41. };
  42. Tensor dst = new Tensor
  43. {
  44. Data = new float[1, 3],
  45. ValueType = Tensor.TensorType.FloatingPoint
  46. };
  47. m.Eval(src, dst);
  48. float[] reference = {2, 2, 2};
  49. int i = 0;
  50. foreach (var f in dst.Data)
  51. {
  52. Assert.AreEqual(reference[i], f);
  53. ++i;
  54. }
  55. }
  56. [Test]
  57. public void TestEvalBatching()
  58. {
  59. Multinomial m = new Multinomial(2018);
  60. Tensor src = new Tensor
  61. {
  62. Data = new float[2, 3]
  63. {
  64. {Mathf.Log(0.1f) - 50, Mathf.Log(0.2f) - 50, Mathf.Log(0.7f) - 50},
  65. {Mathf.Log(0.3f) - 25, Mathf.Log(0.4f) - 25, Mathf.Log(0.3f) - 25},
  66. },
  67. ValueType = Tensor.TensorType.FloatingPoint
  68. };
  69. Tensor dst = new Tensor
  70. {
  71. Data = new float[2, 3],
  72. ValueType = Tensor.TensorType.FloatingPoint
  73. };
  74. m.Eval(src, dst);
  75. float[] reference = {2, 2, 2, 0, 1, 0};
  76. int i = 0;
  77. foreach (var f in dst.Data)
  78. {
  79. Assert.AreEqual(reference[i], f);
  80. ++i;
  81. }
  82. }
  83. [Test]
  84. public void TestSrcInt()
  85. {
  86. Multinomial m = new Multinomial(2018);
  87. Tensor src = new Tensor
  88. {
  89. ValueType = Tensor.TensorType.Integer
  90. };
  91. Assert.Throws<NotImplementedException>(() => m.Eval(src, null));
  92. }
  93. [Test]
  94. public void TestDstInt()
  95. {
  96. Multinomial m = new Multinomial(2018);
  97. Tensor src = new Tensor
  98. {
  99. ValueType = Tensor.TensorType.FloatingPoint
  100. };
  101. Tensor dst = new Tensor
  102. {
  103. ValueType = Tensor.TensorType.Integer
  104. };
  105. Assert.Throws<ArgumentException>(() => m.Eval(src, dst));
  106. }
  107. [Test]
  108. public void TestSrcDataNull()
  109. {
  110. Multinomial m = new Multinomial(2018);
  111. Tensor src = new Tensor
  112. {
  113. ValueType = Tensor.TensorType.FloatingPoint
  114. };
  115. Tensor dst = new Tensor
  116. {
  117. ValueType = Tensor.TensorType.FloatingPoint
  118. };
  119. Assert.Throws<ArgumentNullException>(() => m.Eval(src, dst));
  120. }
  121. [Test]
  122. public void TestDstDataNull()
  123. {
  124. Multinomial m = new Multinomial(2018);
  125. Tensor src = new Tensor
  126. {
  127. ValueType = Tensor.TensorType.FloatingPoint,
  128. Data = new float[1]
  129. };
  130. Tensor dst = new Tensor
  131. {
  132. ValueType = Tensor.TensorType.FloatingPoint
  133. };
  134. Assert.Throws<ArgumentNullException>(() => m.Eval(src, dst));
  135. }
  136. [Test]
  137. public void TestSrcWrongShape()
  138. {
  139. Multinomial m = new Multinomial(2018);
  140. Tensor src = new Tensor
  141. {
  142. ValueType = Tensor.TensorType.FloatingPoint,
  143. Data = new float[1]
  144. };
  145. Tensor dst = new Tensor
  146. {
  147. ValueType = Tensor.TensorType.FloatingPoint,
  148. Data = new float[1]
  149. };
  150. Assert.Throws<ArgumentException>(() => m.Eval(src, dst));
  151. }
  152. [Test]
  153. public void TestDstWrongShape()
  154. {
  155. Multinomial m = new Multinomial(2018);
  156. Tensor src = new Tensor
  157. {
  158. ValueType = Tensor.TensorType.FloatingPoint,
  159. Data = new float[1, 1]
  160. };
  161. Tensor dst = new Tensor
  162. {
  163. ValueType = Tensor.TensorType.FloatingPoint,
  164. Data = new float[1]
  165. };
  166. Assert.Throws<ArgumentException>(() => m.Eval(src, dst));
  167. }
  168. [Test]
  169. public void TestUnequalBatchSize()
  170. {
  171. Multinomial m = new Multinomial(2018);
  172. Tensor src = new Tensor
  173. {
  174. ValueType = Tensor.TensorType.FloatingPoint,
  175. Data = new float[1, 1]
  176. };
  177. Tensor dst = new Tensor
  178. {
  179. ValueType = Tensor.TensorType.FloatingPoint,
  180. Data = new float[2, 1]
  181. };
  182. Assert.Throws<ArgumentException>(() => m.Eval(src, dst));
  183. }
  184. }
  185. }