Skip to content

Instantly share code, notes, and snippets.

@minjang
Created September 13, 2024 05:26
Show Gist options
  • Save minjang/ceba1f33508bffd680b8cf53bb8bdbd9 to your computer and use it in GitHub Desktop.
Save minjang/ceba1f33508bffd680b8cf53bb8bdbd9 to your computer and use it in GitHub Desktop.
test.py
from typing import Optional, Union
import os
import numpy as np
import torch
import triton
import triton.language as tl
import math
from triton.runtime.jit import TensorWrapper, reinterpret
from numpy.random import RandomState
from triton.runtime.jit import reinterpret, TensorWrapper
torch.manual_seed(0)
int_dtypes = ["int8", "int16", "int32", "int64"]
uint_dtypes = ["uint8", "uint16", "uint32", "uint64"]
float_dtypes = ["float16", "float32", "float64"]
dtypes = int_dtypes + uint_dtypes + float_dtypes
dtypes_with_bfloat16 = dtypes + ["bfloat16"]
torch_float8_dtypes = ["float8_e4m3fn", "float8_e5m2"]
torch_dtypes = ["bool"] + int_dtypes + ["uint8"] + float_dtypes + ["bfloat16"]
def is_interpreter():
return os.environ.get('TRITON_INTERPRET', '0') == '1'
def is_cpu():
return not is_interpreter() and \
triton.runtime.driver.active.get_current_target().backend in ["cpu", "cpu_v2"]
def patch_kernel(template, to_replace):
if is_interpreter():
local_namespace = {}
src = textwrap.dedent(inspect.getsource(template.fn))
for k, v in to_replace.items():
src = src.replace(k, v)
exec(src, globals(), local_namespace)
return local_namespace[template.fn.__name__]
else:
kernel = triton.JITFunction(template.fn)
for key, value in to_replace.items():
kernel.src = kernel.src.replace(key, value)
return kernel
def to_triton(x: np.ndarray, device, dst_type=None) -> Union[TensorWrapper, torch.Tensor]:
"""
Note: We need dst_type because the type of x can be different from dst_type.
For example: x is of type `float32`, dst_type is `bfloat16`.
If dst_type is None, we infer dst_type from x.
"""
t = x.dtype.name
if t in uint_dtypes:
signed_type_name = t.lstrip("u") # e.g. "uint16" -> "int16"
x_signed = x.astype(getattr(np, signed_type_name))
return reinterpret(torch.tensor(x_signed, device=device), getattr(tl, t))
else:
if dst_type and "float8" in dst_type:
return reinterpret(torch.tensor(x, device=device), getattr(tl, dst_type))
if t == "float32" and dst_type == "bfloat16":
return torch.tensor(x, device=device).bfloat16()
return torch.tensor(x, device=device)
def to_numpy(x):
if isinstance(x, TensorWrapper):
return x.base.cpu().numpy().astype(getattr(np, torch_dtype_name(x.dtype)))
elif isinstance(x, torch.Tensor):
if x.dtype is torch.bfloat16:
return x.cpu().float().numpy()
return x.cpu().numpy()
else:
raise ValueError(f"Not a triton-compatible tensor: {x}")
def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None):
"""
Override `rs` if you're calling this function twice and don't want the same
result for both calls.
"""
if isinstance(shape, int):
shape = (shape, )
if rs is None:
rs = RandomState(seed=17)
if dtype_str in int_dtypes + uint_dtypes:
iinfo = np.iinfo(getattr(np, dtype_str))
low = iinfo.min if low is None else max(low, iinfo.min)
high = iinfo.max if high is None else min(high, iinfo.max)
dtype = getattr(np, dtype_str)
x = rs.randint(low, high, shape, dtype=dtype)
x[x == 0] = (
1 # Workaround. Never return zero so tests of division don't error out.
)
return x
elif dtype_str and "float8" in dtype_str:
x = rs.randint(20, 40, shape, dtype=np.int8)
return x
elif dtype_str in float_dtypes:
return rs.normal(0, 1, shape).astype(dtype_str)
elif dtype_str == "bfloat16":
return (rs.normal(0, 1, shape).astype("float32").view("uint32") & np.uint32(0xFFFF0000)).view("float32")
elif dtype_str in ["bool", "int1", "bool_"]:
return rs.normal(0, 1, shape) > 0.0
else:
raise RuntimeError(f"Unknown dtype {dtype_str}")
def test_add(dtype_str, BLOCK_SIZE, device, num_warps=1, experimental=False):
if device == "cpu":
triton.runtime.driver.set_active_to_cpu(experimental)
else:
triton.runtime.driver.set_active_to_gpu()
@triton.jit
def kernel_add(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x_ptrs = x_ptr + offsets
x = tl.load(x_ptrs, mask=mask)
y_ptrs = y_ptr + offsets
y = tl.load(y_ptrs, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
size = BLOCK_SIZE * 2 + 41
x = torch.arange(1, size + 1, dtype=getattr(torch, dtype_str), device=device)
y = torch.arange(1, size + 1, dtype=getattr(torch, dtype_str), device=device)
output = torch.zeros_like(x)
grid = (triton.cdiv(size, BLOCK_SIZE), )
kernel_add[grid](x, y, output, size, BLOCK_SIZE, num_warps=num_warps)
expected = x + y
print(f"add-{dtype_str}-{BLOCK_SIZE}-{device}-{num_warps}-{experimental}:", torch.allclose(output, expected))
if True:
# The current GPU and CPU
test_add("float32", 1024, "cuda")
# test_add("float32", 128, "cpu")
# A new experimental CPU
test_add("float32", 128, "cpu", experimental=True)
# We can run with multiple warps for a kernel that doesn't need of any inter-thread/warp communication.
# This effectively reduces the size of blocks, reducing compilation time.
# os.environ["TRITON_CPU_ALLOW_MULTI_WARPS"] = "1"
# test_add("float32", 128, "cpu", num_warps=4, experimental=True)
def test_softmax(shape, device, num_warps=1, experimental=False):
if device == "cpu":
triton.runtime.driver.set_active_to_cpu(experimental)
else:
triton.runtime.driver.set_active_to_gpu()
@triton.jit
def kernel_softmax(
output_ptr,
input_ptr,
input_row_stride,
output_row_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
x = tl.load(input_ptr)
y = tl.exp(x)
y = tl.sqrt(y)
tl.store(output_ptr, y)
# row_idx = tl.program_id(0)
# row_start_ptr = input_ptr + row_idx * input_row_stride
# col_offsets = tl.arange(0, BLOCK_SIZE)
# input_ptrs = row_start_ptr + col_offsets
# mask = col_offsets < n_cols
# row = tl.load(input_ptrs, mask=mask, other=-float("inf"))
# a = tl.sqrt(row)
# row_minus_max = row - tl.max(row, axis=0) + a
# numerator = tl.exp(row_minus_max)
# denominator = tl.sum(numerator, axis=0)
# softmax_output = numerator / denominator
# output_row_start_ptr = output_ptr + row_idx * output_row_stride
# output_ptrs = output_row_start_ptr + col_offsets
# tl.store(output_ptrs, softmax_output, mask=mask)
x = torch.randn(shape, device=device)
y = torch.empty_like(x)
n_rows, n_cols = x.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)
kernel_softmax[(n_rows, )](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
print(x)
print(y)
expected = torch.softmax(x, axis=1)
print(f"softmax-{shape}-{device}-{num_warps}-{experimental}:", torch.allclose(y, expected))
if False:
# test_softmax((1020, 59), "cuda")
# test_softmax((1020, 59), "cpu")
test_softmax((1020, 59), "cpu", num_warps=1, experimental=True)
# It fails with 4 warps due to incomplete reductions.
os.environ["TRITON_CPU_ALLOW_MULTI_WARPS"] = "1"
# test_softmax((1020, 59), "cpu", num_warps=4, experimental=True)
def test_transpose(device, num_warps=1, experimental=False):
if device == "cpu":
triton.runtime.driver.set_active_to_cpu(experimental)
else:
triton.runtime.driver.set_active_to_gpu()
SIZE_M = 8
SIZE_N = 4
@triton.jit
def kernel_2d_simple(Z, X, SIZE_M: tl.constexpr, SIZE_N: tl.constexpr):
off = tl.arange(0, SIZE_M)
off2d = off[None, :] + (tl.arange(0, SIZE_N) * SIZE_M)[:, None]
x = tl.load(X + off2d)
z = x + 42
tl.store(Z + off2d, z)
@triton.jit
def kernel_transpose(Z, X, SIZE_M: tl.constexpr, SIZE_N: tl.constexpr):
off = tl.arange(0, SIZE_M)
off2d = off[None, :] + (tl.arange(0, SIZE_N) * SIZE_M)[:, None]
x = tl.load(X + off2d)
z = x.T + 42
tl.store(Z + off2d.T, z)
x = torch.randint(low=0, high=8, size=(SIZE_M, SIZE_N), device=device)
z = torch.zeros_like(x.T)
kernel_transpose[(1, )](z, x, SIZE_M, SIZE_N, num_warps=num_warps)
expected = x.T + 42
print(f"transpose-{device}-{num_warps}-{experimental}:", torch.allclose(z, expected))
if False:
test_transpose("cuda")
test_transpose("cpu")
test_transpose("cpu", experimental=True)
# It crashes due to the shared memory.
# os.environ["TRITON_CPU_ALLOW_MULTI_WARPS"] = "1"
# test_transpose("cpu", num_warps=4, experimental=True)
def test_trans_4d(dtype_str, shape, perm, device, num_warps=1, experimental=False):
if device == "cpu":
triton.runtime.driver.set_active_to_cpu(experimental)
else:
triton.runtime.driver.set_active_to_gpu()
@triton.jit
def kernel(In, Out, #
in_shape1: tl.constexpr, in_shape2: tl.constexpr, in_shape3: tl.constexpr, in_shape4: tl.constexpr,
ou_shape1: tl.constexpr, ou_shape2: tl.constexpr, ou_shape3: tl.constexpr, ou_shape4: tl.constexpr,
trans1: tl.constexpr, trans2: tl.constexpr, trans3: tl.constexpr, trans4: tl.constexpr):
in_ptr = tl.make_block_ptr(
base=In,
shape=(in_shape1, in_shape2, in_shape3, in_shape4),
strides=(in_shape4 * in_shape3 * in_shape2, in_shape4 * in_shape3, in_shape4, 1),
offsets=(0, 0, 0, 0),
block_shape=(in_shape1, in_shape2, in_shape3, in_shape4),
order=(3, 2, 1, 0),
)
out_ptr = tl.make_block_ptr(
base=Out,
shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4),
strides=(ou_shape4 * ou_shape3 * ou_shape2, ou_shape4 * ou_shape3, ou_shape4, 1),
offsets=(0, 0, 0, 0),
block_shape=(ou_shape1, ou_shape2, ou_shape3, ou_shape4),
order=(3, 2, 1, 0),
)
tl.store(out_ptr, tl.load(in_ptr).permute((trans1, trans2, trans3, trans4)))
input = torch.arange(math.prod(shape), dtype=getattr(torch, dtype_str), device=device).reshape(shape)
expected = torch.permute(input, perm)
# Don't do zeros_like -- that copies the layout, which we don't want.
actual = torch.zeros(expected.shape, dtype=getattr(torch, dtype_str), device=device)
kernel[(1, )](input, actual, *shape, *[shape[i] for i in perm], *perm, num_warps=num_warps)
print(f"trans4d-{shape}-{perm}-{device}-{num_warps}-{experimental}:", torch.allclose(actual, expected))
if False:
# test_trans_4d('int32', (2, 2, 2, 2), (3, 1, 2, 0), 'cuda')
# test_trans_4d('int32', (2, 2, 2, 2), (3, 1, 2, 0), 'cpu')
test_trans_4d('int32', (2, 2, 2, 2), (3, 1, 2, 0), 'cpu', experimental=True)
# os.environ["TRITON_CPU_ALLOW_MULTI_WARPS"] = "1"
# test_trans_4d('int32', (2, 2, 2, 2), (3, 1, 2, 0), 'cpu', num_warps=4, experimental=True)
def test_dot_without_load(dtype_str, device, num_warps=1, experimental=False):
if device == "cpu":
triton.runtime.driver.set_active_to_cpu(experimental)
else:
triton.runtime.driver.set_active_to_gpu()
SIZE = 16
@triton.jit
def _kernel_dot_no_load(out, SIZE: tl.constexpr):
a = GENERATE_TEST_HERE
b = GENERATE_TEST_HERE
c = tl.dot(a, b)
out_ptr = out + tl.arange(0, SIZE)[:, None] * SIZE + tl.arange(0, SIZE)[None, :]
tl.store(out_ptr, c)
kernel = patch_kernel(_kernel_dot_no_load, {'GENERATE_TEST_HERE': f"tl.full((SIZE, SIZE), 1.0, tl.{dtype_str})"})
a = torch.ones((SIZE, SIZE), dtype=getattr(torch, dtype_str), device=device)
b = torch.ones((SIZE, SIZE), dtype=getattr(torch, dtype_str), device=device)
if is_cpu() and dtype_str == "float16":
# torch.matmul not implemented for Half float (float16) cpu
out_ref = torch.tensor(np.matmul(to_numpy(a), to_numpy(b)), dtype=getattr(torch, dtype_str), device=device)
else:
out_ref = torch.matmul(a, b)
out = torch.zeros((SIZE, SIZE), dtype=getattr(torch, dtype_str), device=device)
kernel[(1, )](out, SIZE, num_warps=num_warps)
print(f"dot-wo-load-{device}-{num_warps}-{experimental}:", torch.allclose(out, out_ref))
# test_dot_without_load('float32', 'cpu')
# test_dot_without_load('float32', 'cpu', experimental=True)
def test_atomic_cas(sem, num_ctas, device, experimental=False):
if device == "cpu":
triton.runtime.driver.set_active_to_cpu(experimental)
else:
triton.runtime.driver.set_active_to_gpu()
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
def change_value(Lock):
tl.atomic_cas(Lock, 0, 1)
Lock = torch.zeros((1, ), device=device, dtype=torch.int32)
change_value[(1, )](Lock)
assert (Lock[0] == 1)
# 2. only one block enters the critical section
@triton.jit
def serialized_add(data, Lock, SEM: tl.constexpr):
ptrs = data + tl.arange(0, 16)
while tl.atomic_cas(Lock, 0, 1, SEM) == 1:
pass
tl.store(ptrs, tl.load(ptrs) + 1.0)
# insert barrier to set a fence between tl.store and
# tl.atomic_xchg in a block.
tl.debug_barrier()
# release lock
tl.atomic_xchg(Lock, 0)
Lock = torch.zeros((1, ), device=device, dtype=torch.int32)
data = torch.zeros((16, ), device=device, dtype=torch.float32)
ref = torch.full((16, ), 2000.0)
h = serialized_add[(2000, )](data, Lock, SEM=sem, num_ctas=num_ctas)
sem_str = "acq_rel" if sem is None else sem
np.testing.assert_allclose(to_numpy(data), to_numpy(ref))
# test_atomic_cas("acq_rel", 1, "cuda")
# test_atomic_cas("acq_rel", 1, "cpu", experimental=True)
def test_tensor_atomic_cas(sem, num_ctas, device, experimental=False):
if device == "cpu":
triton.runtime.driver.set_active_to_cpu(experimental)
else:
triton.runtime.driver.set_active_to_gpu()
@triton.jit
def change_value(X, BLOCK_SIZE: tl.constexpr, sem: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
t1 = tl.full((BLOCK_SIZE, ), 0, dtype=tl.int64)
t2 = tl.full((BLOCK_SIZE, ), 2, dtype=tl.int64)
tl.atomic_cas(X + offsets, t1, t2, sem=sem)
X = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], device=device, dtype=torch.int64)
Y = torch.tensor([2, 1, 2, 1, 2, 1, 2, 1], device=device, dtype=torch.int64)
change_value[(2, )](X, 4, sem)
assert (torch.equal(X, Y))
# test_tensor_atomic_cas("acq_rel", 1, "cpu", experimental=True)
def test_atomic_rmw(op, dtype_x_str, mode, sem, device, experimental=False):
if device == "cpu":
triton.runtime.driver.set_active_to_cpu(experimental)
else:
triton.runtime.driver.set_active_to_gpu()
n_programs = 5
# triton kernel
@triton.jit
def kernel(X, Z):
pid = tl.program_id(0)
x = tl.load(X + pid)
old = GENERATE_TEST_HERE
tl.static_assert(old.dtype == x.dtype)
sem_arg = sem if sem is None else f'"{sem}"'
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'})
numpy_op = {'add': np.sum, 'max': np.max, 'min': np.min}[op]
max_neutral = float('-inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).min
min_neutral = float('inf') if dtype_x_str in float_dtypes else np.iinfo(getattr(np, dtype_x_str)).max
neutral = {'add': 0, 'max': max_neutral, 'min': min_neutral}[op]
# triton result
rs = RandomState(17)
# x = np.array([2**i for i in range(n_programs)], dtype=getattr(np, dtype_x_str))
x = np.array([3, 2, 4, 16, -8], dtype=getattr(np, dtype_x_str))
print(">>>>>>>>>>>", x)
if mode == 'all_neg':
x = -np.abs(x)
if mode == 'all_pos':
x = np.abs(x)
if mode == 'min_neg':
idx = rs.randint(n_programs, size=(1, )).item()
x[idx] = -np.max(np.abs(x)) - 1
if mode == 'max_pos':
idx = rs.randint(n_programs, size=(1, )).item()
x[idx] = np.max(np.abs(x)) + 1
x_tri = to_triton(x, device=device)
z_tri = to_triton(np.array([neutral], dtype=getattr(np, dtype_x_str)), device=device)
print(">>>>>>>>>>>", x_tri)
print(">>>>>>>>>>>", z_tri)
h = kernel[(5, )](x_tri, z_tri)
# torch result
z_ref = numpy_op(x).astype(getattr(np, dtype_x_str))
# compare
exact = op not in ['add']
print(">>>>>>>>>>>", z_ref)
print(">>>>>>>>>>>", z_tri)
if exact:
assert z_ref.item() == to_numpy(z_tri).item()
else:
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
if False:
test_atomic_rmw("max", "float32", "all_pos", "acq_rel", "cuda")
# test_atomic_rmw("max", "float32", "all_pos", "acq_rel", "cpu", False)
# test_atomic_rmw("max", "float32", "all_pos", "acq_rel", "cpu", True)
test_atomic_rmw("max", "float64", "all_pos", "acq_rel", "cpu", True)
# test_atomic_rmw("max", "int32", "all_pos", "acq_rel", "cpu", True)
def test_atomic_cas(sem, num_ctas, device, experimental=False):
if device == "cpu":
triton.runtime.driver.set_active_to_cpu(experimental)
else:
triton.runtime.driver.set_active_to_gpu()
# 1. make sure that atomic_cas changes the original value (Lock)
@triton.jit
def change_value(Lock):
tl.atomic_cas(Lock, 0, 1)
Lock = torch.zeros((1, ), device=device, dtype=torch.int32)
change_value[(1, )](Lock)
assert (Lock[0] == 1)
# 2. only one block enters the critical section
@triton.jit
def serialized_add(data, Lock, SEM: tl.constexpr):
ptrs = data + tl.arange(0, 16)
while tl.atomic_cas(Lock, 0, 1, SEM) == 1:
pass
tl.store(ptrs, tl.load(ptrs) + 1.0)
# insert barrier to set a fence between tl.store and
# tl.atomic_xchg in a block.
tl.debug_barrier()
# release lock
tl.atomic_xchg(Lock, 0)
Lock = torch.zeros((1, ), device=device, dtype=torch.int32)
data = torch.zeros((16, ), device=device, dtype=torch.float32)
ref = torch.full((16, ), 2000.0)
h = serialized_add[(2000, )](data, Lock, SEM=sem, num_ctas=num_ctas)
sem_str = "acq_rel" if sem is None else sem
np.testing.assert_allclose(to_numpy(data), to_numpy(ref))
# test_atomic_cas("acq_rel", 1, "cpu", True)
def test_tensor_atomic_rmw_block(device, experimental=False):
torch.manual_seed(2024)
if device == "cpu":
triton.runtime.driver.set_active_to_cpu(experimental)
else:
triton.runtime.driver.set_active_to_gpu()
shape = (4, 4)
@triton.jit
def kernel(X, Y, cut, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
off0 = tl.arange(0, SHAPE0)
off1 = tl.arange(0, SHAPE1)
offs = off0[:, None] * SHAPE1 + off1[None, :]
y = tl.load(Y + offs)
x = X + offs
tl.atomic_min(x, y, mask=offs < cut)
x = torch.randint(low=1, high=16, size=shape, dtype=torch.float32)
# Conver to gpu if necessary.
x = x.to(device)
x_copy = x.clone()
y = torch.full(shape, 7, device=device, dtype=torch.float32)
kernel[(2, )](x, y, 12, shape[0], shape[1], num_ctas=1)
expected = torch.minimum(x, y)
expected[3] = x_copy[3]
assert torch.allclose(x, expected)
if False:
# test_tensor_atomic_rmw_block("cuda")
# test_tensor_atomic_rmw_block("cpu")
test_tensor_atomic_rmw_block("cpu", experimental=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment