2019-04-10 13:39:36 +02:00
|
|
|
|
using MLAgents;
|
2019-05-17 18:23:24 +02:00
|
|
|
|
using OpenCvSharp;
|
|
|
|
|
using System;
|
2019-04-10 13:39:36 +02:00
|
|
|
|
using System.Collections;
|
|
|
|
|
using System.Collections.Generic;
|
|
|
|
|
using UnityEngine;
|
|
|
|
|
|
|
|
|
|
public class CozmoAgent : Agent
|
|
|
|
|
{
|
2019-05-17 18:23:24 +02:00
|
|
|
|
// Possible Actions
|
|
|
|
|
private const int STOP = 0;
|
|
|
|
|
private const int FORWARD = 1;
|
|
|
|
|
private const int RIGHT = 2;
|
|
|
|
|
private const int LEFT = 3;
|
2019-04-10 13:39:36 +02:00
|
|
|
|
|
2019-05-17 18:23:24 +02:00
|
|
|
|
// Used to determine different areas in the image (near to the center, far away)
|
|
|
|
|
private const float NEAR_AREA_PERCENTAGE_OFFSET = 0.2f;
|
|
|
|
|
private const float FAR_AREA_PERCENTAGE_OFFSET = 0.3f;
|
|
|
|
|
|
|
|
|
|
[Tooltip("The virtual Cozmo camera")]
|
2019-05-15 12:07:27 +02:00
|
|
|
|
public Camera renderCamera;
|
2019-05-17 18:23:24 +02:00
|
|
|
|
[Tooltip("Reference to the CozmoMovement script")]
|
|
|
|
|
public CozmoMovement movement;
|
2019-05-15 12:07:27 +02:00
|
|
|
|
public float timeBetweenDecisionsAtInference;
|
2019-05-17 18:23:24 +02:00
|
|
|
|
|
|
|
|
|
private Academy academy; // CozmoAcademy
|
|
|
|
|
private float timeSinceDecision; // time since last decision
|
|
|
|
|
private ImageProcessor onRenderImageTest; // reference to the ImageProcessor
|
|
|
|
|
private int nearAreaLimit = 0; // X coordinate limit for the near to the imagecenter area
|
|
|
|
|
private int farAreaLimit = 0; // X coordinate limit for the far away to the imagecenter area
|
|
|
|
|
|
|
|
|
|
// for testing
|
|
|
|
|
//private float[] floats = { 1.0f, 2.0f, 3.0f };
|
|
|
|
|
|
|
|
|
|
private void Start()
|
|
|
|
|
{
|
|
|
|
|
academy = FindObjectOfType(typeof(CozmoAcademy)) as CozmoAcademy;
|
|
|
|
|
onRenderImageTest = renderCamera.GetComponent<ImageProcessor>();
|
|
|
|
|
nearAreaLimit = (int)(renderCamera.targetTexture.width / 2 * NEAR_AREA_PERCENTAGE_OFFSET);
|
|
|
|
|
farAreaLimit = (int)(renderCamera.targetTexture.width / 2 * FAR_AREA_PERCENTAGE_OFFSET);
|
|
|
|
|
}
|
2019-05-15 12:07:27 +02:00
|
|
|
|
|
|
|
|
|
public void FixedUpdate()
|
|
|
|
|
{
|
|
|
|
|
WaitTimeInference();
|
2019-05-17 18:23:24 +02:00
|
|
|
|
// for testing
|
|
|
|
|
//AgentAction(floats, "ActionText");
|
2019-05-15 12:07:27 +02:00
|
|
|
|
}
|
|
|
|
|
|
2019-05-17 18:23:24 +02:00
|
|
|
|
|
|
|
|
|
// to be implemented by the developer
|
|
|
|
|
public override void AgentAction(float[] vectorAction, string textAction)
|
|
|
|
|
{
|
|
|
|
|
|
|
|
|
|
int action = Mathf.FloorToInt(vectorAction[0]);
|
|
|
|
|
Point centerOfGravity = onRenderImageTest.CenterOfGravity;
|
|
|
|
|
Vector3 targetPos = transform.position;
|
|
|
|
|
|
|
|
|
|
AddReward(-0.01f);
|
|
|
|
|
|
|
|
|
|
switch (action)
|
|
|
|
|
{
|
|
|
|
|
case STOP:
|
|
|
|
|
movement.Move(0);
|
|
|
|
|
break;
|
|
|
|
|
case FORWARD:
|
|
|
|
|
movement.Move(1);
|
|
|
|
|
break;
|
|
|
|
|
case RIGHT:
|
|
|
|
|
movement.Turn(1);
|
|
|
|
|
break;
|
|
|
|
|
case LEFT:
|
|
|
|
|
movement.Turn(-1);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
movement.Move(0);
|
|
|
|
|
throw new ArgumentException("Invalid action value. Stop movement.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Render new image after movement in order to update the centerOfGravity
|
|
|
|
|
if (renderCamera != null)
|
|
|
|
|
{
|
|
|
|
|
renderCamera.Render();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// If centerOfGravity lies near to the center of the image horizontally
|
|
|
|
|
if (centerOfGravity.X > renderCamera.targetTexture.width / 2 - nearAreaLimit && centerOfGravity.X < renderCamera.targetTexture.width / 2 + nearAreaLimit)
|
|
|
|
|
{
|
|
|
|
|
Done();
|
|
|
|
|
SetReward(1);
|
|
|
|
|
print("Reward: +1");
|
|
|
|
|
}
|
|
|
|
|
else if (centerOfGravity.X > renderCamera.targetTexture.width / 2 - farAreaLimit && centerOfGravity.X < renderCamera.targetTexture.width / 2 + farAreaLimit)
|
|
|
|
|
{
|
|
|
|
|
Done();
|
|
|
|
|
SetReward(-1);
|
|
|
|
|
print("Reward: -1");
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
Done();
|
|
|
|
|
SetReward(-2);
|
|
|
|
|
print("Reward: -2");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// to be implemented by the developer
|
|
|
|
|
public override void AgentReset()
|
|
|
|
|
{
|
|
|
|
|
academy.AcademyReset();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
2019-05-15 12:07:27 +02:00
|
|
|
|
private void WaitTimeInference()
|
|
|
|
|
{
|
|
|
|
|
if (renderCamera != null)
|
|
|
|
|
{
|
|
|
|
|
renderCamera.Render();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!academy.GetIsInference())
|
|
|
|
|
{
|
|
|
|
|
RequestDecision();
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
if (timeSinceDecision >= timeBetweenDecisionsAtInference)
|
|
|
|
|
{
|
|
|
|
|
timeSinceDecision = 0f;
|
|
|
|
|
RequestDecision();
|
|
|
|
|
}
|
|
|
|
|
else
|
|
|
|
|
{
|
|
|
|
|
timeSinceDecision += Time.fixedDeltaTime;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-04-10 13:39:36 +02:00
|
|
|
|
}
|