This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Reset environment to initial state | |
state, info = env.reset() | |
# Cycle through 50 steps redering and displaying environment state each time | |
for _ in range(50): | |
# Render and display current state of the environment | |
plt.imshow(env.render()) # render current state and pass to pyplot | |
plt.axis('off') | |
display.display(plt.gcf()) # get current figure and display |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Evaluate | |
n_eval_episodes=10000 | |
mean_reward, std_reward, episode_rewards = evaluate_agent(n_max_steps, n_eval_episodes, Qtable) | |
# Print evaluation results | |
print(f"Mean Reward = {mean_reward:.2f} +/- {std_reward:.2f}") | |
print(f"Min = {min(episode_rewards):.1f} and Max {max(episode_rewards):.1f}") | |
# Show the distribution of rewards obtained from evaluation | |
plt.figure(figsize=(9,6), dpi=200) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def evaluate_agent(n_max_steps, n_eval_episodes, Qtable): | |
# Initialize an empty list to store rewards for each episode | |
episode_rewards=[] | |
# Evaluate for each episode | |
for episode in range(n_eval_episodes): | |
# Reset the environment at the start of each episode | |
state, info = env.reset() | |
t = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Train | |
Qtable = train(n_episodes, n_max_steps, start_epsilon, min_epsilon, decay_rate, Qtable) | |
# Show Q-table | |
Qtable |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(n_episodes, n_max_steps, start_epsilon, min_epsilon, decay_rate, Qtable): | |
for episode in range(n_episodes): | |
# Reset the environment at the start of each episode | |
state, info = env.reset() | |
t = 0 | |
done = False | |
# Calculate epsilon value based on decay rate | |
epsilon = max(min_epsilon, (start_epsilon - min_epsilon)*np.exp(-decay_rate*episode)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# This is our acting policy (epsilon-greedy), which selects an action for exploration/exploitation during training | |
def epsilon_greedy(Qtable, state, epsilon): | |
# Generate a random number and compare to epsilon, if lower then explore, otherwise exploit | |
randnum = np.random.uniform(0, 1) | |
if randnum < epsilon: | |
action = env.action_space.sample() # explore | |
else: | |
action = np.argmax(Qtable[state, :]) # exploit | |
return action |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Initial Q-table | |
# Our Q-table is a matrix of state(observation) space x action space, i.e., 500 x 6 | |
Qtable = np.zeros((env.observation_space.n, env.action_space.n)) | |
# Show | |
Qtable |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# SARSA parameters | |
alpha = 0.1 # learning rate | |
gamma = 0.95 # discount factor | |
# Training parameters | |
n_episodes = 100000 # number of episodes to use for training | |
n_max_steps = 100 # maximum number of steps per episode | |
# Exploration / Exploitation parameters | |
start_epsilon = 1.0 # start training by selecting purely random actions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Reset environment to initial state | |
state, info = env.reset() | |
# Cycle through 30 random steps redering and displaying the agent inside the environment each time | |
for _ in range(30): | |
# Render and display current state of the environment | |
plt.imshow(env.render()) # render current state and pass to pyplot | |
plt.axis('off') | |
display.display(plt.gcf()) # get current figure and display | |
display.clear_output(wait=True) # clear output before showing the next frame |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Show environment description (map) as an array | |
print("Environment Array: ") | |
print(env.desc) | |
# Observation and action space | |
state_obs_space = env.observation_space # Returns sate(observation) space of the environment. | |
action_space = env.action_space # Returns action space of the environment. | |
print("State(Observation) space:", state_obs_space) | |
print("Action space:", action_space) |
NewerOlder