Skip to content

Instantly share code, notes, and snippets.

@cipher982
Created January 3, 2019 20:30
Show Gist options
  • Save cipher982/6bf4474b7a1daee6656d968a1e269a23 to your computer and use it in GitHub Desktop.
Save cipher982/6bf4474b7a1daee6656d968a1e269a23 to your computer and use it in GitHub Desktop.
learn and soft update steps for the DQN
def learn(self, experiences, gamma):
"""
Update value parameters using given batch of experience tuples
Params
======
experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
gamma (float): discount factor
"""
states, actions, rewards, next_states, dones = experiences
# Get max predicted Q values (for next state) from target model
Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)
# Compute Q targets for current states
Q_targets = rewards + (gamma * Q_targets_next * (1 - dones))
# Get expected Q values from local model
Q_expected = self.qnetwork_local(states).gather(1, actions)
# Compute loss
loss = F.mse_loss(Q_expected, Q_targets)
# Minimize the loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# Update target network
self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU)
def soft_update(self, local_model, target_model, tau):
"""
Soft update model parameters
θ_target = τ*θ_local + (1 - τ)*θ_target
Params
======
local_model (PyTorch model): weights originate from
target_model (PyTorch model): weights will be send to
tau (float): interpolation parameter
"""
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment