Skip to content

Instantly share code, notes, and snippets.

@aredden
Created August 24, 2024 15:57
Show Gist options
  • Save aredden/a92434a98b794ca13becf3d607f5650b to your computer and use it in GitHub Desktop.
Save aredden/a92434a98b794ca13becf3d607f5650b to your computer and use it in GitHub Desktop.
quantize fp8 using quantile vs absmax for very ood values
import torch
torch.set_printoptions(precision=4, sci_mode=False)
import triton
import triton.language as tl
from torch import Tensor
def quanitze_fp8_tensorwise(x: torch.Tensor, dtype=torch.float8_e4m3fn):
scale = x.abs().max() / torch.finfo(dtype).max
x = x.float() / scale
return x.to(dtype), scale.float()
def quanitze_fp8_tensorwise_quantile(weight: torch.Tensor, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
quanti = torch.quantile(
weight.abs().float(), 0.999, dim=0, interpolation="lower"
).max()
scale = quanti / finfo.max
q_weight = (weight.float() / scale).clamp(min=-finfo.max, max=finfo.max).to(dtype)
return q_weight, scale.float()
def fp8_linear_torch(
x: torch.Tensor,
x_scale: torch.Tensor,
weight_fp8: torch.Tensor,
weight_scale: torch.Tensor,
):
out = torch._scaled_mm(
x, weight_fp8.T, scale_a=x_scale, scale_b=weight_scale, out_dtype=torch.bfloat16
)
return out
if __name__ == "__main__":
from triton.testing import do_bench
act_bf16 = torch.randn(1024, 2048).bfloat16().cuda()
weight_bf16 = torch.randn(4096, 2048).bfloat16().cuda()
act_bf16[0, 0] = 62320
ref = act_bf16 @ weight_bf16.T
q_weight, q_w_scale = quanitze_fp8_tensorwise_quantile(weight_bf16)
q_act, q_a_scale = quanitze_fp8_tensorwise_quantile(act_bf16)
abs_weight, abs_w_scale = quanitze_fp8_tensorwise(weight_bf16)
abs_act, abs_a_scale = quanitze_fp8_tensorwise(act_bf16)
out_quantile = fp8_linear_torch(q_act, q_a_scale, q_weight, q_w_scale)
out_abs = fp8_linear_torch(abs_act, abs_a_scale, abs_weight, abs_w_scale)
print(
"Median abs diff for fp8 quantile: ",
(out_quantile - ref).abs().median(),
)
print(
"Median abs diff for fp8 absmax: ",
(out_abs - ref).abs().median(),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment