Last active
July 2, 2019 12:57
-
-
Save livoras/8e749f9e56f5450485fd82990d5db296 to your computer and use it in GitHub Desktop.
PPO2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
from gym import spaces | |
import numpy as np | |
import math | |
import random | |
from gym.utils import seeding | |
class BullsCowsEnv(gym.Env): | |
def __init__(self): | |
super(BullsCowsEnv, self).__init__() | |
self.action_space = spaces.Box( | |
low=np.array([0, 0, 0, 0]), | |
high=np.array([9, 9, 9, 9]), | |
dtype=np.int, | |
) | |
self.observation_space = spaces.Box( | |
low=0, | |
high=9, | |
shape=(10, 8), | |
dtype=np.uint, | |
) | |
self.guess_count = 0 | |
self.state = [] | |
self.solution = [] | |
self.old_guess = [] | |
self.reset() | |
self.max_right_count = 0 | |
self.max_position_count = 0 | |
# self.seed() | |
def step(self, action): | |
assert self.action_space.contains(action) | |
action = [int(round(x)) for x in action] | |
done = False | |
reward = -1 | |
if len(action) > len(set(action)) or np.array_equal(action, self.old_guess): | |
done = True | |
reward = -100 | |
else: | |
prompt = self.get_prompt(action) | |
self.state[self.guess_count] = np.concatenate(( | |
action, prompt | |
)) | |
right_count = prompt[0] | |
position_count = prompt[2] | |
reward += (right_count - self.max_right_count) * 10 | |
reward += (right_count - self.max_position_count) * 3 | |
self.max_right_count = max(right_count, self.max_right_count) | |
self.max_position_count = max(position_count, self.max_position_count) | |
if prompt[0] == 4: | |
reward = 100 | |
done = True | |
self.old_guess = action | |
self.guess_count += 1 | |
if self.guess_count > 9: | |
done = True | |
return self.state, reward, done, {} | |
def get_prompt(self, guess): | |
right_count = 0 | |
position_count = 0 | |
for i, num in enumerate(guess): | |
if guess[i] == self.solution[i]: | |
right_count += 1 | |
if num in self.solution: | |
position_count += 1 | |
return np.array([right_count, 1, position_count - right_count, 2]) | |
def seed(self, seed=None): | |
self.np_random, seed = seeding.np_random(seed) | |
return [seed] | |
def reset(self): | |
self.guess_count = 0 | |
self.solution = self.new_solution() | |
self.old_guess = np.zeros(shape=[1, 4]) | |
self.max_right_count = 0 | |
self.max_position_count = 0 | |
self.state = self.init_state() | |
return self.step(self.new_solution())[0] | |
def init_state(self): | |
# return np.random.randint(0, 9, size=(10, 8)) | |
return np.zeros(shape=(10, 8), dtype=np.uint) | |
def new_solution(self): | |
solution = np.random.randint(0, 9, size=4) | |
while len(solution) != len(set(solution)): | |
solution = np.random.randint(0, 9, size=4) | |
return solution | |
def render(self, mode='human', close=False): | |
print(self.solution) | |
for i, row in enumerate(self.state): | |
s = "" | |
for j, num in enumerate(row): | |
num = int(num) | |
if j <= 3: | |
s += str(num) | |
else: | |
if j % 2 == 0: | |
s += str(num) | |
else: | |
if num == 1: | |
s += "A" | |
elif num == 2: | |
s += "B" | |
else: | |
s += "_" | |
print(s) | |
if __name__ == '__main__': | |
env = BullsCowsEnv() | |
action = env.action_space.sample() | |
print(action) | |
obs, reward, done, _ = env.step(action) | |
env.render() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
from gym import spaces | |
import numpy as np | |
import math | |
import random | |
from gym.utils import seeding | |
import traceback | |
class BullsCowsEnv(gym.Env): | |
def __init__(self): | |
super(BullsCowsEnv, self).__init__() | |
self.action_space = spaces.Box( | |
low=np.array([0, 0, 0, 0]), | |
high=np.array([9, 9, 9, 9]), | |
dtype=np.int, | |
) | |
self.observation_space = spaces.Box( | |
low=0, | |
high=9, | |
shape=(20, 4), | |
dtype=np.uint, | |
) | |
self.guess_count = 0 | |
self.state = [] | |
self.solution = [] | |
self.old_guess = [] | |
self.reset() | |
self.max_right_count = 0 | |
self.max_position_count = 0 | |
# self.seed() | |
def step(self, action): | |
assert self.action_space.contains(action) | |
action = [int(round(x)) for x in action] | |
done = False | |
reward = -1 | |
if len(action) > len(set(action)) or self.is_ever_guess(action): | |
done = True | |
else: | |
reward = 1 | |
prompt = self.get_prompt(action) | |
self.state[self.guess_count * 2] = action | |
self.state[self.guess_count * 2 + 1] = prompt | |
right_count = prompt[0] | |
position_count = prompt[2] | |
reward += right_count * 10 | |
reward += position_count * 3 | |
# self.max_right_count = max(right_count, self.max_right_count) | |
# self.max_position_count = max(position_count, self.max_position_count) | |
if prompt[0] == 4: | |
reward += 100 | |
done = True | |
self.old_guess = action | |
self.guess_count += 1 | |
if self.guess_count > 9: | |
done = True | |
return self.state, reward, done, {} | |
def get_prompt(self, guess): | |
right_count = 0 | |
position_count = 0 | |
for i, num in enumerate(guess): | |
if guess[i] == self.solution[i]: | |
right_count += 1 | |
if num in self.solution: | |
position_count += 1 | |
return np.array([right_count, 1, position_count - right_count, 2]) | |
def seed(self, seed=None): | |
self.np_random, seed = seeding.np_random(seed) | |
return [seed] | |
def reset(self): | |
# traceback.print_stack() | |
self.guess_count = 0 | |
self.solution = self.new_solution() | |
self.old_guess = np.zeros(shape=[1, 4]) | |
self.max_right_count = 0 | |
self.max_position_count = 0 | |
self.state = self.init_state() | |
return self.step(self.new_solution())[0] | |
def init_state(self): | |
# return np.random.randint(0, 9, size=(10, 8)) | |
return np.zeros(shape=(20, 4), dtype=np.uint) | |
def new_solution(self): | |
solution = np.random.randint(0, 9, size=4) | |
while len(solution) != len(set(solution)): | |
solution = np.random.randint(0, 9, size=4) | |
return solution | |
def is_ever_guess(self, action): | |
for old_guess in self.state: | |
old_guess = old_guess[:4] | |
if np.array_equal(old_guess, action): | |
return True | |
return False | |
def render(self, mode='human', close=False): | |
print(self.solution) | |
for i, row in enumerate(self.state): | |
s = "" | |
is_prompt = i % 2 != 0 | |
for j, num in enumerate(row): | |
num = int(num) | |
if not is_prompt: | |
s += str(num) | |
else: | |
if j % 2 == 0: | |
s += str(num) | |
else: | |
if num == 1: | |
s += "A" | |
elif num == 2: | |
s += "B" | |
else: | |
s += "_" | |
print(s) | |
if __name__ == '__main__': | |
env = BullsCowsEnv() | |
action = env.action_space.sample() | |
print(action) | |
print(env.is_ever_guess(action)) | |
obs, reward, done, _ = env.step(action) | |
print(env.is_ever_guess(action)) | |
env.render() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
from gym import spaces | |
import numpy as np | |
import math | |
import random | |
from gym.utils import seeding | |
LEN_OF_PACE = 20 | |
class JumpWalkerEnv(gym.Env): | |
"""Custom Environment that follows gym interface""" | |
metadata = {'render.modes': ['human']} | |
def __init__(self): | |
super(JumpWalkerEnv, self).__init__() | |
# 0 left 1 right 2 jump | |
self.action_space = spaces.Discrete(3) | |
# 0100010002 | |
self.observation_space = spaces.Box( | |
low=np.array([0 for _ in range(LEN_OF_PACE)]), | |
high=np.array([3 for _ in range(LEN_OF_PACE)]), | |
dtype=np.uint, | |
) | |
self.current_pos = 0 | |
self.state = self.make_map() | |
self.seed() | |
def step(self, action): | |
reward = -1 | |
self.state[self.current_pos] = 0 | |
if action == 0: # left | |
self.current_pos -= 1 | |
elif action == 1: # right | |
self.current_pos += 1 | |
else: # jump | |
self.current_pos += 2 | |
current_pos = self.current_pos | |
done = False | |
if current_pos < 0 or current_pos >= LEN_OF_PACE or self.state[current_pos] == 1: | |
reward = -100 | |
done = True | |
elif current_pos == LEN_OF_PACE - 1: | |
reward = 100 | |
done = True | |
else: | |
self.state[current_pos] = 2 | |
return self.state, reward, done, {} | |
def reset(self): | |
self.current_pos = 0 | |
self.state = self.make_map() | |
return self.state | |
def make_map(self): | |
pace_map = [] | |
hole_count = random.randint(1, (LEN_OF_PACE - 2) / 2) | |
for i in range(LEN_OF_PACE): | |
if i == 0: | |
pace_map.append(2) | |
elif i == LEN_OF_PACE - 1: | |
pace_map.append(0) | |
elif hole_count > 0 and random.randint(0, 100) % 2 is 0 and pace_map[i - 1] != 1: | |
pace_map.append(1) | |
hole_count -= 1 | |
else: | |
pace_map.append(0) | |
return np.array(pace_map) | |
def seed(self, seed=None): | |
self.np_random, seed = seeding.np_random(seed) | |
return [seed] | |
def render(self, mode='human', close=False): | |
plot = "" | |
for n in self.state: | |
if n == 0: | |
plot += "_" | |
elif n == 1: | |
plot += "o" | |
else: | |
plot += "v" | |
print(plot) | |
ACTIONS_NAME = ["LEFT", "RIGHT", "JUMP"] | |
if __name__ == "__main__": | |
env = JumpWalkerEnv() | |
while True: | |
ok = False | |
env.reset() | |
print("TRY==============================") | |
env.render() | |
while True: | |
action = env.action_space.sample() | |
action_name = "LEFT" | |
print("Action -> ", ACTIONS_NAME[action]) | |
obs, done, reward, _ = env.step(action) | |
print("Reward -> ", reward) | |
if reward == 100: | |
ok = True | |
env.render() | |
if done: | |
break | |
if ok: | |
break |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
from stable_baselines.common.policies import MlpPolicy | |
from stable_baselines.common.vec_env import SubprocVecEnv, DummyVecEnv | |
from stable_baselines.bench import Monitor | |
from stable_baselines import PPO2 | |
from stable_baselines.results_plotter import load_results, ts2xy | |
import os | |
import numpy as np | |
import time | |
from stable_baselines.common.vec_env import VecNormalize, VecFrameStack | |
# ENV_NAME = 'Pendulum-v0' | |
# ENV_NAME = 'BipedalWalker-v2' | |
# ENV_NAME = 'Acrobot-v1' | |
# ENV_NAME = 'LunarLander-v2' | |
# ENV_NAME = 'BipedalWalkerHardcore-v2' | |
ENV_NAME = 'MountainCarContinuous-v0' | |
# ENV_NAME = 'MountainCar-v0' | |
# log_dir = "./rl-baselines-zoo/logs/ppo2" | |
log_dir = "./logs/" | |
model_file = log_dir + ENV_NAME + '_best_model_2.pkl' | |
env_file = log_dir # + ENV_NAME + '.env' | |
os.makedirs(log_dir, exist_ok=True) | |
# normalize = False | |
normalize = True | |
# model_file = log_dir + "/" + ENV_NAME + "_2/" + ENV_NAME + '.pkl' | |
# env_file = log_dir + "/" + ENV_NAME + "_2/" + ENV_NAME | |
print(model_file, env_file) | |
env = gym.make(ENV_NAME) | |
env = DummyVecEnv([lambda: env]) | |
if normalize: | |
env = VecNormalize(env) | |
def load_model(): | |
while True: | |
obs = env.reset() | |
print("Loading saved model...") | |
model = PPO2.load(model_file) | |
if normalize: | |
print("Loading running averaing...") | |
env.load_running_average(env_file) | |
n = 0 | |
while n < 2000: | |
action, _states = model.predict(obs) | |
obs, rewards, dones, info = env.step(action) | |
env.render() | |
if dones: | |
print("DONE==>") | |
break | |
n += 1 | |
time.sleep(5) | |
load_model() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gym | |
from stable_baselines.common.policies import MlpPolicy | |
from stable_baselines.common.vec_env import SubprocVecEnv, DummyVecEnv, VecNormalize | |
from stable_baselines.bench import Monitor | |
from stable_baselines import PPO2 | |
import os | |
import numpy as np | |
from stable_baselines.common import set_global_seeds | |
best_mean_reward, n_steps = -np.inf, 0 | |
# ENV_NAME = 'Pendulum-v0' | |
# ENV_NAME = 'BipedalWalker-v2' | |
ENV_NAME = 'Acrobot-v1' | |
# ENV_NAME = 'LunarLander-v2' | |
# ENV_NAME = 'BipedalWalkerHardcore-v2' | |
# ENV_NAME = 'MountainCarContinuous-v0' | |
# ENV_NAME = 'MountainCar-v0' | |
log_dir = "./logs/" | |
os.makedirs(log_dir, exist_ok=True) | |
model_file = log_dir + ENV_NAME + '_best_model_2.pkl' | |
env_file = log_dir # + ENV_NAME + '.env' | |
normalize = False | |
# normalize = True # 到底什么时候需要? | |
def linear_schedule(initial_value): | |
if isinstance(initial_value, str): | |
initial_value = float(initial_value) | |
def func(progress): | |
return progress * initial_value | |
return func | |
def make_env(env_id, rank=0, seed=0): | |
def _init(): | |
set_global_seeds(seed + rank) | |
env = gym.make(env_id) | |
env.seed(seed + rank) | |
env = Monitor(env, os.path.join(log_dir, str(rank)), allow_early_resets=True) | |
return env | |
return _init | |
def learn_model(): | |
def callback(_locals, _globals): | |
print("Saving model...") | |
_locals['self'].save(model_file) | |
if normalize: | |
print("Saving running average...") | |
env.save_running_average(env_file) | |
return True | |
env = SubprocVecEnv([make_env(ENV_NAME, i, 0) for i in range(16)]) | |
if normalize: | |
print("Using VecNormalize...") | |
env = VecNormalize(env) | |
args = dict( | |
n_steps=2048, | |
nminibatches=32, | |
noptepochs=10, | |
verbose=1, | |
cliprange=linear_schedule(0.2), | |
learning_rate=linear_schedule(2.5e-4), | |
ent_coef=0.001, | |
) | |
if os.path.exists(model_file): | |
if normalize: | |
env.load_running_average(env_file) | |
print("Loaded running average...") | |
model = PPO2.load( | |
model_file, | |
env, | |
**args, | |
) | |
print("Loaded file...", model_file) | |
model.learn(total_timesteps=900000000, callback=callback) | |
else: | |
model = PPO2( | |
MlpPolicy, | |
env, | |
**args, | |
) | |
model.learn(total_timesteps=900000000, callback=callback) | |
learn_model() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment