Browse Source

New Reward function --> only near and far area now instead if three areas

Development
Tobi 5 years ago
parent
commit
8bb5da8946
1 changed files with 17 additions and 39 deletions
  1. 17
    39
      Assets/Scripts/ML Cozmo/CozmoAgent.cs

+ 17
- 39
Assets/Scripts/ML Cozmo/CozmoAgent.cs View File

private const int LEFT = 3; private const int LEFT = 3;


// Used to determine different areas in the image (near to the center, far away) // Used to determine different areas in the image (near to the center, far away)
private const float NEAR_AREA_PERCENTAGE_OFFSET = 0.2f;
private const float FAR_AREA_PERCENTAGE_OFFSET = 0.3f;
private const float NEAR_AREA_PERCENTAGE_OFFSET = 0.3f;


[Tooltip("The virtual Cozmo camera")] [Tooltip("The virtual Cozmo camera")]
public Camera renderCamera; public Camera renderCamera;
//[Tooltip("Final cropped and scaled rendertexture")]
//public RenderTexture renderTextureScaled;
[Tooltip("Reference to the CozmoMovement script")] [Tooltip("Reference to the CozmoMovement script")]
public CozmoMovementController movementController; public CozmoMovementController movementController;
public float timeBetweenDecisionsAtInference; public float timeBetweenDecisionsAtInference;
private float timeSinceDecision; // time since last decision private float timeSinceDecision; // time since last decision
private ImageProcessor imageProcessor; // reference to the ImageProcessor private ImageProcessor imageProcessor; // reference to the ImageProcessor
private int nearAreaLimit = 0; // X coordinate limit for the near to the imagecenter area private int nearAreaLimit = 0; // X coordinate limit for the near to the imagecenter area
private int farAreaLimit = 0; // X coordinate limit for the far away to the imagecenter area


private int centerOfImageX = 0; // Middle of the image in x direction


private void Start() private void Start()
{ {
academy = FindObjectOfType(typeof(CozmoAcademy)) as CozmoAcademy; academy = FindObjectOfType(typeof(CozmoAcademy)) as CozmoAcademy;
imageProcessor = renderCamera.GetComponent<ImageProcessor>(); imageProcessor = renderCamera.GetComponent<ImageProcessor>();
nearAreaLimit = (int)(renderCamera.targetTexture.width / 2 * NEAR_AREA_PERCENTAGE_OFFSET); nearAreaLimit = (int)(renderCamera.targetTexture.width / 2 * NEAR_AREA_PERCENTAGE_OFFSET);
farAreaLimit = (int)(renderCamera.targetTexture.width / 2 * FAR_AREA_PERCENTAGE_OFFSET);
centerOfImageX = renderCamera.targetTexture.width / 2;
} }




// to be implemented by the developer // to be implemented by the developer
public override void AgentAction(float[] vectorAction, string textAction) public override void AgentAction(float[] vectorAction, string textAction)
{ {
//print("Action before FloorToInt: " + vectorAction[0]);
int action = Mathf.FloorToInt(vectorAction[0]); int action = Mathf.FloorToInt(vectorAction[0]);
Point centerOfGravity = imageProcessor.CenterOfGravity; Point centerOfGravity = imageProcessor.CenterOfGravity;
//Vector3 targetPos = transform.position;

//print("Action after FloorToInt: " + action);


AddReward(-0.01f); AddReward(-0.01f);


} }


RewardAgent(); RewardAgent();
imageProcessor.enabled = false;
} }


/// <summary>
/// TODO: Cleanup code
/// </summary>
// Set the reward for the agent based on how far away the center of gravity is from the center of the image
private void RewardAgent() private void RewardAgent()
{ {
float centerOfImageX = renderCamera.targetTexture.width / 2;
float centerOfGravityX = imageProcessor.CenterOfGravity.X; float centerOfGravityX = imageProcessor.CenterOfGravity.X;
float reward = 0; float reward = 0;


// Center of gravity is far left of the center
if (centerOfGravityX <= centerOfImageX - farAreaLimit)
// Center of gravity is far away from the center (left)
if (centerOfGravityX <= centerOfImageX - nearAreaLimit)
{ {
reward = -1;
}
// Center of gravity is between far and near left of the center
else if (centerOfGravityX <= centerOfImageX - nearAreaLimit)
{
float range = (centerOfImageX - nearAreaLimit) - (centerOfImageX - farAreaLimit);
float distanceToLeftFarBorder = centerOfGravityX - (centerOfImageX - farAreaLimit);
reward = -(1 - (distanceToLeftFarBorder / range));
float range = centerOfImageX - nearAreaLimit;
reward = -(1 - (centerOfGravityX / range));
// Clamp the reward to max -1 in order to handle rewards if the center of gravity is outside of the image
reward = Mathf.Clamp(reward, -1, 0);
} }
// Center of gravity is near left of the center // Center of gravity is near left of the center
else if (centerOfGravityX <= centerOfImageX) else if (centerOfGravityX <= centerOfImageX)
float distanceToLeftFarBorder = centerOfGravityX - (centerOfImageX - nearAreaLimit); float distanceToLeftFarBorder = centerOfGravityX - (centerOfImageX - nearAreaLimit);
reward = (distanceToLeftFarBorder / range); reward = (distanceToLeftFarBorder / range);
} }
// Center of gravity is far right of the center
else if (centerOfGravityX >= centerOfImageX + farAreaLimit)
{
reward = -1;
}
// Center of gravity is between far and near right of the center
// Center of gravity is far away from the center (right)
else if (centerOfGravityX >= centerOfImageX + nearAreaLimit) else if (centerOfGravityX >= centerOfImageX + nearAreaLimit)
{ {
float range = (centerOfImageX + farAreaLimit) - (centerOfImageX + nearAreaLimit);
float distanceToLeftFarBorder = centerOfGravityX - (centerOfImageX + nearAreaLimit);
reward = -(distanceToLeftFarBorder / range);
float range = renderCamera.targetTexture.width - (centerOfImageX + nearAreaLimit);
reward = -(((centerOfGravityX - (centerOfImageX + nearAreaLimit)) / range));
// Clamp the reward to max -1 in order to handle rewards if the center of gravity is outside of the image
reward = Mathf.Clamp(reward, -1, 0);
} }
// Center of gravity is near right of the center // Center of gravity is near right of the center
else if (centerOfGravityX >= centerOfImageX) else if (centerOfGravityX >= centerOfImageX)
{ {
float range = (centerOfImageX + nearAreaLimit) - centerOfImageX; float range = (centerOfImageX + nearAreaLimit) - centerOfImageX;
float distanceToLeftFarBorder = centerOfGravityX - centerOfImageX;
reward = (1 - distanceToLeftFarBorder / range);
float distanceToCenterOfImage = centerOfGravityX - centerOfImageX;
reward = (1 - distanceToCenterOfImage / range);
} }


SetReward(reward); SetReward(reward);
{ {
if (other.transform.CompareTag("Goal")) if (other.transform.CompareTag("Goal"))
{ {
print("Collission");
Done(); Done();
} }
} }

Loading…
Cancel
Save