Skip to content

Instantly share code, notes, and snippets.

@cipher982
Created January 3, 2019 20:22
Show Gist options
  • Save cipher982/79d363ac24f8323286ebe36c0a285ff5 to your computer and use it in GitHub Desktop.
Save cipher982/79d363ac24f8323286ebe36c0a285ff5 to your computer and use it in GitHub Desktop.
replay buffer for pytorch DQN
def sample(self):
"""Randomly sample a batch of experiences from memory"""
experiences = random.sample(self.memory, k=self.batch_size)
states = torch.from_numpy(np.vstack([e.state for e in experiences if e is not None])).float().to(device)
actions = torch.from_numpy(np.vstack([e.action for e in experiences if e is not None])).long().to(device)
rewards = torch.from_numpy(np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)
dones = torch.from_numpy(np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)
return (states, actions, rewards, next_states, dones)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment