Last active
September 16, 2020 13:53
-
-
Save chokosabe/23188725dc05b586cf1613bdebab5ef9 to your computer and use it in GitHub Desktop.
A PPO train eval method
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""PPO Learner implementation.""" | |
import gin | |
import tensorflow.compat.v2 as tf | |
from tf_agents.experimental.train import learner | |
from tf_agents.networks import utils | |
from tf_agents.utils import common | |
@gin.configurable | |
class PPOLearner(object): | |
"""Manages all the learning details needed when training an PPO. | |
These include: | |
* Using distribution strategies correctly | |
* Summaries | |
* Checkpoints | |
* Minimizing entering/exiting TF context: | |
Especially in the case of TPUs scheduling a single TPU program to | |
perform multiple train steps is critical for performance. | |
* Generalizes the train call to be done correctly across CPU, GPU, or TPU | |
executions managed by DistributionStrategies. This uses `strategy.run` and | |
then makes sure to do a reduce operation over the `LossInfo` returned by | |
the agent. | |
""" | |
def __init__(self, | |
root_dir, | |
train_step, | |
agent, | |
max_num_sequences=None, | |
minibatch_size=None, | |
shuffle_buffer_size=None, | |
after_train_strategy_step_fn=None, | |
triggers=None, | |
checkpoint_interval=100000, | |
summary_interval=1000, | |
use_kwargs_in_agent_train=False, | |
strategy=None): | |
"""Initializes a PPOLearner instance. | |
Args: | |
root_dir: Main directory path where checkpoints, saved_models, and | |
summaries will be written to. | |
train_step: a scalar tf.int64 `tf.Variable` which will keep track of the | |
number of train steps. This is used for artifacts created like | |
summaries, or outputs in the root_dir. | |
agent: `tf_agent.TFAgent` instance to train with. | |
max_num_sequences: The max number of sequences to read from the input | |
dataset in `run`. Defaults to None, in which case `run` will terminate | |
when reach the end of the dataset (for instance when the rate limiter | |
times out). | |
minibatch_size: The minibatch size. The dataset used for training is | |
shaped [minibatch_size, 1, ...]. | |
shuffle_buffer_size: The buffer size for shuffling the trajectories before | |
splitting them into mini batches. Only required when mini batch | |
learning is enabled (minibatch_size is set). Otherwise it is ignored. | |
Commonly set to a number 1-3x the episode length of your environment. | |
after_train_strategy_step_fn: (Optional) callable of the form | |
`fn(sample, loss)` which can be used for example to update priorities in | |
a replay buffer where sample is pulled from the `experience_iterator` | |
and loss is a `LossInfo` named tuple returned from the agent. This is | |
called after every train step. It runs using `strategy.run(...)`. | |
triggers: List of callables of the form `trigger(train_step)`. After every | |
`run` call every trigger is called with the current `train_step` value | |
as an np scalar. | |
checkpoint_interval: Number of train steps in between checkpoints. Note | |
these are placed into triggers and so a check to generate a checkpoint | |
only occurs after every `run` call. Set to -1 to disable. This only | |
takes care of the checkpointing the training process. Policies must be | |
explicitly exported through triggers | |
summary_interval: Number of train steps in between summaries. Note these | |
are placed into triggers and so a check to generate a checkpoint only | |
occurs after every `run` call. | |
use_kwargs_in_agent_train: If True the experience from the replay buffer | |
is passed into the agent as kwargs. This requires samples from the RB to | |
be of the form `dict(experience=experience, kwarg1=kwarg1, ...)`. This | |
is useful if you have an agent with a custom argspec. | |
strategy: (Optional) `tf.distribute.Strategy` to use during training. | |
""" | |
if minibatch_size is not None and shuffle_buffer_size is None: | |
raise ValueError( | |
'shuffle_buffer_size must be provided if minibatch_size is not None.' | |
) | |
if agent.update_normalizers_in_train: | |
raise ValueError( | |
'agent.update_normalizers_in_train should be set to False when ' | |
'PPOLearner is used.' | |
) | |
self._agent = agent | |
self._max_num_sequences = max_num_sequences | |
self._minibatch_size = minibatch_size | |
self._shuffle_buffer_size = shuffle_buffer_size | |
self._generic_learner = learner.Learner( | |
root_dir, | |
train_step, | |
agent, | |
experience_dataset_fn=None, | |
after_train_strategy_step_fn=after_train_strategy_step_fn, | |
triggers=triggers, | |
checkpoint_interval=checkpoint_interval, | |
summary_interval=summary_interval, | |
use_kwargs_in_agent_train=use_kwargs_in_agent_train, | |
strategy=strategy) | |
def run(self, iterations, dataset): | |
"""Runs training until dataset timesout, or when num sequences is reached. | |
Args: | |
iterations: Number of iterations/epochs to repeat over the collected | |
sequences. (Schulman,2017) sets this to 10 for Mujoco, 15 for Roboschool | |
and 3 for Atari. | |
dataset: A 'tf.Dataset' where each sample is shaped | |
[sample_batch_size, sequence_length, ...], commonly the output from | |
'reverb_replay_buffer.as_dataset(sample_batch_size, preprocess_fn)'. | |
Returns: | |
The total loss computed before running the final step. | |
""" | |
# TODO(b/160802425): Verify this setup works with distributed. | |
if self._max_num_sequences: | |
dataset = dataset.take(self._max_num_sequences) | |
cached_dataset = dataset.cache() | |
self._update_advantage_normalizer(cached_dataset) | |
new_dataset = cached_dataset.repeat(iterations) | |
if self._minibatch_size: | |
def squash_dataset_element(sequence, info): | |
return tf.nest.map_structure( | |
utils.BatchSquash(2).flatten, (sequence, info)) | |
# We unbatch the dataset shaped [B, T, ...] to a new dataset that contains | |
# individual elements. | |
# Note that we unbatch across the time dimension, which could result in | |
# mini batches that contain subsets from more than one sequences. The PPO | |
# agent can handle mini batches across episode boundaries. | |
new_dataset = new_dataset.map(squash_dataset_element).unbatch() | |
new_dataset = new_dataset.shuffle(self._shuffle_buffer_size) | |
new_dataset = new_dataset.batch(1, drop_remainder=True) | |
new_dataset = new_dataset.batch(self._minibatch_size, drop_remainder=True) | |
# TODO(b/161133726): use learner.run once it supports None iterations. | |
def _summary_record_if(): | |
return tf.math.equal( | |
self._generic_learner.train_step % | |
tf.constant(self._generic_learner.summary_interval), 0) | |
with self._generic_learner.train_summary_writer.as_default(), \ | |
common.soft_device_placement(), \ | |
tf.compat.v2.summary.record_if(_summary_record_if), \ | |
self._generic_learner.strategy.scope(): | |
loss_info = self.multi_train_step(iter(new_dataset)) | |
train_step_val = self._generic_learner.train_step_numpy | |
for trigger in self._generic_learner.triggers: | |
trigger(train_step_val) | |
self._update_normalizers(cached_dataset) | |
return loss_info | |
@common.function(autograph=True) | |
def multi_train_step(self, iterator): | |
experience, sample_info = next(iterator) | |
loss_info = self.single_train_step(experience, sample_info) | |
for experience, sample_info in iterator: | |
loss_info = self.single_train_step(experience, sample_info) | |
return loss_info | |
@common.function(autograph=False) | |
def single_train_step(self, experience, sample_info): | |
"""Train a single (mini) batch of Trajectories.""" | |
if self._generic_learner.use_kwargs_in_agent_train: | |
loss_info = self._generic_learner.strategy.run( | |
self._agent.train, kwargs=experience) | |
else: | |
loss_info = self._generic_learner.strategy.run( | |
self._agent.train, args=(experience,)) | |
if self._generic_learner.after_train_strategy_step_fn: | |
if self.use_kwargs_in_agent_train: | |
self.strategy.run( | |
self._generic_learner.after_train_strategy_step_fn, | |
kwargs=dict( | |
experience=(experience, sample_info), loss_info=loss_info)) | |
else: | |
self.strategy.run( | |
self._generic_learner.after_train_strategy_step_fn, | |
args=((experience, sample_info), loss_info)) | |
return loss_info | |
@common.function(autograph=True) | |
def _update_normalizers(self, dataset): | |
iterator = iter(dataset) | |
traj, _ = next(iterator) | |
self._agent.update_observation_normalizer(traj.observation) | |
self._agent.update_reward_normalizer(traj.reward) | |
for traj, _ in iterator: | |
self._agent.update_observation_normalizer(traj.observation) | |
self._agent.update_reward_normalizer(traj.reward) | |
@common.function(autograph=True) | |
def _update_advantage_normalizer(self, dataset): | |
self._agent._reset_advantage_normalizer() # pylint: disable=protected-access | |
iterator = iter(dataset) | |
traj, _ = next(iterator) | |
self._agent._update_advantage_normalizer(traj.policy_info['advantage']) # pylint: disable=protected-access | |
for traj, _ in iterator: | |
self._agent._update_advantage_normalizer(traj.policy_info['advantage']) # pylint: disable=protected-access | |
@property | |
def train_step_numpy(self): | |
"""The current train_step. | |
Returns: | |
The current `train_step`. Note this will return a scalar numpy array which | |
holds the `train_step` value when this was called. | |
""" | |
return self._generic_learner.train_step_numpy | |
"""Train and Eval PPOClipAgent in the Mujoco environments. | |
All hyperparameters come from the PPO paper | |
https://arxiv.org/abs/1707.06347.pdf | |
""" | |
import os | |
from absl import logging | |
import gin | |
import reverb | |
import tensorflow.compat.v2 as tf | |
from tf_agents.agents.ppo import ppo_clip_agent | |
from tf_agents.environments import suite_mujoco | |
#import .ppo_learner | |
from tf_agents.experimental.train import actor | |
from tf_agents.experimental.train import learner | |
from tf_agents.experimental.train import triggers | |
from tf_agents.experimental.train.utils import spec_utils | |
from tf_agents.experimental.train.utils import train_utils | |
from tf_agents.replay_buffers import reverb_replay_buffer | |
from tf_agents.replay_buffers import reverb_utils | |
from tf_agents.metrics import py_metrics | |
from tf_agents.networks import actor_distribution_network | |
from tf_agents.networks import value_network | |
from tf_agents.policies import py_tf_eager_policy | |
actor_fc_layers=(64, 64) | |
value_fc_layers=(64, 64) | |
@gin.configurable | |
def train_eval( | |
root_dir, | |
env_name='Hedge', | |
# Training params | |
num_iterations=20000, | |
actor_fc_layers=actor_fc_layers, | |
value_fc_layers=value_fc_layers, | |
learning_rate=1e-5, | |
collect_sequence_length=2048, | |
minibatch_size=64, | |
num_epochs=10, | |
# Agent params | |
importance_ratio_clipping=0.2, | |
lambda_value=0.95, | |
discount_factor=0.99, | |
entropy_regularization=0., | |
value_pred_loss_coef=0.5, | |
use_gae=True, | |
use_td_lambda_return=True, | |
gradient_clipping=None, | |
value_clipping=None, | |
# Replay params | |
reverb_port=None, | |
replay_capacity=10000, | |
# Others | |
policy_save_interval=5000, | |
summary_interval=1000, | |
eval_interval=10000, | |
eval_episodes=30, | |
debug_summaries=False, | |
summarize_grads_and_vars=False, | |
env=None, | |
): | |
"""Trains and evaluates PPO (Importance Ratio Clipping). | |
Args: | |
root_dir: Main directory path where checkpoints, saved_models, and summaries | |
will be written to. | |
env_name: Name for the Mujoco environment to load. | |
num_iterations: The number of iterations to perform collection and training. | |
actor_fc_layers: List of fully_connected parameters for the actor network, | |
where each item is the number of units in the layer. | |
value_fc_layers: : List of fully_connected parameters for the value network, | |
where each item is the number of units in the layer. | |
learning_rate: Learning rate used on the Adam optimizer. | |
collect_sequence_length: Number of steps to take in each collect run. | |
minibatch_size: Number of elements in each mini batch. If `None`, the entire | |
collected sequence will be treated as one batch. | |
num_epochs: Number of iterations to repeat over all collected data per data | |
collection step. (Schulman,2017) sets this to 10 for Mujoco, 15 for | |
Roboschool and 3 for Atari. | |
importance_ratio_clipping: Epsilon in clipped, surrogate PPO objective. For | |
more detail, see explanation at the top of the doc. | |
lambda_value: Lambda parameter for TD-lambda computation. | |
discount_factor: Discount factor for return computation. Default to `0.99` | |
which is the value used for all environments from (Schulman, 2017). | |
entropy_regularization: Coefficient for entropy regularization loss term. | |
Default to `0.0` because no entropy bonus was used in (Schulman, 2017). | |
value_pred_loss_coef: Multiplier for value prediction loss to balance with | |
policy gradient loss. Default to `0.5`, which was used for all | |
environments in the OpenAI baseline implementation. This parameters is | |
irrelevant unless you are sharing part of actor_net and value_net. In that | |
case, you would want to tune this coeeficient, whose value depends on the | |
network architecture of your choice. | |
use_gae: If True (default False), uses generalized advantage estimation for | |
computing per-timestep advantage. Else, just subtracts value predictions | |
from empirical return. | |
use_td_lambda_return: If True (default False), uses td_lambda_return for | |
training value function; here: `td_lambda_return = gae_advantage + | |
value_predictions`. `use_gae` must be set to `True` as well to enable TD | |
-lambda returns. If `use_td_lambda_return` is set to True while | |
`use_gae` is False, the empirical return will be used and a warning will | |
be logged. | |
gradient_clipping: Norm length to clip gradients. | |
value_clipping: Difference between new and old value predictions are clipped | |
to this threshold. Value clipping could be helpful when training | |
very deep networks. Default: no clipping. | |
reverb_port: Port for reverb server, if None, use a randomly chosen unused | |
port. | |
replay_capacity: The maximum number of elements for the replay buffer. Items | |
will be wasted if this is smalled than collect_sequence_length. | |
policy_save_interval: How often, in train_steps, the policy will be saved. | |
summary_interval: How often to write data into Tensorboard. | |
eval_interval: How often to run evaluation, in train_steps. | |
eval_episodes: Number of episodes to evaluate over. | |
debug_summaries: Boolean for whether to gather debug summaries. | |
summarize_grads_and_vars: If true, gradient summaries will be written. | |
""" | |
collect_env = env | |
eval_env = env | |
num_environments = 1 | |
observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = ( | |
spec_utils.get_tensor_specs(collect_env)) | |
actor_net = actor_distribution_network.ActorDistributionNetwork( | |
observation_tensor_spec, | |
action_tensor_spec, | |
fc_layer_params=actor_fc_layers, | |
activation_fn=tf.nn.tanh, | |
kernel_initializer=tf.keras.initializers.Orthogonal()) | |
value_net = value_network.ValueNetwork( | |
observation_tensor_spec, | |
fc_layer_params=value_fc_layers, | |
kernel_initializer=tf.keras.initializers.Orthogonal()) | |
train_step = train_utils.create_train_step() | |
current_iteration = tf.Variable(0, dtype=tf.int64) | |
def learning_rate_fn(): | |
# Linearly decay the learning rate. | |
return learning_rate * (1 - current_iteration / num_iterations) | |
agent = ppo_clip_agent.PPOClipAgent( | |
time_step_tensor_spec, | |
action_tensor_spec, | |
optimizer=tf.compat.v1.train.AdamOptimizer( | |
learning_rate=learning_rate_fn, epsilon=1e-5), | |
actor_net=actor_net, | |
value_net=value_net, | |
importance_ratio_clipping=importance_ratio_clipping, | |
lambda_value=lambda_value, | |
discount_factor=discount_factor, | |
entropy_regularization=entropy_regularization, | |
value_pred_loss_coef=value_pred_loss_coef, | |
# This is a legacy argument for the number of times we repeat the data | |
# inside of the train function, incompatible with mini batch learning. | |
# We set the epoch number from the replay buffer and tf.Data instead. | |
num_epochs=1, | |
use_gae=use_gae, | |
use_td_lambda_return=use_td_lambda_return, | |
gradient_clipping=gradient_clipping, | |
value_clipping=value_clipping, | |
# TODO(b/150244758): Default compute_value_and_advantage_in_train to False | |
# after Reverb open source. | |
compute_value_and_advantage_in_train=False, | |
# Skips updating normalizers in the agent, as it's handled in the learner. | |
update_normalizers_in_train=False, | |
debug_summaries=debug_summaries, | |
summarize_grads_and_vars=summarize_grads_and_vars, | |
train_step_counter=train_step) | |
agent.initialize() | |
table_name = 'uniform_table' | |
table = reverb.Table( | |
table_name, | |
max_size=replay_capacity, | |
sampler=reverb.selectors.Uniform(), | |
remover=reverb.selectors.Fifo(), | |
rate_limiter=reverb.rate_limiters.MinSize(1), | |
max_times_sampled=1) | |
reverb_server = reverb.Server([table], port=reverb_port) | |
reverb_replay = reverb_replay_buffer.ReverbReplayBuffer( | |
agent.collect_data_spec, | |
sequence_length=collect_sequence_length, | |
table_name=table_name, | |
server_address='localhost:{}'.format(reverb_server.port), | |
# The only collected sequence is used to populate the batches. | |
max_cycle_length=1, | |
rate_limiter_timeout_ms=1000) | |
# TODO(b/162244134): move to using the episodic observer after the performance | |
# issue caused by the bug is resolved. | |
rb_observer = reverb_utils.ReverbAddTrajectoryObserver( # pylint: disable=protected-access | |
reverb_replay.py_client, table_name, | |
sequence_length=collect_sequence_length, | |
stride_length=collect_sequence_length, | |
#allow_multi_episode_sequences=True | |
) | |
saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR) | |
collect_env_step_metric = py_metrics.EnvironmentSteps() | |
learning_triggers = [ | |
triggers.PolicySavedModelTrigger( | |
saved_model_dir, | |
agent, | |
train_step, | |
interval=policy_save_interval, | |
metadata_metrics={ | |
triggers.ENV_STEP_METADATA_KEY: collect_env_step_metric | |
}), | |
triggers.StepPerSecondLogTrigger(train_step, interval=summary_interval), | |
] | |
agent_learner = PPOLearner( | |
root_dir, | |
train_step, | |
agent, | |
minibatch_size=minibatch_size, | |
shuffle_buffer_size=collect_sequence_length, | |
triggers=learning_triggers) | |
tf_collect_policy = agent.collect_policy | |
collect_policy = py_tf_eager_policy.PyTFEagerPolicy( | |
tf_collect_policy, use_tf_function=True) | |
collect_actor = actor.Actor( | |
collect_env, | |
collect_policy, | |
train_step, | |
steps_per_run=collect_sequence_length, | |
observers=[rb_observer], | |
metrics=actor.collect_metrics(buffer_size=10) + [collect_env_step_metric], | |
reference_metrics=[collect_env_step_metric], | |
summary_dir=os.path.join(root_dir, learner.TRAIN_DIR), | |
summary_interval=summary_interval) | |
tf_greedy_policy = agent.policy | |
greedy_policy = py_tf_eager_policy.PyTFEagerPolicy( | |
tf_greedy_policy, use_tf_function=True) | |
if eval_interval: | |
logging.info('Intial evaluation.') | |
eval_actor = actor.Actor( | |
eval_env, | |
greedy_policy, | |
train_step, | |
metrics=actor.eval_metrics(eval_episodes), | |
summary_dir=os.path.join(root_dir, 'eval'), | |
episodes_per_run=eval_episodes) | |
eval_actor.run_and_log() | |
logging.info('Training.') | |
dataset = reverb_replay.as_dataset( | |
sample_batch_size=num_environments, | |
sequence_preprocess_fn=agent.preprocess_sequence) | |
for _ in range(num_iterations): | |
collect_actor.run() | |
# TODO(b/159490625): Get rid of the reset call once the | |
# multi_episode_sequences flag is gone. | |
# TODO(b/159615593): Update to use observer.flush. | |
# Reset the reverb observer to make sure the data collected is flushed and | |
# written to the RB. | |
rb_observer.reset() | |
agent_learner.run(iterations=num_epochs, dataset=dataset) | |
reverb_replay.clear() | |
current_iteration.assign_add(1) | |
if eval_interval and agent_learner.train_step_numpy % eval_interval == 0: | |
logging.info('Evaluating.') | |
eval_actor.run_and_log() | |
rb_observer.close() | |
reverb_server.stop() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment