You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

RandomNormalTest.cs 2.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. using System;
  2. using NUnit.Framework;
  3. using MLAgents.InferenceBrain;
  4. using MLAgents.InferenceBrain.Utils;
  5. namespace MLAgents.Tests
  6. {
  7. public class RandomNormalTest
  8. {
  9. [Test]
  10. public void RandomNormalTestTwoDouble()
  11. {
  12. RandomNormal rn = new RandomNormal(2018);
  13. Assert.AreEqual(-0.46666, rn.NextDouble(), 0.0001);
  14. Assert.AreEqual(-0.37989, rn.NextDouble(), 0.0001);
  15. }
  16. [Test]
  17. public void RandomNormalTestWithMean()
  18. {
  19. RandomNormal rn = new RandomNormal(2018, 5.0f);
  20. Assert.AreEqual(4.53333, rn.NextDouble(), 0.0001);
  21. Assert.AreEqual(4.6201, rn.NextDouble(), 0.0001);
  22. }
  23. [Test]
  24. public void RandomNormalTestWithStddev()
  25. {
  26. RandomNormal rn = new RandomNormal(2018, 1.0f, 4.2f);
  27. Assert.AreEqual(-0.9599, rn.NextDouble(), 0.0001);
  28. Assert.AreEqual(-0.5955, rn.NextDouble(), 0.0001);
  29. }
  30. [Test]
  31. public void RandomNormalTestWithMeanStddev()
  32. {
  33. RandomNormal rn = new RandomNormal(2018, -3.2f, 2.2f);
  34. Assert.AreEqual(-4.2266, rn.NextDouble(), 0.0001);
  35. Assert.AreEqual(-4.0357, rn.NextDouble(), 0.0001);
  36. }
  37. [Test]
  38. public void RandomNormalTestTensorInt()
  39. {
  40. RandomNormal rn = new RandomNormal(1982);
  41. Tensor t = new Tensor
  42. {
  43. ValueType = Tensor.TensorType.Integer
  44. };
  45. Assert.Throws<NotImplementedException>(() => rn.FillTensor(t));
  46. }
  47. [Test]
  48. public void RandomNormalTestDataNull()
  49. {
  50. RandomNormal rn = new RandomNormal(1982);
  51. Tensor t = new Tensor
  52. {
  53. ValueType = Tensor.TensorType.FloatingPoint
  54. };
  55. Assert.Throws<ArgumentNullException>(() => rn.FillTensor(t));
  56. }
  57. [Test]
  58. public void RandomNormalTestTensor()
  59. {
  60. RandomNormal rn = new RandomNormal(1982);
  61. Tensor t = new Tensor
  62. {
  63. ValueType = Tensor.TensorType.FloatingPoint,
  64. Data = Array.CreateInstance(typeof(float), new long[3] {3, 4, 2})
  65. };
  66. rn.FillTensor(t);
  67. float[] reference = new float[]
  68. {
  69. -0.2139822f,
  70. 0.5051259f,
  71. -0.5640336f,
  72. -0.3357787f,
  73. -0.2055894f,
  74. -0.09432302f,
  75. -0.01419199f,
  76. 0.53621f,
  77. -0.5507085f,
  78. -0.2651141f,
  79. 0.09315512f,
  80. -0.04918706f,
  81. -0.179625f,
  82. 0.2280539f,
  83. 0.1883962f,
  84. 0.4047216f,
  85. 0.1704049f,
  86. 0.5050544f,
  87. -0.3365685f,
  88. 0.3542781f,
  89. 0.5951571f,
  90. 0.03460682f,
  91. -0.5537263f,
  92. -0.4378373f,
  93. };
  94. int i = 0;
  95. foreach (float f in t.Data)
  96. {
  97. Assert.AreEqual(f, reference[i], 0.0001);
  98. ++i;
  99. }
  100. }
  101. }
  102. }