163 lines
5.3 KiB
C#
Raw Normal View History

2019-04-10 13:39:36 +02:00
using MLAgents;
using OpenCvSharp;
using System;
2019-04-10 13:39:36 +02:00
using System.Collections;
using System.Collections.Generic;
using UnityEngine;
public class CozmoAgent : Agent
{
// Possible Actions
private const int STOP = 0;
private const int FORWARD = 1;
private const int RIGHT = 2;
private const int LEFT = 3;
2019-04-10 13:39:36 +02:00
// Used to determine different areas in the image (near to the center, far away)
private const float NEAR_AREA_PERCENTAGE_OFFSET = 0.3f;
[Tooltip("The virtual Cozmo camera")]
public Camera renderCamera;
[Tooltip("Reference to the CozmoMovement script")]
public CozmoMovementController movementController;
public float timeBetweenDecisionsAtInference;
private Academy academy; // CozmoAcademy
private float timeSinceDecision; // time since last decision
private ImageProcessor imageProcessor; // reference to the ImageProcessor
private int nearAreaLimit = 0; // X coordinate limit for the near to the imagecenter area
private int centerOfImageX = 0; // Middle of the image in x direction
private void Start()
{
academy = FindObjectOfType(typeof(CozmoAcademy)) as CozmoAcademy;
imageProcessor = renderCamera.GetComponent<ImageProcessor>();
nearAreaLimit = (int)(renderCamera.targetTexture.width / 2 * NEAR_AREA_PERCENTAGE_OFFSET);
centerOfImageX = renderCamera.targetTexture.width / 2;
}
public void FixedUpdate()
{
WaitTimeInference();
}
// to be implemented by the developer
public override void AgentAction(float[] vectorAction, string textAction)
{
int action = Mathf.FloorToInt(vectorAction[0]);
Point centerOfGravity = imageProcessor.CenterOfGravity;
AddReward(-0.01f);
switch (action)
{
case STOP:
movementController.currentMovementState = MovementState.Stop;
break;
case FORWARD:
movementController.currentMovementState = MovementState.Forward;
break;
case RIGHT:
movementController.currentMovementState = MovementState.Right;
break;
case LEFT:
movementController.currentMovementState = MovementState.Left;
break;
default:
//movement.Move(0);
throw new ArgumentException("Invalid action value. Stop movement.");
}
// Render new image after movement in order to update the centerOfGravity
if (renderCamera != null)
{
renderCamera.Render();
}
RewardAgent();
}
// Set the reward for the agent based on how far away the center of gravity is from the center of the image
private void RewardAgent()
{
float centerOfGravityX = imageProcessor.CenterOfGravity.X;
float reward = 0;
// Center of gravity is far away from the center (left)
if (centerOfGravityX <= centerOfImageX - nearAreaLimit)
{
float range = centerOfImageX - nearAreaLimit;
reward = -(1 - (centerOfGravityX / range));
// Clamp the reward to max -1 in order to handle rewards if the center of gravity is outside of the image
reward = Mathf.Clamp(reward, -1, 0);
}
// Center of gravity is near left of the center
else if (centerOfGravityX <= centerOfImageX)
{
float range = centerOfImageX - (centerOfImageX - nearAreaLimit);
float distanceToLeftFarBorder = centerOfGravityX - (centerOfImageX - nearAreaLimit);
reward = (distanceToLeftFarBorder / range);
}
// Center of gravity is far away from the center (right)
else if (centerOfGravityX >= centerOfImageX + nearAreaLimit)
{
float range = renderCamera.targetTexture.width - (centerOfImageX + nearAreaLimit);
reward = -(((centerOfGravityX - (centerOfImageX + nearAreaLimit)) / range));
// Clamp the reward to max -1 in order to handle rewards if the center of gravity is outside of the image
reward = Mathf.Clamp(reward, -1, 0);
}
// Center of gravity is near right of the center
else if (centerOfGravityX >= centerOfImageX)
{
float range = (centerOfImageX + nearAreaLimit) - centerOfImageX;
float distanceToCenterOfImage = centerOfGravityX - centerOfImageX;
reward = (1 - distanceToCenterOfImage / range);
}
2019-05-23 15:28:26 +02:00
SetReward(reward);
}
// to be implemented by the developer
public override void AgentReset()
{
academy.AcademyReset();
}
private void OnTriggerEnter(Collider other)
{
if (other.transform.CompareTag("Goal"))
{
Done();
}
}
private void WaitTimeInference()
{
if (renderCamera != null)
{
renderCamera.Render();
}
if (!academy.GetIsInference())
{
RequestDecision();
}
else
{
if (timeSinceDecision >= timeBetweenDecisionsAtInference)
{
timeSinceDecision = 0f;
RequestDecision();
}
else
{
timeSinceDecision += Time.fixedDeltaTime;
}
}
}
2019-04-10 13:39:36 +02:00
}