Skip to content

Instantly share code, notes, and snippets.

@davidADSP
Created December 1, 2019 21:48
Show Gist options
  • Save davidADSP/5cf07b7f66848dca56e1e9257247ca50 to your computer and use it in GitHub Desktop.
Save davidADSP/5cf07b7f66848dca56e1e9257247ca50 to your computer and use it in GitHub Desktop.
def update_weights(optimizer: tf.train.Optimizer, network: Network, batch,
weight_decay: float):
loss = 0
for image, actions, targets in batch:
# Initial step, from the real observation.
value, reward, policy_logits, hidden_state = network.initial_inference(
image)
predictions = [(1.0, value, reward, policy_logits)]
# Recurrent steps, from action and previous hidden state.
for action in actions:
value, reward, policy_logits, hidden_state = network.recurrent_inference(
hidden_state, action)
predictions.append((1.0 / len(actions), value, reward, policy_logits))
hidden_state = tf.scale_gradient(hidden_state, 0.5)
for prediction, target in zip(predictions, targets):
gradient_scale, value, reward, policy_logits = prediction
target_value, target_reward, target_policy = target
l = (
scalar_loss(value, target_value) +
scalar_loss(reward, target_reward) +
tf.nn.softmax_cross_entropy_with_logits(
logits=policy_logits, labels=target_policy))
loss += tf.scale_gradient(l, gradient_scale)
for weights in network.get_weights():
loss += weight_decay * tf.nn.l2_loss(weights)
optimizer.minimize(loss)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment