You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

EditModeTestActionMasker.cs 4.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. using NUnit.Framework;
  2. namespace MLAgents.Tests
  3. {
  4. public class EditModeTestActionMasker
  5. {
  6. [Test]
  7. public void Contruction()
  8. {
  9. var bp = new BrainParameters();
  10. var masker = new ActionMasker(bp);
  11. Assert.IsNotNull(masker);
  12. }
  13. [Test]
  14. public void FailsWithContinuous()
  15. {
  16. var bp = new BrainParameters();
  17. bp.vectorActionSpaceType = SpaceType.continuous;
  18. bp.vectorActionSize = new int[1] {4};
  19. var masker = new ActionMasker(bp);
  20. masker.SetActionMask(0, new int[1] {0});
  21. Assert.Catch<UnityAgentsException>(() => masker.GetMask());
  22. }
  23. [Test]
  24. public void NullMask()
  25. {
  26. var bp = new BrainParameters();
  27. bp.vectorActionSpaceType = SpaceType.discrete;
  28. var masker = new ActionMasker(bp);
  29. var mask = masker.GetMask();
  30. Assert.IsNull(mask);
  31. }
  32. [Test]
  33. public void FirstBranchMask()
  34. {
  35. var bp = new BrainParameters();
  36. bp.vectorActionSpaceType = SpaceType.discrete;
  37. bp.vectorActionSize = new int[3] {4, 5, 6};
  38. var masker = new ActionMasker(bp);
  39. var mask = masker.GetMask();
  40. Assert.IsNull(mask);
  41. masker.SetActionMask(0, new int[]{1,2,3});
  42. mask = masker.GetMask();
  43. Assert.IsFalse(mask[0]);
  44. Assert.IsTrue(mask[1]);
  45. Assert.IsTrue(mask[2]);
  46. Assert.IsTrue(mask[3]);
  47. Assert.IsFalse(mask[4]);
  48. Assert.AreEqual(mask.Length, 15);
  49. }
  50. [Test]
  51. public void SecondBranchMask()
  52. {
  53. var bp = new BrainParameters();
  54. bp.vectorActionSpaceType = SpaceType.discrete;
  55. bp.vectorActionSize = new int[3] {4, 5, 6};
  56. var masker = new ActionMasker(bp);
  57. bool[] mask = masker.GetMask();
  58. masker.SetActionMask(1, new int[]{1,2,3});
  59. mask = masker.GetMask();
  60. Assert.IsFalse(mask[0]);
  61. Assert.IsFalse(mask[4]);
  62. Assert.IsTrue(mask[5]);
  63. Assert.IsTrue(mask[6]);
  64. Assert.IsTrue(mask[7]);
  65. Assert.IsFalse(mask[8]);
  66. Assert.IsFalse(mask[9]);
  67. }
  68. [Test]
  69. public void MaskReset()
  70. {
  71. var bp = new BrainParameters();
  72. bp.vectorActionSpaceType = SpaceType.discrete;
  73. bp.vectorActionSize = new int[3] {4, 5, 6};
  74. var masker = new ActionMasker(bp);
  75. var mask = masker.GetMask();
  76. masker.SetActionMask(1, new int[3]{1,2,3});
  77. mask = masker.GetMask();
  78. masker.ResetMask();
  79. mask = masker.GetMask();
  80. for (var i = 0; i < 15; i++)
  81. {
  82. Assert.IsFalse(mask[i]);
  83. }
  84. }
  85. [Test]
  86. public void ThrowsError()
  87. {
  88. var bp = new BrainParameters();
  89. bp.vectorActionSpaceType = SpaceType.discrete;
  90. bp.vectorActionSize = new int[3] {4, 5, 6};
  91. var masker = new ActionMasker(bp);
  92. Assert.Catch<UnityAgentsException>(
  93. () => masker.SetActionMask(0, new int[1]{5}));
  94. Assert.Catch<UnityAgentsException>(
  95. () => masker.SetActionMask(1, new int[1]{5}));
  96. masker.SetActionMask(2, new int[1] {5});
  97. Assert.Catch<UnityAgentsException>(
  98. () => masker.SetActionMask(3, new int[1]{1}));
  99. masker.GetMask();
  100. masker.ResetMask();
  101. masker.SetActionMask(0, new int[4] {0, 1, 2, 3});
  102. Assert.Catch<UnityAgentsException>(
  103. () => masker.GetMask());
  104. }
  105. [Test]
  106. public void MultipleMaskEdit()
  107. {
  108. var bp = new BrainParameters();
  109. bp.vectorActionSpaceType = SpaceType.discrete;
  110. bp.vectorActionSize = new int[3] {4, 5, 6};
  111. var masker = new ActionMasker(bp);
  112. masker.SetActionMask(0, new int[2] {0, 1});
  113. masker.SetActionMask(0, new int[1] {3});
  114. masker.SetActionMask(2, new int[1] {1});
  115. var mask = masker.GetMask();
  116. for (var i = 0; i < 15; i++)
  117. {
  118. if ((i == 0) || (i == 1) || (i == 3)|| (i == 10))
  119. {
  120. Assert.IsTrue(mask[i]);
  121. }
  122. else
  123. {
  124. Assert.IsFalse(mask[i]);
  125. }
  126. }
  127. }
  128. }
  129. }