123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224 |
- using System;
- using NUnit.Framework;
- using UnityEngine;
- using MLAgents.InferenceBrain;
- using MLAgents.InferenceBrain.Utils;
-
- namespace MLAgents.Tests
- {
- public class MultinomialTest
- {
- [Test]
- public void TestEvalP()
- {
- Multinomial m = new Multinomial(2018);
-
- Tensor src = new Tensor
- {
- Data = new float[1, 3] {{0.1f, 0.2f, 0.7f}},
- ValueType = Tensor.TensorType.FloatingPoint
- };
-
- Tensor dst = new Tensor
- {
- Data = new float[1, 3],
- ValueType = Tensor.TensorType.FloatingPoint
- };
-
- m.Eval(src, dst);
-
- float[] reference = {2, 2, 1};
- int i = 0;
- foreach (var f in dst.Data)
- {
- Assert.AreEqual(reference[i], f);
- ++i;
- }
- }
-
- [Test]
- public void TestEvalLogits()
- {
- Multinomial m = new Multinomial(2018);
-
- Tensor src = new Tensor
- {
- Data = new float[1, 3] {{Mathf.Log(0.1f) - 50, Mathf.Log(0.2f) - 50, Mathf.Log(0.7f) - 50}},
- ValueType = Tensor.TensorType.FloatingPoint
- };
-
- Tensor dst = new Tensor
- {
- Data = new float[1, 3],
- ValueType = Tensor.TensorType.FloatingPoint
- };
-
- m.Eval(src, dst);
-
- float[] reference = {2, 2, 2};
- int i = 0;
- foreach (var f in dst.Data)
- {
- Assert.AreEqual(reference[i], f);
- ++i;
- }
- }
-
- [Test]
- public void TestEvalBatching()
- {
- Multinomial m = new Multinomial(2018);
-
- Tensor src = new Tensor
- {
- Data = new float[2, 3]
- {
- {Mathf.Log(0.1f) - 50, Mathf.Log(0.2f) - 50, Mathf.Log(0.7f) - 50},
- {Mathf.Log(0.3f) - 25, Mathf.Log(0.4f) - 25, Mathf.Log(0.3f) - 25},
-
- },
- ValueType = Tensor.TensorType.FloatingPoint
- };
-
- Tensor dst = new Tensor
- {
- Data = new float[2, 3],
- ValueType = Tensor.TensorType.FloatingPoint
- };
-
- m.Eval(src, dst);
-
- float[] reference = {2, 2, 2, 0, 1, 0};
- int i = 0;
- foreach (var f in dst.Data)
- {
- Assert.AreEqual(reference[i], f);
- ++i;
- }
- }
-
- [Test]
- public void TestSrcInt()
- {
- Multinomial m = new Multinomial(2018);
-
- Tensor src = new Tensor
- {
- ValueType = Tensor.TensorType.Integer
- };
-
- Assert.Throws<NotImplementedException>(() => m.Eval(src, null));
- }
-
- [Test]
- public void TestDstInt()
- {
- Multinomial m = new Multinomial(2018);
-
- Tensor src = new Tensor
- {
- ValueType = Tensor.TensorType.FloatingPoint
- };
- Tensor dst = new Tensor
- {
- ValueType = Tensor.TensorType.Integer
- };
-
- Assert.Throws<ArgumentException>(() => m.Eval(src, dst));
- }
-
- [Test]
- public void TestSrcDataNull()
- {
- Multinomial m = new Multinomial(2018);
-
- Tensor src = new Tensor
- {
- ValueType = Tensor.TensorType.FloatingPoint
- };
- Tensor dst = new Tensor
- {
- ValueType = Tensor.TensorType.FloatingPoint
- };
-
- Assert.Throws<ArgumentNullException>(() => m.Eval(src, dst));
- }
-
- [Test]
- public void TestDstDataNull()
- {
- Multinomial m = new Multinomial(2018);
-
- Tensor src = new Tensor
- {
- ValueType = Tensor.TensorType.FloatingPoint,
- Data = new float[1]
- };
- Tensor dst = new Tensor
- {
- ValueType = Tensor.TensorType.FloatingPoint
- };
-
- Assert.Throws<ArgumentNullException>(() => m.Eval(src, dst));
- }
-
- [Test]
- public void TestSrcWrongShape()
- {
- Multinomial m = new Multinomial(2018);
-
- Tensor src = new Tensor
- {
- ValueType = Tensor.TensorType.FloatingPoint,
- Data = new float[1]
- };
- Tensor dst = new Tensor
- {
- ValueType = Tensor.TensorType.FloatingPoint,
- Data = new float[1]
- };
-
- Assert.Throws<ArgumentException>(() => m.Eval(src, dst));
- }
-
- [Test]
- public void TestDstWrongShape()
- {
- Multinomial m = new Multinomial(2018);
-
- Tensor src = new Tensor
- {
- ValueType = Tensor.TensorType.FloatingPoint,
- Data = new float[1, 1]
- };
- Tensor dst = new Tensor
- {
- ValueType = Tensor.TensorType.FloatingPoint,
- Data = new float[1]
- };
-
- Assert.Throws<ArgumentException>(() => m.Eval(src, dst));
- }
-
- [Test]
- public void TestUnequalBatchSize()
- {
- Multinomial m = new Multinomial(2018);
-
- Tensor src = new Tensor
- {
- ValueType = Tensor.TensorType.FloatingPoint,
- Data = new float[1, 1]
- };
- Tensor dst = new Tensor
- {
- ValueType = Tensor.TensorType.FloatingPoint,
- Data = new float[2, 1]
- };
-
- Assert.Throws<ArgumentException>(() => m.Eval(src, dst));
- }
-
-
- }
- }
|