Last active
May 19, 2022 08:43
-
-
Save ArztSamuel/499e617844ca4ce6e222183bd23752f0 to your computer and use it in GitHub Desktop.
Most important parts of the Agent code used for the project of https://youtu.be/VMp6pq6_QjI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class ParkingCarAgent : Agent | |
{ | |
[SerializeField] | |
private Transform TargetParkingSpot; | |
[SerializeField] | |
// = Reward every 'interval' units getting closer | |
private float DistanceRewardInterval = 3f; | |
// Thresholds defining when the task is complete | |
[SerializeField] | |
private float DistanceThreshold = 2; | |
[SerializeField] | |
private float RotationThreshold = 20; | |
[SerializeField] | |
private float SpeedTheshold = 5f; | |
// Bounds the agent may not leave | |
[SerializeField] | |
private Bounds AllowedBounds; | |
private DistanceSensor[] distanceSensors; | |
... | |
public override void CollectObservations() | |
{ | |
base.CollectObservations(); | |
// Agent position, y rotation and velocity | |
Vector3 normalizedAgentPosition = GetNormalizedPosition(this.transform.position); | |
AddVectorObs(carPhysics.CurrentSpeed); | |
AddVectorObs(normalizedAgentPosition.x); | |
AddVectorObs(normalizedAgentPosition.z); | |
Vector3 normalizedAgentRotation = GetNormalizedRotation(this.transform.rotation); | |
AddVectorObs(normalizedAgentRotation.y); | |
// Target position / y rotation | |
Vector3 normalizedTargetPosition = GetNormalizedPosition(TargetParkingSpot.position); | |
AddVectorObs(normalizedTargetPosition.x - normalizedAgentPosition.x); | |
AddVectorObs(normalizedTargetPosition.z - normalizedAgentPosition.z); | |
Vector3 normalizedTargetRotation = GetNormalizedRotation(TargetParkingSpot.rotation); | |
AddVectorObs(normalizedTargetRotation.y - normalizedAgentRotation.y); | |
// Add all sensor readings | |
foreach (DistanceSensor sensor in distanceSensors) | |
{ | |
sensor.UpdateSensorReadings(); | |
AddVectorObs(sensor.NormalizedDistance); | |
} | |
} | |
public override void AgentAction(float[] vectorAction, string textAction) | |
{ | |
base.AgentAction(vectorAction, textAction); | |
if (IsDone()) | |
return; | |
// Action Inputs, length 3: | |
// [0]: Throttle, positive remapped to range [0, 1] | |
// [0]: Braking, negative remapped to range [0, 1] | |
// [1]: Turning, directly used as input | |
carPhysics.CurrentThrottle = Mathf.Max(0, vectorAction[0]); | |
carPhysics.CurrentBraking = Mathf.Max(0, -vectorAction[0]); | |
carPhysics.CurrentTurning = vectorAction[1]; | |
// Reward for getting closer; Note: could use sqrDistance here for performance | |
float distanceToTarget = Vector3.Distance(this.transform.position, TargetParkingSpot.transform.position); | |
if (distanceToTarget < previousDistance) | |
{ | |
if ((int)(distanceToTarget / DistanceRewardInterval) < (int)(previousDistance / DistanceRewardInterval)) | |
AddReward(0.02f); | |
previousDistance = distanceToTarget; | |
} | |
else | |
{ | |
// Note: '* 2' is a hard coded value here, which I introduced after tuning the penalty to occur less frequently than | |
// the reward, in order to not 'scare' the AI of performing corrective maneuvers where it has to first increase the | |
// distance to the target parking spot. | |
if ((int)(distanceToTarget / (DistanceRewardInterval * 2)) > (int)(previousDistance / (DistanceRewardInterval * 2))) | |
{ | |
if (Verbose) | |
Debug.Log("Distance based penalty"); | |
AddReward(-0.04f); | |
previousDistance = distanceToTarget; | |
} | |
} | |
// Check task completion (= position and rotation lower than threshold) | |
float rotationDiff = Quaternion.Angle(this.transform.rotation, TargetParkingSpot.rotation); | |
if (distanceToTarget <= DistanceThreshold) | |
{ | |
// Angle wrap-around | |
if (rotationDiff > 90) | |
rotationDiff = 180 - rotationDiff; | |
if (Mathf.Abs(carPhysics.CurrentSpeed) <= SpeedTheshold) | |
{ | |
// Determine how well (= how parallel) the AI parked | |
float reward = 1; | |
if (rotationDiff > RotationThreshold) | |
reward = 1 - GetNormalizedValue(rotationDiff, RotationThreshold, 90); | |
AddReward(reward); | |
Done(); | |
return; | |
} | |
} | |
if (!AllowedBounds.Contains(new Vector3Int((int)transform.position.x, (int)transform.position.y, (int)transform.position.z))) | |
{ | |
AddReward(-1.0f); | |
Done(); | |
return; | |
} | |
} | |
private Vector3 GetNormalizedPosition(in Vector3 position) | |
{ | |
float normalizedX = GetNormalizedValue(position.x, AllowedBounds.min.x, AllowedBounds.max.x); | |
float normalizedY = GetNormalizedValue(position.y, AllowedBounds.min.y, AllowedBounds.max.y); | |
float normalizedZ = GetNormalizedValue(position.z, AllowedBounds.min.z, AllowedBounds.max.z); | |
return new Vector3(normalizedX, normalizedY, normalizedZ); | |
} | |
private Vector3 GetNormalizedRotation(in Quaternion rotation) | |
{ | |
float normalizedX = GetNormalizedValue(rotation.eulerAngles.x, 0, 360); | |
float normalizedY = GetNormalizedValue(rotation.eulerAngles.y, 0, 360); | |
float normalizedZ = GetNormalizedValue(rotation.eulerAngles.z, 0, 360); | |
return new Vector3(normalizedX, normalizedY, normalizedZ); | |
} | |
private float GetNormalizedValue(float currentValue, float minValue, float maxValue) | |
{ | |
return (currentValue - minValue) / (maxValue - minValue); | |
} | |
void OnCollisionEnter(Collision collision) | |
{ | |
if (collision.collider.gameObject.GetComponent<Knockable>() || collision.collider.gameObject.GetComponentInParent<ParkingCar>()) | |
AddReward(-0.12f); | |
} | |
... | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment