Created
January 3, 2019 20:22
-
-
Save cipher982/79d363ac24f8323286ebe36c0a285ff5 to your computer and use it in GitHub Desktop.
replay buffer for pytorch DQN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def sample(self): | |
"""Randomly sample a batch of experiences from memory""" | |
experiences = random.sample(self.memory, k=self.batch_size) | |
states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device) | |
actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device) | |
rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device) | |
next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device) | |
dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device) | |
return (states, actions, rewards, next_states, dones) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment