Skip to content

Instantly share code, notes, and snippets.

@NTT123
Last active September 12, 2024 05:25
Show Gist options
  • Save NTT123/844001a21520d8db0c955e8e2f3e0b33 to your computer and use it in GitHub Desktop.
Save NTT123/844001a21520d8db0c955e8e2f3e0b33 to your computer and use it in GitHub Desktop.
Inplace RMSNorm Implementation

This is an optimized implementation of RMSNorm inference kernel using Triton, a Python-based GPU programming library. This implementation is a modified version of the excellent RMSNorm kernel from the Unsloth project.

It has two improvements:

  • int64 for pointer offset: We use int64 instead of the default int32 to compute the pointer offset value. This change prevents overflow when dealing with large sequence lengths where the offset exceeds the maximum int32 value (2B).
  • In-place computation: Our kernel writes the result back to the input buffer, eliminating the need for additional memory allocation. This approach halves the memory usage compared to traditional implementations that use a separate output buffer.
import torch
import triton
import triton.language as tl

MAX_FUSED_SIZE = 65536
next_power_of_2 = triton.next_power_of_2


def calculate_settings(n):
    BLOCK_SIZE = next_power_of_2(n)
    if BLOCK_SIZE > MAX_FUSED_SIZE:
        raise RuntimeError(
            f"Cannot launch Triton kernel since n = {n} exceeds "
            f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}."
        )
    num_warps = 4
    if BLOCK_SIZE >= 32768:
        num_warps = 32
    elif BLOCK_SIZE >= 8192:
        num_warps = 16
    elif BLOCK_SIZE >= 2048:
        num_warps = 8
    return BLOCK_SIZE, num_warps


@triton.jit
def _inplace_rms_norm(
    X, X_row_stride, W, W_row_stride, n_cols, eps, BLOCK_SIZE: tl.constexpr
):
    row_idx = tl.program_id(0).to(tl.int64)
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols

    X += row_idx * X_row_stride

    X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)
    W_row = tl.load(W + col_offsets, mask=mask, other=0)

    row_var = tl.sum(X_row * X_row, axis=0) / n_cols
    inv_var = tl.math.rsqrt(row_var + eps)
    normed = X_row * inv_var
    output = normed * W_row
    output = output.to(X.dtype.element_ty)
    tl.store(X + col_offsets, output, mask=mask)


def inplace_rms_norm(X, W, eps):
    shape = X.shape
    dim = shape[-1]
    X = X.view(-1, dim)
    n_rows, n_cols = X.shape
    BLOCK_SIZE, num_warps = calculate_settings(n_cols)

    _inplace_rms_norm[(n_rows,)](
        X,
        X.stride(0),
        W,
        W.stride(0),
        n_cols,
        eps,
        BLOCK_SIZE=BLOCK_SIZE,
        num_warps=num_warps,
    )
    return X.view(*shape)

Here's how you can use it:

batch_size = 1
seq_length = 1_000_000
hidden_size = 4096
X = torch.randn(
    batch_size, seq_length, hidden_size, device="cuda", dtype=torch.bfloat16
)
W = torch.randn(hidden_size, device="cuda", dtype=torch.bfloat16)
eps = 1e-5

normalized_X = inplace_rms_norm(X, W, eps)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment