Skip to content

Instantly share code, notes, and snippets.

@ejmejm
Created July 23, 2024 17:42
Show Gist options
  • Save ejmejm/3cf43081a457270912d31a6de1a500a9 to your computer and use it in GitHub Desktop.
Save ejmejm/3cf43081a457270912d31a6de1a500a9 to your computer and use it in GitHub Desktop.
Test of vmapped matrix multiplication vs. batched matrix multiplication
from equinox import nn
import jax
import jax.numpy as jnp
#########################################
# Test 1 #
#########################################
n = 100
a = jax.random.normal(jax.random.PRNGKey(0), (n, 256))
b = jax.random.normal(jax.random.PRNGKey(1), (256, 512))
matmul = jax.jit(jnp.matmul)
batch_matmul = jax.jit(jax.vmap(jnp.matmul, in_axes=(0, None)))
# Compile
jax.block_until_ready(matmul(jnp.ones_like(a), jnp.ones_like(b)))
jax.block_until_ready(batch_matmul(jnp.ones_like(a), jnp.ones_like(b)))
# Time each function
%timeit jax.block_until_ready(matmul(a, b))
%timeit jax.block_until_ready(batch_matmul(a, b))
# Output:
# 180 µs ± 44.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 214 µs ± 15.9 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
#########################################
# Test 2 #
#########################################
n = 100
a = jax.random.normal(jax.random.PRNGKey(2), (n, 256))
b = jax.random.normal(jax.random.PRNGKey(3), (n, 256, 512))
batch_matmul = jax.jit(jax.lax.batch_matmul)
vmap_batch_matmul = jax.jit(jax.vmap(jnp.matmul))
# Compile
jax.block_until_ready(batch_matmul(jnp.ones_like(a)[:, None, :], jnp.ones_like(b)))
jax.block_until_ready(vmap_batch_matmul(jnp.ones_like(a), jnp.ones_like(b)))
# Time each function
%timeit jax.block_until_ready(batch_matmul(a[:, None, :], b))
%timeit jax.block_until_ready(vmap_batch_matmul(a, b))
# Output:
# 366 µs ± 37.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 1.25 ms ± 44.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
#########################################
# Test 3 #
#########################################
n = 100
a = jax.random.normal(jax.random.PRNGKey(4), (n, 256))
linear = nn.Linear(256, 512, key=jax.random.PRNGKey(5))
def batch_linear(a, weight, bias):
return a @ weight.T + bias
linear_fn = jax.jit(batch_linear)
vmap_linear_fn = jax.jit(jax.vmap(linear.__call__))
# Compile
jax.block_until_ready(linear_fn(jnp.ones_like(a), jnp.ones_like(linear.weight), jnp.ones_like(linear.bias)))
jax.block_until_ready(vmap_linear_fn(jnp.ones_like(a)))
# Time each function
%timeit jax.block_until_ready(linear_fn(a, linear.weight, linear.bias))
%timeit jax.block_until_ready(vmap_linear_fn(a))
# Output:
# 254 µs ± 7.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# 251 µs ± 3.77 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment