Last active
August 25, 2023 13:45
-
-
Save emrul/38fcade5eb4f5e6f639fe8273f5da3f7 to your computer and use it in GitHub Desktop.
Tianshou DQN with Temporarlly-extended epsilon greedy exploration
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Union | |
from tianshou.data import Batch | |
from tianshou.policy import RainbowPolicy | |
import numpy as np | |
import argparse | |
# See https://arxiv.org/abs/2006.01782 for paper - Temporally-Extended ε-Greedy Exploration | |
# See https://www.youtube.com/watch?v=Gi_B0IqscBE for video explaining where it's helpful | |
# See https://github.com/thu-ml/tianshou/blob/master/examples/atari/atari_rainbow.py for full example of how to setup args/net/collectors/etc. | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
# ... | |
return parser.parse_args() | |
def run_rainbow(args=get_args()): | |
# ... | |
RainbowPolicy.exploration_noise = ez_greedy_exploration_noise | |
# net = Rainbow(... | |
# define policy | |
policy = RainbowPolicy( | |
net, | |
optim, | |
args.gamma, | |
args.num_atoms, | |
args.v_min, | |
args.v_max, | |
args.n_step, | |
target_update_freq=args.target_update_freq | |
).to(args.device) | |
# ... | |
def generate_action_sequence(num_sequences, num_actions, min_repeat=2, max_repeat=20): | |
""" | |
Generate an array of shape (num_sequences, 2) where: | |
- First column: random integer between min_repeat and max_repeat (inclusive), representing num_repeats | |
- Second column: random integer between 0 and num_actions, representing action | |
Parameters: | |
- num_sequences: Number of sequences (rows) to generate | |
Returns: | |
- Numpy array of shape (num_sequences, 2) | |
""" | |
num_repeats = np.random.randint(min_repeat, max_repeat+1, size=num_sequences).reshape(-1, 1) | |
actions = np.random.randint(0, num_actions, size=num_sequences).reshape(-1, 1) | |
return np.hstack([num_repeats, actions]) | |
def ez_greedy_exploration_noise( | |
self, | |
act: Union[np.ndarray, Batch], | |
batch: Batch, | |
) -> Union[np.ndarray, Batch]: | |
if isinstance(act, np.ndarray) and not np.isclose(self.eps, 0.0): | |
bsz = len(act) | |
if hasattr(self, "ez_greedies"): | |
ez_greedies = self.ez_greedies | |
else: | |
ez_greedies = generate_action_sequence(bsz, self.max_action_num) | |
valid_repeats = ez_greedies[:, 0] > 0 | |
if not any(valid_repeats) and np.random.random() < self.eps: | |
ez_greedies = generate_action_sequence(bsz, self.max_action_num) | |
valid_repeats = ez_greedies[:, 0] > 0 | |
if np.any(valid_repeats): # if we have any n | |
indices = np.arange(bsz)[valid_repeats] | |
if hasattr(batch.obs, "mask"): | |
actions_to_validate = ez_greedies[indices, 1] | |
actions_valid = batch.obs.mask[indices, actions_to_validate] | |
invalid_action_indices = indices[np.where(~actions_valid)[0]] | |
for idx in invalid_action_indices: | |
available_actions = np.where(batch.obs.mask[idx])[0] | |
assert available_actions.size > 0, f"No available actions: {available_actions}" | |
ez_greedies[idx, 1] = np.random.choice(available_actions) # Replace with a random valid action | |
act[indices] = ez_greedies[indices, 1] | |
ez_greedies[indices, 0] -= 1 | |
self.ez_greedies = ez_greedies | |
return act | |
if __name__ == "__main__": | |
run_rainbow(get_args()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment