This is an optimized implementation of RMSNorm inference kernel using Triton, a Python-based GPU programming library. This implementation is a modified version of the excellent RMSNorm kernel from the Unsloth project.
It has two improvements:
int64
for pointer offset: We useint64
instead of the defaultint32
to compute the pointer offset value. This change prevents overflow when dealing with large sequence lengths where the offset exceeds the maximumint32
value (2B).- In-place computation: Our kernel writes the result back to the input buffer, eliminating the need for additional memory allocation. This approach halves the memory usage compared to traditional implementations that use a separate output buffer.
import torch
import triton
import triton.language as tl
MAX_FUSED_SIZE = 65536
next_power_of_2 = triton.next_power_of_2
def calculate_settings(n):
BLOCK_SIZE = next_power_of_2(n)
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(
f"Cannot launch Triton kernel since n = {n} exceeds "
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}."
)
num_warps = 4
if BLOCK_SIZE >= 32768:
num_warps = 32
elif BLOCK_SIZE >= 8192:
num_warps = 16
elif BLOCK_SIZE >= 2048:
num_warps = 8
return BLOCK_SIZE, num_warps
@triton.jit
def _inplace_rms_norm(
X, X_row_stride, W, W_row_stride, n_cols, eps, BLOCK_SIZE: tl.constexpr
):
row_idx = tl.program_id(0).to(tl.int64)
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
X += row_idx * X_row_stride
X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
W_row = tl.load(W + col_offsets, mask=mask, other=0)
row_var = tl.sum(X_row * X_row, axis=0) / n_cols
inv_var = tl.math.rsqrt(row_var + eps)
normed = X_row * inv_var
output = normed * W_row
output = output.to(X.dtype.element_ty)
tl.store(X + col_offsets, output, mask=mask)
def inplace_rms_norm(X, W, eps):
shape = X.shape
dim = shape[-1]
X = X.view(-1, dim)
n_rows, n_cols = X.shape
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
_inplace_rms_norm[(n_rows,)](
X,
X.stride(0),
W,
W.stride(0),
n_cols,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return X.view(*shape)
Here's how you can use it:
batch_size = 1
seq_length = 1_000_000
hidden_size = 4096
X = torch.randn(
batch_size, seq_length, hidden_size, device="cuda", dtype=torch.bfloat16
)
W = torch.randn(hidden_size, device="cuda", dtype=torch.bfloat16)
eps = 1e-5
normalized_X = inplace_rms_norm(X, W, eps)