Skip to content

Instantly share code, notes, and snippets.

@epignatelli
Created November 18, 2020 18:20
Show Gist options
  • Save epignatelli/b97f26a90d60460b85c0b011bc13e5a9 to your computer and use it in GitHub Desktop.
Save epignatelli/b97f26a90d60460b85c0b011bc13e5a9 to your computer and use it in GitHub Desktop.
A template for buiding deep NNs with stax
import jax
import jax.numpy as jnp
@jax.jit
def compute_loss(y_hat: jnp.ndarray, y: jnp.ndarray):
raise NotImplementedError
@partial(jax.jit, static_argnums=0)
def forward(model, params, x, y):
y_hat = model.apply(params, x)
return (compute_loss(y_hat, y), y_hat)
@partial(jax.jit, static_argnums=0)
def backward(model, params, x, y):
return jax.value_and_grad(forward, argnums=1, has_aux=True)(model, params, x, y)
@partial(jax.jit, static_argnums=(0, 1))
def update(model, optimiser, iteration, optimiser_state, x, y):
params = optimiser.params_fn(optimiser_state)
(loss, y_hat), gradients = backward(model, params, x, y)
return loss, y_hat, optimiser.update_fn(iteration, gradients, optimiser_state)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment