Created
January 3, 2019 20:30
-
-
Save cipher982/6bf4474b7a1daee6656d968a1e269a23 to your computer and use it in GitHub Desktop.
learn and soft update steps for the DQN
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def learn(self, experiences, gamma): | |
""" | |
Update value parameters using given batch of experience tuples | |
Params | |
====== | |
experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples | |
gamma (float): discount factor | |
""" | |
states, actions, rewards, next_states, dones = experiences | |
# Get max predicted Q values (for next state) from target model | |
Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1) | |
# Compute Q targets for current states | |
Q_targets = rewards + (gamma * Q_targets_next * (1 - dones)) | |
# Get expected Q values from local model | |
Q_expected = self.qnetwork_local(states).gather(1, actions) | |
# Compute loss | |
loss = F.mse_loss(Q_expected, Q_targets) | |
# Minimize the loss | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
# Update target network | |
self.soft_update(self.qnetwork_local, self.qnetwork_target, TAU) | |
def soft_update(self, local_model, target_model, tau): | |
""" | |
Soft update model parameters | |
θ_target = τ*θ_local + (1 - τ)*θ_target | |
Params | |
====== | |
local_model (PyTorch model): weights originate from | |
target_model (PyTorch model): weights will be send to | |
tau (float): interpolation parameter | |
""" | |
for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): | |
target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment