Skip to content

Instantly share code, notes, and snippets.

@NTT123
Created September 13, 2024 13:56
Show Gist options
  • Save NTT123/76193b931a72e7ab3810f143ac97b020 to your computer and use it in GitHub Desktop.
Save NTT123/76193b931a72e7ab3810f143ac97b020 to your computer and use it in GitHub Desktop.
Inplace RoPE inference kernel
"""
RoPE triton kernel
"""
import triton
import triton.language as tl
@triton.jit
def _rope_kernel(
x_ptr, x_row_stride, x_head_stride,
r_ptr, r_row_stride,
N, H: tl.constexpr, D: tl.constexpr,
):
row_start = tl.program_id(0).to(tl.int64)
row_step = tl.num_programs(0)
cols = tl.arange(0, D//2)
for row_idx in tl.range(row_start, N, step=row_step):
# load r to SRAM
r_row = tl.load( r_ptr + row_idx * r_row_stride + cols, mask=None).to(tl.float32)
cos = tl.cos(r_row)
sin = tl.sin(r_row)
for head_idx in tl.range(0, H):
x_row_0 = tl.load(x_ptr + row_idx * x_row_stride + head_idx *x_head_stride + cols * 2 + 0, mask=None)
x_row_1 = tl.load(x_ptr + row_idx * x_row_stride + head_idx *x_head_stride + cols * 2 + 1, mask=None)
output_0 = x_row_0 * cos - x_row_1 * sin
output_1 = x_row_0 * sin + x_row_1 * cos
output = tl.interleave(output_0, output_1)
tl.store(
x_ptr + row_idx * x_row_stride + head_idx * x_head_stride + tl.arange(0, D),
output.to(x_row_0.dtype),
mask=None
)
def rope(x, r):
# x is the input, N X H x D
# r is the rotation angle, N x D/2
shape = x.shape
x = x.view(-1, shape[-2], shape[-1])
r = r.view(-1, r.shape[-1])
N, H, D = x.shape
N1, D1 = r.shape
assert D1 * 2 == D
assert N == N1
M = max(1, N//32)
_rope_kernel[(M,)](
x, x.stride(0), x.stride(1),
r, r.stride(0),
N=N,
H=H,
D=D,
)
return x.view(*shape)
if __name__ == "__main__":
import torch
# Set default device to CUDA and default dtype to bfloat16
torch.set_default_device('cuda')
torch.set_default_dtype(torch.bfloat16)
# Create example input tensors
N, H, D = 128, 32, 128 # Batch size, number of heads, embedding dimension
x = torch.randn(N, H, D) # Input tensor
print(f"First few values of input: {x[0, 0, :10]}")
r = torch.randn(N, D // 2) # Rotation angles
# Apply RoPE (Rotary Position Embedding)
rotated_x = rope(x, r)
# Synchronize CUDA operations
torch.cuda.synchronize()
print(f"Input shape: {x.shape}")
print(f"Rotated output shape: {rotated_x.shape}")
print(f"First few values of rotated output: {rotated_x[0, 0, :10]}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment