Created
July 5, 2019 08:17
-
-
Save ernestum/95b72069549209366ef0188bb05e4df0 to your computer and use it in GitHub Desktop.
An efficient way of generating a large number of trajectories for a QuadrotorEnvironment
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import numpy as np | |
from quadrotor_environment.quadrotor_model import SysState | |
from quadrotor_environment.simulation_parameter_randomization import SPRWrappedQuadrotorEnvironment | |
from stable_baselines import PPO2 | |
state_dtype = np.dtype([("position", np.float, 3), | |
("velocity", np.float, 3), | |
("rotation", np.quaternion), | |
("angular_velocity", np.float, 3), | |
("propeller_speed", np.float, 4)]) | |
trajectory_dtype = np.dtype([("time", np.float), ("action", np.float, 4), ("state", state_dtype)]) | |
trajectory_list_dtype = np.dtype([("trajectory", trajectory_dtype)]) | |
def record_episodes(model, env_params, initial_states, num_steps, spr_seeds=None, batch_size=20): | |
""" | |
Records a number of episodes and yields them one by one in a structured numpy array. | |
The episodes are recorded in parallel in batches, to make use of the improved speed of parallel model predictions. | |
No more than `batch_size` episodes are kept in memory at any time. Thus this function is suitable for generating and | |
processing a large number of episodes efficiently, which means that we take advantage of parallel model predictions | |
while not occupying too much memory. | |
Note that the last batch might be smaller than batch_size. | |
:param model: The model to control the quadrotor. We use it to make deterministic predictions. | |
:param env_params: The parameters of the environment to record the episodes in. Refer to the documentation of | |
quadrotor_environment.simulation_parameter_randomization.SPRWrappedQuadrotorEnvironment for allowed values. | |
:param initial_states: A list containing an initial state for each episode to generate. If the list contains | |
something else than initial states, then a random state is generated using the specified environment. | |
:param num_steps: The number of steps to simulate forward for each episode. | |
:param spr_seeds: A list of seeds to pass to the SPRWrappedQuadrotorEnvironment | |
:param batch_size: The number of trajectories to generate in parallel. | |
:return: yields structured numpy arrays containing a trajectory each. Refer to | |
trajectory_generation.trajectory_dtype for the actual structure. | |
""" | |
for batch_start in range(0, len(initial_states), batch_size): | |
batch = initial_states[batch_start:min(batch_start+batch_size, len(initial_states))] | |
seed_batch = spr_seeds[batch_start:min(batch_start+batch_size, len(initial_states))] if spr_seeds else None | |
for traj in _record_episodes(model, env_params, batch, num_steps, spr_seeds=seed_batch): | |
yield traj | |
def _record_episodes(model: PPO2, env_params: dict, initial_states: list, num_steps: int, spr_seeds:list = None): | |
""" | |
Records a number of episodes and returns them in a structured numpy array. | |
The episodes are all recorded in parallel, each in its own copy of the environment to make use of the improved | |
speed of parallel model predictions. | |
:param model: The model to control the quadrotor. We use it to make deterministic predictions. | |
:param env_params: The parameters of the environment to record the episodes in. Refer to the documentation of | |
quadrotor_environment.simulation_parameter_randomization.SPRWrappedQuadrotorEnvironment for allowed values. | |
:param initial_states: A list containing an initial state for each episode to generate. If the list contains | |
something else than initial states, then a random state is generated using the specified environment. | |
:param num_steps: The number of steps to simulate forward for each episode. | |
:param spr_seeds: A list of seeds to pass to the SPRWrappedQuadrotorEnvironment | |
:return: A list of structured numpy array containing a trajectory. Refer to trajectory_generation.trajectory_dtype | |
for the actual structure. | |
""" | |
# Use 0 seed if spr_seeds is None | |
if spr_seeds is None: | |
spr_seeds = [0 for _ in initial_states] | |
assert(len(spr_seeds) == len(initial_states)) | |
# Generate an environment for each trajectory to generate and reset it 'all_env_observations' is a list containing | |
# the list of the current observation for each environment | |
envs = [SPRWrappedQuadrotorEnvironment(**env_params) for _ in initial_states] | |
if isinstance(initial_states[0], SysState): | |
all_env_observations = [env.reset(initial_state, spr_seed=seed) for | |
env, initial_state, seed in zip(envs, initial_states, spr_seeds)] | |
else: | |
all_env_observations = [env.reset(spr_seed=seed) for env, seed in zip(envs, spr_seeds)] | |
# Set up list structured numpy arrays to contain the trajectories and set the initial state | |
numpy_trajectories = [np.empty(num_steps +1,dtype=trajectory_dtype) for _ in initial_states] | |
for env, numpy_trajectory in zip(envs, numpy_trajectories): | |
state = env.get_current_state() | |
numpy_trajectory[0]['time'] = env.time | |
numpy_trajectory[0]['action'] = np.zeros(4) | |
numpy_trajectory[0]['state']['position'] = state.position | |
numpy_trajectory[0]['state']['velocity'] = state.velocity | |
numpy_trajectory[0]['state']['rotation'] = state.rotation | |
numpy_trajectory[0]['state']['angular_velocity'] = state.angular_velocity | |
numpy_trajectory[0]['state']['propeller_speed'] = state.propeller_speed | |
# Simulate all trajectories in parallel | |
for step in range(num_steps): | |
# Do one step in each environment | |
all_env_actions = model.predict(all_env_observations, deterministic=True)[0] | |
all_env_observations = [] | |
for env, action, numpy_trajectory in zip(envs, all_env_actions, numpy_trajectories): | |
observation, reward, done, _ = env.step(action) | |
all_env_observations.append(observation) | |
env.unwrapped.observation_history.prune_history() # HACK: this prevents our memory from exploding | |
state = env.get_current_state() | |
numpy_trajectory[step+1]['time'] = env.time | |
numpy_trajectory[step+1]['action'] = action | |
numpy_trajectory[step+1]['state']['position'] = state.position | |
numpy_trajectory[step+1]['state']['velocity'] = state.velocity | |
numpy_trajectory[step+1]['state']['rotation'] = state.rotation | |
numpy_trajectory[step+1]['state']['angular_velocity'] = state.angular_velocity | |
numpy_trajectory[step+1]['state']['propeller_speed'] = state.propeller_speed | |
return numpy_trajectories |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment