Created
July 4, 2017 19:17
-
-
Save zh4ngx/2de34bab8415f05ec87d0e42a80b4fb8 to your computer and use it in GitHub Desktop.
Cleaned up CartPole
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Source: http://rl-gym-doc.s3-website-us-west-2.amazonaws.com/mlss/lab1.html | |
import gym | |
import numpy as np | |
from gym.wrappers.monitoring import Monitor | |
from policy import Policy | |
# Task settings: | |
env = gym.make('CartPole-v0') # Change as needed | |
env = Monitor(env, 'tmp/cart-pole-cross-entropy-1', force=True) | |
num_steps = 500 # maximum length of episode | |
# Alg settings: | |
n_iter = 100 # number of iterations of CEM | |
batch_size = 25 # number of samples per batch | |
elite_ratio = 0.2 # fraction of samples used as elite set | |
dim_theta = Policy.get_dim_theta(env) | |
# Initialize mean and standard deviation | |
theta_mean = np.zeros(dim_theta) | |
theta_std = np.ones(dim_theta) | |
# Now, for the algorithm | |
for iteration in range(n_iter): | |
# Sample parameter vectors | |
thetas = np.vstack([np.random.multivariate_normal(theta_mean, np.diag(theta_std ** 2)) for _ in range(batch_size)]) | |
rewards = [Policy.make_policy(env, theta).evaluate(env, num_steps) for theta in thetas] | |
# Get elite parameters | |
n_elite = int(batch_size * elite_ratio) | |
elite_indices = np.argsort(rewards)[batch_size - n_elite:batch_size] | |
elite_thetas = [thetas[i] for i in elite_indices] | |
# Update theta_mean, theta_std | |
theta_mean = np.mean(elite_thetas, axis=0) | |
theta_std = np.std(elite_thetas, axis=0) | |
if iteration % 10 == 0: | |
print("iteration %i. mean f: %8.3g. max f: %8.3g" % (iteration, np.mean(rewards), np.max(rewards))) | |
print("theta mean %s \n theta std %s" % (theta_mean, theta_std)) | |
# Demonstrate this policy | |
Policy.make_policy(env, theta_mean).evaluate(env, num_steps, render=True) | |
env.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ================================================================ | |
# Policies | |
# ================================================================ | |
import numpy as np | |
from gym.spaces import Discrete, Box | |
class Policy(object): | |
def __init__(self): | |
pass | |
def act(self, obs): | |
raise NotImplementedError | |
@staticmethod | |
def make_policy(env, theta): | |
if isinstance(env.action_space, Discrete): | |
return DeterministicDiscreteActionLinearPolicy( | |
theta, | |
env.observation_space, | |
env.action_space, | |
) | |
elif isinstance(env.action_space, Box): | |
return DeterministicContinuousActionLinearPolicy( | |
theta, | |
env.observation_space, | |
env.action_space, | |
) | |
else: | |
raise NotImplementedError | |
@staticmethod | |
def get_dim_theta(env): | |
if isinstance(env.action_space, Discrete): | |
return (env.observation_space.shape[0] + 1) * env.action_space.n | |
elif isinstance(env.action_space, Box): | |
return (env.observation_space.shape[0] + 1) * env.action_space.shape[0] | |
else: | |
raise NotImplementedError | |
def evaluate(self, env, num_steps, render=False): | |
total_rew = 0 | |
ob = env.reset() | |
for t in range(num_steps): | |
a = self.act(ob) | |
(ob, reward, done, _info) = env.step(a) | |
total_rew += reward | |
if render and t % 3 == 0: | |
env.render() | |
if done: | |
break | |
return total_rew | |
class DeterministicDiscreteActionLinearPolicy(Policy): | |
def __init__(self, theta, ob_space, ac_space): | |
""" | |
dim_ob: dimension of observations | |
n_actions: number of actions | |
theta: flat vector of parameters | |
""" | |
Policy.__init__(self) | |
dim_ob = ob_space.shape[0] | |
n_actions = ac_space.n | |
assert len(theta) == (dim_ob + 1) * n_actions | |
self.W = theta[0: dim_ob * n_actions].reshape(dim_ob, n_actions) | |
self.b = theta[dim_ob * n_actions: None].reshape(1, n_actions) | |
def act(self, ob): | |
""" | |
""" | |
y = ob.dot(self.W) + self.b | |
a = y.argmax() | |
return a | |
class DeterministicContinuousActionLinearPolicy(Policy): | |
def __init__(self, theta, ob_space, ac_space): | |
""" | |
dim_ob: dimension of observations | |
dim_ac: dimension of action vector | |
theta: flat vector of parameters | |
""" | |
Policy.__init__(self) | |
self.ac_space = ac_space | |
dim_ob = ob_space.shape[0] | |
dim_ac = ac_space.shape[0] | |
assert len(theta) == (dim_ob + 1) * dim_ac | |
self.W = theta[0: dim_ob * dim_ac].reshape(dim_ob, dim_ac) | |
self.b = theta[dim_ob * dim_ac: None] | |
def act(self, ob): | |
a = np.clip(ob.dot(self.W) + self.b, self.ac_space.low, self.ac_space.high) | |
return a |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment