資料は以下にあります。 このレポジトリは、資料中の「SageMaker でChainerRLを使ってみる」のサンプルです。
https://pages.awscloud.com/rs/112-TZM-766/images/Access_AWS-SageMaker-Matsuri-20190212.pdf
資料は以下にあります。 このレポジトリは、資料中の「SageMaker でChainerRLを使ってみる」のサンプルです。
https://pages.awscloud.com/rs/112-TZM-766/images/Access_AWS-SageMaker-Matsuri-20190212.pdf
"""An example of training DQN against OpenAI Gym Envs. | |
This script is an example of training a DQN agent against OpenAI Gym envs. | |
Both discrete and continuous action spaces are supported. For continuous action | |
spaces, A NAF (Normalized Advantage Function) is used to approximate Q-values. | |
To solve CartPole-v0, run: | |
python train_dqn_gym.py --env CartPole-v0 | |
To solve Pendulum-v0, run: | |
python train_dqn_gym.py --env Pendulum-v0 | |
""" | |
from __future__ import print_function | |
from __future__ import unicode_literals | |
from __future__ import division | |
from __future__ import absolute_import | |
from builtins import * # NOQA | |
import argparse | |
import os | |
import sys | |
def install(package): | |
if 'SM_OUTPUT_DATA_DIR' in os.environ: | |
os.system('pip3 install {}'.format(package)) | |
install('chainerrl==0.4.0') | |
#for local | |
if not 'SM_OUTPUT_DATA_DIR' in os.environ: | |
os.environ['SM_OUTPUT_DATA_DIR'] = 'results' | |
os.environ['SM_MODEL_DIR'] = 'model' | |
os.environ['SM_CHANNEL_TRAIN'] = '' | |
os.environ['SM_CHANNEL_TEST'] = '' | |
if not 'SM_CHANNEL_TRAIN' in os.environ: | |
os.environ['SM_CHANNEL_TRAIN'] = '' | |
if not 'SM_CHANNEL_TEST' in os.environ: | |
os.environ['SM_CHANNEL_TEST'] = '' | |
import chainer | |
from chainer import optimizers | |
import chainerrl | |
from chainerrl.agents.dqn import DQN | |
from chainerrl import experiments | |
from chainerrl import explorers | |
from chainerrl import links | |
from chainerrl import misc | |
from chainerrl import q_functions | |
from chainerrl import replay_buffer | |
import gym | |
from gym import spaces | |
import gym.wrappers | |
import numpy as np | |
def main(): | |
import logging | |
logging.basicConfig(level=logging.WARNING) | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--env', type=str, default='CartPole-v0') | |
parser.add_argument('--seed', type=int, default=0, | |
help='Random seed [0, 2 ** 32)') | |
parser.add_argument('--gpu', type=int, default=0) | |
parser.add_argument('--final-exploration-steps', | |
type=int, default=10 ** 4) | |
parser.add_argument('--start-epsilon', type=float, default=1.0) | |
parser.add_argument('--end-epsilon', type=float, default=0.1) | |
parser.add_argument('--noisy-net-sigma', type=float, default=None) | |
parser.add_argument('--demo', action='store_true', default=False) | |
parser.add_argument('--load', type=str, default=None) | |
parser.add_argument('--steps', type=int, default=10 ** 5) | |
parser.add_argument('--eval-steps', type=int, default=1000) | |
parser.add_argument('--prioritized-replay', action='store_true') | |
parser.add_argument('--episodic-replay', action='store_true') | |
parser.add_argument('--replay-start-size', type=int, default=1000) | |
parser.add_argument('--target-update-interval', type=int, default=10 ** 2) | |
parser.add_argument('--target-update-method', type=str, default='hard') | |
parser.add_argument('--soft-update-tau', type=float, default=1e-2) | |
parser.add_argument('--update-interval', type=int, default=1) | |
parser.add_argument('--eval-n-runs', type=int, default=100) | |
parser.add_argument('--eval-interval', type=int, default=10 ** 4) | |
parser.add_argument('--n-hidden-channels', type=int, default=100) | |
parser.add_argument('--n-hidden-layers', type=int, default=2) | |
parser.add_argument('--gamma', type=float, default=0.99) | |
parser.add_argument('--minibatch-size', type=int, default=None) | |
parser.add_argument('--reward-scale-factor', type=float, default=1e-3) | |
parser.add_argument('--render-train', action='store_true') | |
parser.add_argument('--render-eval', action='store_true') | |
parser.add_argument('--monitor', action='store_true') | |
parser.add_argument('--record-video', type=str, default=None) # for sagemaker | |
# Required for sagemaker | |
parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR']) | |
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR']) | |
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAIN']) | |
parser.add_argument('--test', type=str, default=os.environ['SM_CHANNEL_TEST']) | |
args = parser.parse_args() | |
# Sagemaker's hyperparameter cannot set switch flags. | |
if args.record_video == "True": | |
print("video recording mode.") | |
args.demo = True | |
args.monitor = True | |
# Set a random seed used in ChainerRL | |
misc.set_random_seed(args.seed, gpus=(args.gpu,)) | |
args.output_data_dir = experiments.prepare_output_dir( | |
args, args.output_data_dir, argv=sys.argv) | |
print('Output files are saved in {}'.format(args.output_data_dir)) | |
def clip_action_filter(a): | |
return np.clip(a, action_space.low, action_space.high) | |
def make_env(test): | |
env = gym.make(args.env) | |
# Use different random seeds for train and test envs | |
env_seed = 2 ** 32 - 1 - args.seed if test else args.seed | |
env.seed(env_seed) | |
if args.monitor: | |
env = gym.wrappers.Monitor(env, args.output_data_dir, force=True) | |
if isinstance(env.action_space, spaces.Box): | |
misc.env_modifiers.make_action_filtered(env, clip_action_filter) | |
if not test: | |
misc.env_modifiers.make_reward_filtered( | |
env, lambda x: x * args.reward_scale_factor) | |
if ((args.render_eval and test) or | |
(args.render_train and not test)): | |
misc.env_modifiers.make_rendered(env) | |
return env | |
env = make_env(test=False) | |
timestep_limit = env.spec.tags.get( | |
'wrapper_config.TimeLimit.max_episode_steps') | |
obs_space = env.observation_space | |
obs_size = obs_space.low.size | |
action_space = env.action_space | |
if isinstance(action_space, spaces.Box): | |
action_size = action_space.low.size | |
# Use NAF to apply DQN to continuous action spaces | |
q_func = q_functions.FCQuadraticStateQFunction( | |
obs_size, action_size, | |
n_hidden_channels=args.n_hidden_channels, | |
n_hidden_layers=args.n_hidden_layers, | |
action_space=action_space) | |
# Use the Ornstein-Uhlenbeck process for exploration | |
ou_sigma = (action_space.high - action_space.low) * 0.2 | |
explorer = explorers.AdditiveOU(sigma=ou_sigma) | |
else: | |
n_actions = action_space.n | |
q_func = q_functions.FCStateQFunctionWithDiscreteAction( | |
obs_size, n_actions, | |
n_hidden_channels=args.n_hidden_channels, | |
n_hidden_layers=args.n_hidden_layers) | |
# Use epsilon-greedy for exploration | |
explorer = explorers.LinearDecayEpsilonGreedy( | |
args.start_epsilon, args.end_epsilon, args.final_exploration_steps, | |
action_space.sample) | |
if args.noisy_net_sigma is not None: | |
links.to_factorized_noisy(q_func) | |
# Turn off explorer | |
explorer = explorers.Greedy() | |
# Draw the computational graph and save it in the output directory. | |
chainerrl.misc.draw_computational_graph( | |
[q_func(np.zeros_like(obs_space.low, dtype=np.float32)[None])], | |
os.path.join(args.output_data_dir, 'model')) | |
opt = optimizers.Adam() | |
opt.setup(q_func) | |
rbuf_capacity = 5 * 10 ** 5 | |
if args.episodic_replay: | |
if args.minibatch_size is None: | |
args.minibatch_size = 4 | |
if args.prioritized_replay: | |
betasteps = (args.steps - args.replay_start_size) \ | |
// args.update_interval | |
rbuf = replay_buffer.PrioritizedEpisodicReplayBuffer( | |
rbuf_capacity, betasteps=betasteps) | |
else: | |
rbuf = replay_buffer.EpisodicReplayBuffer(rbuf_capacity) | |
else: | |
if args.minibatch_size is None: | |
args.minibatch_size = 32 | |
if args.prioritized_replay: | |
betasteps = (args.steps - args.replay_start_size) \ | |
// args.update_interval | |
rbuf = replay_buffer.PrioritizedReplayBuffer( | |
rbuf_capacity, betasteps=betasteps) | |
else: | |
rbuf = replay_buffer.ReplayBuffer(rbuf_capacity) | |
def phi(obs): | |
return obs.astype(np.float32) | |
agent = DQN(q_func, opt, rbuf, gpu=args.gpu, gamma=args.gamma, | |
explorer=explorer, replay_start_size=args.replay_start_size, | |
target_update_interval=args.target_update_interval, | |
update_interval=args.update_interval, | |
phi=phi, minibatch_size=args.minibatch_size, | |
target_update_method=args.target_update_method, | |
soft_update_tau=args.soft_update_tau, | |
episodic_update=args.episodic_replay, episodic_update_len=16) | |
if args.load: | |
agent.load(args.load) | |
eval_env = make_env(test=True) | |
def start_display(): | |
from pyvirtualdisplay import Display | |
display = Display(visible=0, size=(1024, 768)) | |
display.start() | |
import os | |
os.environ["DISPLAY"] = ":" + str(display.display) + "." + str(display.screen) | |
if args.demo: | |
print('Starting demo mode.') | |
#start_display() | |
eval_stats = experiments.eval_performance( | |
env=eval_env, | |
agent=agent, | |
n_runs=args.eval_n_runs, | |
max_episode_len=timestep_limit) | |
print('n_runs: {} mean: {} median: {} stdev {}'.format( | |
args.eval_n_runs, eval_stats['mean'], eval_stats['median'], | |
eval_stats['stdev'])) | |
else: | |
experiments.train_agent_with_evaluation( | |
agent=agent, env=env, steps=args.steps, | |
eval_n_runs=args.eval_n_runs, eval_interval=args.eval_interval, | |
outdir=args.output_data_dir, eval_env=eval_env, | |
max_episode_len=timestep_limit) | |
agent.save(args.model_dir) | |
if args.monitor: | |
eval_env.env.close() | |
eval_env.close() | |
env.close() | |
if __name__ == '__main__': | |
main() |