You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

LearningBrainEditor.cs 3.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. using UnityEngine;
  2. using UnityEditor;
  3. namespace MLAgents
  4. {
  5. /// <summary>
  6. /// CustomEditor for the LearningBrain class. Defines the default Inspector view for a
  7. /// LearningBrain.
  8. /// Shows the BrainParameters of the Brain and expose a tool to deep copy BrainParameters
  9. /// between brains. Also exposes a drag box for the Model that will be used by the
  10. /// LearningBrain.
  11. /// </summary>
  12. [CustomEditor(typeof(LearningBrain))]
  13. public class LearningBrainEditor : BrainEditor
  14. {
  15. private const string ModelPropName = "model";
  16. private const string InferenceDevicePropName = "inferenceDevice";
  17. private const float TimeBetweenModelReloads = 2f;
  18. // Time since the last reload of the model
  19. private float _timeSinceModelReload;
  20. // Whether or not the model needs to be reloaded
  21. private bool _requireReload;
  22. /// <summary>
  23. /// Called when the user opens the Inspector for the LearningBrain
  24. /// </summary>
  25. public void OnEnable()
  26. {
  27. _requireReload = true;
  28. EditorApplication.update += IncreaseTimeSinceLastModelReload;
  29. }
  30. /// <summary>
  31. /// Called when the user leaves the Inspector for the LearningBrain
  32. /// </summary>
  33. public void OnDisable()
  34. {
  35. EditorApplication.update -= IncreaseTimeSinceLastModelReload;
  36. }
  37. public override void OnInspectorGUI()
  38. {
  39. EditorGUILayout.LabelField("Learning Brain", EditorStyles.boldLabel);
  40. var brain = (LearningBrain) target;
  41. var serializedBrain = serializedObject;
  42. EditorGUI.BeginChangeCheck();
  43. base.OnInspectorGUI();
  44. serializedBrain.Update();
  45. var tfGraphModel = serializedBrain.FindProperty(ModelPropName);
  46. EditorGUILayout.ObjectField(tfGraphModel);
  47. var inferenceDevice = serializedBrain.FindProperty(InferenceDevicePropName);
  48. EditorGUILayout.PropertyField(inferenceDevice);
  49. serializedBrain.ApplyModifiedProperties();
  50. if (EditorGUI.EndChangeCheck())
  51. {
  52. _requireReload = true;
  53. }
  54. if (_requireReload && _timeSinceModelReload > TimeBetweenModelReloads)
  55. {
  56. brain.ReloadModel();
  57. _requireReload = false;
  58. _timeSinceModelReload = 0;
  59. }
  60. // Display all failed checks
  61. var failedChecks = brain.GetModelFailedChecks();
  62. foreach (var check in failedChecks)
  63. {
  64. if (check != null)
  65. {
  66. EditorGUILayout.HelpBox(check, MessageType.Warning);
  67. }
  68. }
  69. }
  70. /// <summary>
  71. /// Increases the time since last model reload by the deltaTime since the last Update call
  72. /// from the UnityEditor
  73. /// </summary>
  74. private void IncreaseTimeSinceLastModelReload()
  75. {
  76. _timeSinceModelReload += Time.deltaTime;
  77. }
  78. }
  79. }