123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157 |
- using System;
- using UnityEngine;
- using System.Linq;
- using MLAgents;
-
- public class GridAgent : Agent
- {
- [Header("Specific to GridWorld")]
- private GridAcademy academy;
- public float timeBetweenDecisionsAtInference;
- private float timeSinceDecision;
-
- [Tooltip("Because we want an observation right before making a decision, we can force " +
- "a camera to render before making a decision. Place the agentCam here if using " +
- "RenderTexture as observations.")]
- public Camera renderCamera;
-
- [Tooltip("Selecting will turn on action masking. Note that a model trained with action " +
- "masking turned on may not behave optimally when action masking is turned off.")]
- public bool maskActions = true;
-
- private const int NoAction = 0; // do nothing!
- private const int Up = 1;
- private const int Down = 2;
- private const int Left = 3;
- private const int Right = 4;
-
- public override void InitializeAgent()
- {
- academy = FindObjectOfType(typeof(GridAcademy)) as GridAcademy;
- }
-
- public override void CollectObservations()
- {
- // There are no numeric observations to collect as this environment uses visual
- // observations.
-
- // Mask the necessary actions if selected by the user.
- if (maskActions)
- {
- SetMask();
- }
- }
-
- /// <summary>
- /// Applies the mask for the agents action to disallow unnecessary actions.
- /// </summary>
- private void SetMask()
- {
- // Prevents the agent from picking an action that would make it collide with a wall
- var positionX = (int)transform.position.x;
- var positionZ = (int)transform.position.z;
- var maxPosition = academy.gridSize - 1;
-
- if (positionX == 0)
- {
- SetActionMask(Left);
- }
-
- if (positionX == maxPosition)
- {
- SetActionMask(Right);
- }
-
- if (positionZ == 0)
- {
- SetActionMask(Down);
- }
-
- if (positionZ == maxPosition)
- {
- SetActionMask(Up);
- }
- }
-
- // to be implemented by the developer
- public override void AgentAction(float[] vectorAction, string textAction)
- {
- AddReward(-0.01f);
- int action = Mathf.FloorToInt(vectorAction[0]);
-
- Vector3 targetPos = transform.position;
- switch (action)
- {
- case NoAction:
- // do nothing
- break;
- case Right:
- targetPos = transform.position + new Vector3(1f, 0, 0f);
- break;
- case Left:
- targetPos = transform.position + new Vector3(-1f, 0, 0f);
- break;
- case Up:
- targetPos = transform.position + new Vector3(0f, 0, 1f);
- break;
- case Down:
- targetPos = transform.position + new Vector3(0f, 0, -1f);
- break;
- default:
- throw new ArgumentException("Invalid action value");
- }
-
- Collider[] blockTest = Physics.OverlapBox(targetPos, new Vector3(0.3f, 0.3f, 0.3f));
- if (blockTest.Where(col => col.gameObject.CompareTag("wall")).ToArray().Length == 0)
- {
- transform.position = targetPos;
-
- if (blockTest.Where(col => col.gameObject.CompareTag("goal")).ToArray().Length == 1)
- {
- SetReward(1f);
- }
- if (blockTest.Where(col => col.gameObject.CompareTag("pit")).ToArray().Length == 1)
- {
- SetReward(-1f);
- }
- }
- }
-
-
-
- // to be implemented by the developer
- public override void AgentReset()
- {
- academy.AcademyReset();
- }
-
- public void FixedUpdate()
- {
- WaitTimeInference();
- }
-
- private void WaitTimeInference()
- {
- if (renderCamera != null)
- {
- renderCamera.Render();
- }
-
- if (!academy.GetIsInference())
- {
- RequestDecision();
- }
- else
- {
- if (timeSinceDecision >= timeBetweenDecisionsAtInference)
- {
- timeSinceDecision = 0f;
- RequestDecision();
- }
- else
- {
- timeSinceDecision += Time.fixedDeltaTime;
- }
- }
- }
- }
|