You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

BrainParametersDrawer.cs 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. using UnityEngine;
  2. using UnityEditor;
  3. namespace MLAgents
  4. {
  5. /// <summary>
  6. /// PropertyDrawer for BrainParameters. Defines how BrainParameters are displayed in the
  7. /// Inspector.
  8. /// </summary>
  9. [CustomPropertyDrawer(typeof(BrainParameters))]
  10. public class BrainParametersDrawer : PropertyDrawer
  11. {
  12. // The height of a line in the Unity Inspectors
  13. private const float LineHeight = 17f;
  14. private const int VecObsNumLine = 3;
  15. private const string CamResPropName = "cameraResolutions";
  16. private const string ActionSizePropName = "vectorActionSize";
  17. private const string ActionTypePropName = "vectorActionSpaceType";
  18. private const string ActionDescriptionPropName = "vectorActionDescriptions";
  19. private const string VecObsPropName = "vectorObservationSize";
  20. private const string NumVecObsPropName ="numStackedVectorObservations";
  21. private const string CamWidthPropName = "width";
  22. private const string CamHeightPropName = "height";
  23. private const string CamGrayPropName = "blackAndWhite";
  24. private const int DefaultCameraWidth = 84;
  25. private const int DefaultCameraHeight = 84;
  26. private const bool DefaultCameraGray = false;
  27. /// <inheritdoc />
  28. public override float GetPropertyHeight(SerializedProperty property, GUIContent label)
  29. {
  30. if (property.isExpanded)
  31. {
  32. return LineHeight +
  33. GetHeightDrawVectorObservation() +
  34. GetHeightDrawVisualObservation(property) +
  35. GetHeightDrawVectorAction(property) +
  36. GetHeightDrawVectorActionDescriptions(property);
  37. }
  38. return LineHeight;
  39. }
  40. /// <inheritdoc />
  41. public override void OnGUI(Rect position, SerializedProperty property, GUIContent label)
  42. {
  43. var indent = EditorGUI.indentLevel;
  44. EditorGUI.indentLevel = 0;
  45. position.height = LineHeight;
  46. property.isExpanded = EditorGUI.Foldout(position, property.isExpanded, label);
  47. position.y += LineHeight;
  48. if (property.isExpanded)
  49. {
  50. EditorGUI.BeginProperty(position, label, property);
  51. EditorGUI.indentLevel++;
  52. // Vector Observations
  53. DrawVectorObservation(position, property);
  54. position.y += GetHeightDrawVectorObservation();
  55. //Visual Observations
  56. DrawVisualObservations(position, property);
  57. position.y += GetHeightDrawVisualObservation(property);
  58. // Vector Action
  59. DrawVectorAction(position, property);
  60. position.y += GetHeightDrawVectorAction(property);
  61. // Vector Action Descriptions
  62. DrawVectorActionDescriptions(position, property);
  63. position.y += GetHeightDrawVectorActionDescriptions(property);
  64. EditorGUI.EndProperty();
  65. }
  66. EditorGUI.indentLevel = indent;
  67. }
  68. /// <summary>
  69. /// Draws the Vector Observations for the Brain Parameters
  70. /// </summary>
  71. /// <param name="position">Rectangle on the screen to use for the property GUI.</param>
  72. /// <param name="property">The SerializedProperty of the BrainParameters
  73. /// to make the custom GUI for.</param>
  74. private static void DrawVectorObservation(Rect position, SerializedProperty property)
  75. {
  76. EditorGUI.LabelField(position, "Vector Observation");
  77. position.y += LineHeight;
  78. EditorGUI.indentLevel++;
  79. EditorGUI.PropertyField(position,
  80. property.FindPropertyRelative(VecObsPropName),
  81. new GUIContent("Space Size",
  82. "Length of state " +
  83. "vector for brain (In Continuous state space)." +
  84. "Or number of possible values (in Discrete state space)."));
  85. position.y += LineHeight;
  86. EditorGUI.PropertyField(position,
  87. property.FindPropertyRelative(NumVecObsPropName),
  88. new GUIContent("Stacked Vectors",
  89. "Number of states that will be stacked before " +
  90. "being fed to the neural network."));
  91. position.y += LineHeight;
  92. EditorGUI.indentLevel--;
  93. }
  94. /// <summary>
  95. /// The Height required to draw the Vector Observations paramaters
  96. /// </summary>
  97. /// <returns>The height of the drawer of the Vector Observations </returns>
  98. private static float GetHeightDrawVectorObservation()
  99. {
  100. return VecObsNumLine * LineHeight;
  101. }
  102. /// <summary>
  103. /// Draws the Visual Observations parameters for the Brain Parameters
  104. /// </summary>
  105. /// <param name="position">Rectangle on the screen to use for the property GUI.</param>
  106. /// <param name="property">The SerializedProperty of the BrainParameters
  107. /// to make the custom GUI for.</param>
  108. private static void DrawVisualObservations(Rect position, SerializedProperty property)
  109. {
  110. EditorGUI.LabelField(position, "Visual Observations");
  111. position.y += LineHeight;
  112. var quarter = position.width / 4;
  113. var resolutions = property.FindPropertyRelative(CamResPropName);
  114. DrawVisualObsButtons(position, resolutions);
  115. position.y += LineHeight;
  116. // Display the labels for the columns : Index, Width, Height and Gray
  117. var indexRect = new Rect(position.x, position.y, quarter, position.height);
  118. var widthRect = new Rect(position.x + quarter, position.y, quarter, position.height);
  119. var heightRect = new Rect(position.x + 2*quarter, position.y, quarter, position.height);
  120. var bwRect = new Rect(position.x + 3*quarter, position.y, quarter, position.height);
  121. EditorGUI.indentLevel++;
  122. if (resolutions.arraySize > 0)
  123. {
  124. EditorGUI.LabelField(indexRect, "Index");
  125. indexRect.y += LineHeight;
  126. EditorGUI.LabelField(widthRect, "Width");
  127. widthRect.y += LineHeight;
  128. EditorGUI.LabelField(heightRect, "Height");
  129. heightRect.y += LineHeight;
  130. EditorGUI.LabelField(bwRect, "Gray");
  131. bwRect.y += LineHeight;
  132. }
  133. // Iterate over the resolutions
  134. for (var i = 0; i < resolutions.arraySize; i++)
  135. {
  136. EditorGUI.LabelField(indexRect, "Obs " + i);
  137. indexRect.y += LineHeight;
  138. var res = resolutions.GetArrayElementAtIndex(i);
  139. var w = res.FindPropertyRelative("width");
  140. w.intValue = EditorGUI.IntField(widthRect, w.intValue);
  141. widthRect.y += LineHeight;
  142. var h = res.FindPropertyRelative("height");
  143. h.intValue = EditorGUI.IntField(heightRect, h.intValue);
  144. heightRect.y += LineHeight;
  145. var bw = res.FindPropertyRelative("blackAndWhite");
  146. bw.boolValue = EditorGUI.Toggle(bwRect, bw.boolValue);
  147. bwRect.y += LineHeight;
  148. }
  149. EditorGUI.indentLevel--;
  150. }
  151. /// <summary>
  152. /// Draws the buttons to add and remove the visual observations parameters
  153. /// </summary>
  154. /// <param name="position">Rectangle on the screen to use for the property GUI.</param>
  155. /// <param name="resolutions">The SerializedProperty of the resolution array
  156. /// to make the custom GUI for.</param>
  157. private static void DrawVisualObsButtons(Rect position, SerializedProperty resolutions)
  158. {
  159. var widthEighth = position.width / 8;
  160. var addButtonRect = new Rect(position.x + widthEighth, position.y,
  161. 3 * widthEighth, position.height);
  162. var removeButtonRect = new Rect(position.x + 4 * widthEighth, position.y,
  163. 3 * widthEighth, position.height);
  164. if (resolutions.arraySize == 0)
  165. {
  166. addButtonRect.width *= 2;
  167. }
  168. // Display the buttons
  169. if (GUI.Button(addButtonRect, "Add New", EditorStyles.miniButton))
  170. {
  171. resolutions.arraySize += 1;
  172. var newRes = resolutions.GetArrayElementAtIndex(resolutions.arraySize - 1);
  173. newRes.FindPropertyRelative(CamWidthPropName).intValue = DefaultCameraWidth;
  174. newRes.FindPropertyRelative(CamHeightPropName).intValue = DefaultCameraHeight;
  175. newRes.FindPropertyRelative(CamGrayPropName).boolValue = DefaultCameraGray;
  176. }
  177. if (resolutions.arraySize > 0)
  178. {
  179. if (GUI.Button(removeButtonRect, "Remove Last", EditorStyles.miniButton))
  180. {
  181. resolutions.arraySize -= 1;
  182. }
  183. }
  184. }
  185. /// <summary>
  186. /// The Height required to draw the Visual Observations parameters
  187. /// </summary>
  188. /// <returns>The height of the drawer of the Visual Observations </returns>
  189. private static float GetHeightDrawVisualObservation(SerializedProperty property)
  190. {
  191. var visObsSize = property.FindPropertyRelative(CamResPropName).arraySize + 2;
  192. if (property.FindPropertyRelative(CamResPropName).arraySize > 0)
  193. {
  194. visObsSize += 1;
  195. }
  196. return LineHeight * visObsSize;
  197. }
  198. /// <summary>
  199. /// Draws the Vector Actions parameters for the Brain Parameters
  200. /// </summary>
  201. /// <param name="position">Rectangle on the screen to use for the property GUI.</param>
  202. /// <param name="property">The SerializedProperty of the BrainParameters
  203. /// to make the custom GUI for.</param>
  204. private static void DrawVectorAction(Rect position, SerializedProperty property)
  205. {
  206. EditorGUI.LabelField(position, "Vector Action");
  207. position.y += LineHeight;
  208. EditorGUI.indentLevel++;
  209. var bpVectorActionType = property.FindPropertyRelative(ActionTypePropName);
  210. EditorGUI.PropertyField(
  211. position,
  212. bpVectorActionType,
  213. new GUIContent("Space Type",
  214. "Corresponds to whether state vector contains a single integer (Discrete) " +
  215. "or a series of real-valued floats (Continuous)."));
  216. position.y += LineHeight;
  217. if (bpVectorActionType.enumValueIndex == 1)
  218. {
  219. DrawContinuousVectorAction(position, property);
  220. }
  221. else
  222. {
  223. DrawDiscreteVectorAction(position, property);
  224. }
  225. }
  226. /// <summary>
  227. /// Draws the Continuous Vector Actions parameters for the Brain Parameters
  228. /// </summary>
  229. /// <param name="position">Rectangle on the screen to use for the property GUI.</param>
  230. /// <param name="property">The SerializedProperty of the BrainParameters
  231. /// to make the custom GUI for.</param>
  232. private static void DrawContinuousVectorAction(Rect position, SerializedProperty property)
  233. {
  234. var vecActionSize = property.FindPropertyRelative(ActionSizePropName);
  235. vecActionSize.arraySize = 1;
  236. SerializedProperty continuousActionSize =
  237. vecActionSize.GetArrayElementAtIndex(0);
  238. EditorGUI.PropertyField(
  239. position,
  240. continuousActionSize,
  241. new GUIContent("Space Size", "Length of continuous action vector."));
  242. }
  243. /// <summary>
  244. /// Draws the Discrete Vector Actions parameters for the Brain Parameters
  245. /// </summary>
  246. /// <param name="position">Rectangle on the screen to use for the property GUI.</param>
  247. /// <param name="property">The SerializedProperty of the BrainParameters
  248. /// to make the custom GUI for.</param>
  249. private static void DrawDiscreteVectorAction(Rect position, SerializedProperty property)
  250. {
  251. var vecActionSize = property.FindPropertyRelative(ActionSizePropName);
  252. vecActionSize.arraySize = EditorGUI.IntField(
  253. position, "Branches Size", vecActionSize.arraySize);
  254. position.y += LineHeight;
  255. position.x += 20;
  256. position.width -= 20;
  257. for (var branchIndex = 0;
  258. branchIndex < vecActionSize.arraySize;
  259. branchIndex++)
  260. {
  261. SerializedProperty branchActionSize =
  262. vecActionSize.GetArrayElementAtIndex(branchIndex);
  263. EditorGUI.PropertyField(
  264. position,
  265. branchActionSize,
  266. new GUIContent("Branch " + branchIndex + " Size",
  267. "Number of possible actions for the branch number " + branchIndex + "."));
  268. position.y += LineHeight;
  269. }
  270. }
  271. /// <summary>
  272. /// The Height required to draw the Vector Action parameters
  273. /// </summary>
  274. /// <returns>The height of the drawer of the Vector Action </returns>
  275. private static float GetHeightDrawVectorAction(SerializedProperty property)
  276. {
  277. var actionSize = 2 + property.FindPropertyRelative(ActionSizePropName).arraySize;
  278. if (property.FindPropertyRelative(ActionTypePropName).enumValueIndex == 0)
  279. {
  280. actionSize += 1;
  281. }
  282. return actionSize * LineHeight;
  283. }
  284. /// <summary>
  285. /// Draws the Vector Actions descriptions for the Brain Parameters
  286. /// </summary>
  287. /// <param name="position">Rectangle on the screen to use for the property GUI.</param>
  288. /// <param name="property">The SerializedProperty of the BrainParameters
  289. /// to make the custom GUI for.</param>
  290. private static void DrawVectorActionDescriptions(Rect position, SerializedProperty property)
  291. {
  292. var bpVectorActionType = property.FindPropertyRelative(ActionTypePropName);
  293. var vecActionSize = property.FindPropertyRelative(ActionSizePropName);
  294. var numberOfDescriptions = 0;
  295. if (bpVectorActionType.enumValueIndex == 1)
  296. {
  297. numberOfDescriptions = vecActionSize.GetArrayElementAtIndex(0).intValue;
  298. }
  299. else
  300. {
  301. numberOfDescriptions = vecActionSize.arraySize;
  302. }
  303. EditorGUI.indentLevel++;
  304. var vecActionDescriptions =
  305. property.FindPropertyRelative(ActionDescriptionPropName);
  306. vecActionDescriptions.arraySize = numberOfDescriptions;
  307. if (bpVectorActionType.enumValueIndex == 1)
  308. {
  309. //Continuous case :
  310. EditorGUI.PropertyField(
  311. position,
  312. vecActionDescriptions,
  313. new GUIContent("Action Descriptions",
  314. "A list of strings used to name the available actionsm for the Brain."),
  315. true);
  316. position.y += LineHeight;
  317. }
  318. else
  319. {
  320. // Discrete case :
  321. EditorGUI.PropertyField(
  322. position,
  323. vecActionDescriptions,
  324. new GUIContent("Branch Descriptions",
  325. "A list of strings used to name the available branches for the Brain."),
  326. true);
  327. position.y += LineHeight;
  328. }
  329. }
  330. /// <summary>
  331. /// The Height required to draw the Action Descriptions
  332. /// </summary>
  333. /// <returns>The height of the drawer of the Action Descriptions </returns>
  334. private static float GetHeightDrawVectorActionDescriptions(SerializedProperty property)
  335. {
  336. var descriptionSize = 1;
  337. if (property.FindPropertyRelative(ActionDescriptionPropName).isExpanded)
  338. {
  339. var descriptions = property.FindPropertyRelative(ActionDescriptionPropName);
  340. descriptionSize += descriptions.arraySize + 1;
  341. }
  342. return descriptionSize * LineHeight;
  343. }
  344. }
  345. }