Last active
July 13, 2023 12:42
-
-
Save WuXinyang2012/647a1aca65691578155bdae1a6ea4f6c to your computer and use it in GitHub Desktop.
A CartPole-SwingUp env as in PILCO, with {cos(theta), sin(theta)} as observation, instead of theta.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Cart pole swing-up: Identical version to PILCO V0.9 | |
""" | |
import logging | |
import math | |
import gym | |
from gym import spaces | |
from gym.utils import seeding | |
import numpy as np | |
logger = logging.getLogger(__name__) | |
class CartPoleSwingUpEnv(gym.Env): | |
metadata = { | |
'render.modes': ['human', 'rgb_array'], | |
'video.frames_per_second' : 50 | |
} | |
def __init__(self): | |
self.g = 9.82 # gravity | |
self.m_c = 0.5 # cart mass | |
self.m_p = 0.5 # pendulum mass | |
self.total_m = (self.m_p + self.m_c) | |
self.l = 0.6 # pole's length | |
self.m_p_l = (self.m_p*self.l) | |
self.force_mag = 10.0 | |
self.dt = 0.01 # seconds between state updates | |
self.b = 0.1 # friction coefficient | |
# Angle at which to fail the episode | |
self.theta_threshold_radians = 180 * 2 * np.pi / 360 | |
self.x_threshold = 2.4 | |
self.x_dot_threshold = 8 # np.finfo(np.float32).max | |
self.theta_dot_threshold = 8 # np.finfo(np.float32).max | |
# high = np.array([ | |
# 2 * self.x_threshold, | |
# self.x_dot_threshold, | |
# 2 * self.theta_threshold_radians, | |
# self.theta_dot_threshold | |
# ]) | |
high = np.array([ | |
self.x_threshold, | |
self.x_dot_threshold, | |
1., | |
1., | |
self.theta_dot_threshold | |
], dtype=np.float32) | |
self.action_space = spaces.Box(-self.force_mag, self.force_mag, shape=(1,)) | |
self.observation_space = spaces.Box(-high, high) | |
self._seed() | |
self.viewer = None | |
self.state = None | |
def _seed(self, seed=None): | |
self.np_random, seed = seeding.np_random(seed) | |
return [seed] | |
def _step(self, action): | |
# Valid action | |
action = np.clip(action, -self.force_mag, self.force_mag)[0] | |
state = self.state | |
x, x_dot, theta, theta_dot = state | |
s = math.sin(theta) | |
c = math.cos(theta) | |
xdot_update = (-2*self.m_p_l*(theta_dot**2)*s + 3*self.m_p*self.g*s*c + 4*action - 4*self.b*x_dot)/(4*self.total_m - 3*self.m_p*c**2) | |
thetadot_update = (-3*self.m_p_l*(theta_dot**2)*s*c + 6*self.total_m*self.g*s + 6*(action - self.b*x_dot)*c)/(4*self.l*self.total_m - 3*self.m_p_l*c**2) | |
x = x + x_dot*self.dt | |
theta = theta + theta_dot*self.dt | |
x_dot = x_dot + xdot_update*self.dt | |
theta_dot = theta_dot + thetadot_update*self.dt | |
# if theta > self.theta_threshold_radians: | |
# print("theta bigger than 2pi") | |
# Constraint theta into [-pi, pi] | |
# theta = np.arctan2(np.sin(theta), np.cos(theta)) | |
self.state = (x,x_dot,theta,theta_dot) | |
done = x < -self.x_threshold \ | |
or x > self.x_threshold # \ | |
# or theta > self.theta_threshold_radians \ | |
# or theta < -self.theta_threshold_radians | |
done = bool(done) | |
# compute costs - saturation cost | |
goal = np.array([0.0, self.l]) | |
pole_x = self.l*np.sin(theta) | |
pole_y = self.l*np.cos(theta) | |
position = np.array([self.state[0] + pole_x, pole_y]) | |
squared_distance = np.sum((position - goal)**2) | |
squared_sigma = 0.25**2 | |
costs = 1 - np.exp(-0.5*squared_distance/squared_sigma) | |
return self._get_obs(), -costs, done, {} | |
def _get_obs(self): | |
x,x_dot,theta,theta_dot = self.state | |
return np.array([x, x_dot, np.cos(theta), np.sin(theta), theta_dot]) | |
def _reset(self): | |
#self.state = self.np_random.normal(loc=np.array([0.0, 0.0, 30*(2*np.pi)/360, 0.0]), scale=np.array([0.0, 0.0, 0.0, 0.0])) | |
self.state = self.np_random.normal(loc=np.array([0.0, 0.0, np.pi, 0.0]), scale=np.array([0.02, 0.02, 0.02, 0.02])) | |
self.steps_beyond_done = None | |
return self._get_obs() | |
def _render(self, mode='human', close=False): | |
if close: | |
if self.viewer is not None: | |
self.viewer.close() | |
self.viewer = None | |
return | |
screen_width = 600 | |
screen_height = 400 | |
world_width = 5 # max visible position of cart | |
scale = screen_width/world_width | |
carty = 200 # TOP OF CART | |
polewidth = 6.0 | |
polelen = scale*self.l # 0.6 or self.l | |
cartwidth = 40.0 | |
cartheight = 20.0 | |
if self.viewer is None: | |
from gym.envs.classic_control import rendering | |
self.viewer = rendering.Viewer(screen_width, screen_height) | |
l,r,t,b = -cartwidth/2, cartwidth/2, cartheight/2, -cartheight/2 | |
cart = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) | |
self.carttrans = rendering.Transform() | |
cart.add_attr(self.carttrans) | |
cart.set_color(1, 0, 0) | |
self.viewer.add_geom(cart) | |
l,r,t,b = -polewidth/2,polewidth/2,polelen-polewidth/2,-polewidth/2 | |
pole = rendering.FilledPolygon([(l,b), (l,t), (r,t), (r,b)]) | |
pole.set_color(0, 0, 1) | |
self.poletrans = rendering.Transform(translation=(0, 0)) | |
pole.add_attr(self.poletrans) | |
pole.add_attr(self.carttrans) | |
self.viewer.add_geom(pole) | |
self.axle = rendering.make_circle(polewidth/2) | |
self.axle.add_attr(self.poletrans) | |
self.axle.add_attr(self.carttrans) | |
self.axle.set_color(0.1, 1, 1) | |
self.viewer.add_geom(self.axle) | |
# Make another circle on the top of the pole | |
self.pole_bob = rendering.make_circle(polewidth/2) | |
self.pole_bob_trans = rendering.Transform() | |
self.pole_bob.add_attr(self.pole_bob_trans) | |
self.pole_bob.add_attr(self.poletrans) | |
self.pole_bob.add_attr(self.carttrans) | |
self.pole_bob.set_color(0, 0, 0) | |
self.viewer.add_geom(self.pole_bob) | |
self.wheel_l = rendering.make_circle(cartheight/4) | |
self.wheel_r = rendering.make_circle(cartheight/4) | |
self.wheeltrans_l = rendering.Transform(translation=(-cartwidth/2, -cartheight/2)) | |
self.wheeltrans_r = rendering.Transform(translation=(cartwidth/2, -cartheight/2)) | |
self.wheel_l.add_attr(self.wheeltrans_l) | |
self.wheel_l.add_attr(self.carttrans) | |
self.wheel_r.add_attr(self.wheeltrans_r) | |
self.wheel_r.add_attr(self.carttrans) | |
self.wheel_l.set_color(0, 0, 0) # Black, (B, G, R) | |
self.wheel_r.set_color(0, 0, 0) # Black, (B, G, R) | |
self.viewer.add_geom(self.wheel_l) | |
self.viewer.add_geom(self.wheel_r) | |
self.track = rendering.Line((0,carty - cartheight/2 - cartheight/4), (screen_width,carty - cartheight/2 - cartheight/4)) | |
self.track.set_color(0,0,0) | |
self.viewer.add_geom(self.track) | |
if self.state is None: return None | |
x = self.state | |
cartx = x[0]*scale+screen_width/2.0 # MIDDLE OF CART | |
self.carttrans.set_translation(cartx, carty) | |
self.poletrans.set_rotation(x[2]) | |
self.pole_bob_trans.set_translation(-self.l*np.sin(x[2]), self.l*np.cos(x[2])) | |
return self.viewer.render(return_rgb_array = mode=='rgb_array') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The reason for using {cos(theta), sin(theta)} instead of theta:
With {cos(theta), sin(theta)}, we keep the observation space continuous and possible for further normalization.