# Train PPO agent on cartpole and collect expert demonstrations
python -m imitation.scripts.expert_demos with cartpole log_dir=quickstart
# Train GAIL from demonstrations
python -m imitation.scripts.train_adversarial with gail cartpole rollout_path=quickstart/rollouts/final.pkl
# Train AIRL from demonstrations
python -m imitation.scripts.train_adversarial with airl cartpole rollout_path=quickstart/rollouts/final.pkl
# Tip: `python -m imitation.scripts.* print_config` will list Sacred script options, which are documented
# in `src/imitation/scripts/`.
# For more information configuring Sacred options, see docs at https://sacred.readthedocs.io/en/stable/.
import gym
import pickle
import stable_baselines3 as sb3
from imitation.algorithms import bc
from imitation.data import types
from imitation.util import logger, util
# Load pickled test demonstrations.
with open("tests/data/expert_models/cartpole_0/rollouts/final.pkl", "rb") as f:
# This is a list of `types.Trajectory`, where
# every instance contains observations and actions for a single expert demonstration.
trajectories = pickle.load(f)
# Convert List[types.Trajectory] to an instance of `types.Transitions`.
# This is a more general dataclass containing unordered (observation, actions, next_observation)
# transitions.
transitions = types.flatten_trajectories(trajectories)
venv = util.make_vec_env("CartPole-v1")
# Train BC on expert data.
logger.configure("quickstart/tensorboard_dir_bc/")
bc_trainer = bc.BC(venv.observation_space, venv.action_space, expert_data=transitions)
bc_trainer.train(n_epochs=2)
# Train GAIL on expert data.
logger.configure("quickstart/tensorboard_dir_gail/")
gail_trainer = GAIL(venv, expert_data=transitions, expert_batch_size=32, gen_algo=sb3.PPO(venv))
gail_trainer.train(total_timesteps=2000)
# Train AIRL on expert data.
logger.configure("quickstart/tensorboard_dir_airl/")
airl_trainer = AIRL(venv, expert_data=transitions, expert_batch_size=32, gen_algo=sb3.PPO(venv))
airl_trainer.train(total_timesteps=2000)
BC, GAIL, and AIRL also accept as expert_data
any Pytorch-style DataLoader that iterates over dictionaries containing observations, actions, and next_observations.