import torch
import torch.nn.functional as F
import unittest
import xformers.ops as xops
import math
import time
MAX_ITER = 100
"math": dict(enable_math=True, enable_flash=False, enable_mem_efficient=False),
"flash": dict(enable_math=False, enable_flash=True, enable_mem_efficient=False),
"mem_efficient": dict(enable_math=False, enable_flash=False, enable_mem_efficient=True),
def attn_fn(q, k, v, attn_mask=None, scale=None, dropout_p=0., is_causal=False, is_training=False):
# Converted C++ implementation of pytorch to python.
# Scale q, k before matmul for stability see for math
scale_factor = math.sqrt(1 / math.sqrt(q.size(-1)) if scale is None else scale)
if is_causal:
if attn_mask is not None:
raise ValueError("Explicit attn_mask should not be set when is_causal=True")
# Replace attn_mask with causal mask; lower triangular elements take part in attention.
L, S = q.size(-2), k.size(-2)
attn_mask = torch.ones([L, S], dtype=torch.bool).tril()
if attn_mask is not None:
# Convert boolean mask to additive mask; need to invert mask to indicate what to mask *out*.
new_attn_mask = torch.zeros_like(attn_mask).to(q)
new_attn_mask = new_attn_mask.masked_fill(~attn_mask, -float('inf'))
attn_mask = new_attn_mask
query_ = q * scale_factor
key_ = k * scale_factor
attn = query_ @ key_.transpose(-2, -1)
if attn_mask is not None:
attn = attn + attn_mask
attn = torch.softmax(attn, dim=-1)
if dropout_p > 0:
attn = torch.dropout(attn, dropout_p, train=is_training)
return attn @ v
class TestAttention(unittest.TestCase):
def test_attention(self):
device = "cuda:0"
dtype = torch.float16
B = 16
N = 1024
H = 16
D = 128
q = torch.rand(B, H, N, D, dtype=dtype, device=device)
k = torch.rand(B, H, N, D, dtype=dtype, device=device)
v = torch.rand(B, H, N, D, dtype=dtype, device=device)
attn_mask = torch.ones((B, H, N, N), dtype=torch.bool, device=device).tril()
# attn_mask = None
s_math = time.time()
out_math = attn_fn(q, k, v, attn_mask=attn_mask)
e_math = time.time()
s_torch = time.time()
for _ in range(MAX_ITER):
with torch.backends.cuda.sdp_kernel(**SDP_BACKENDS["math"]):
out_torch = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
e_torch = time.time()
if attn_mask is not None:
attn_bias =, -float('inf'))
attn_bias = None
q_ = q.transpose(1, 2)
k_ = k.transpose(1, 2)
v_ = v.transpose(1, 2)
s_xformer = time.time()
for _ in range(MAX_ITER):
out_xformer_ = xops.memory_efficient_attention(q_, k_, v_, attn_bias=attn_bias)
e_xformer = time.time()
out_xformer = out_xformer_.transpose(1, 2)
print(f"The math implementation runs in {(e_math - s_math) * 1e6:.3f} microseconds")
print(f"The torch implementation runs in {(e_torch - s_torch) / MAX_ITER * 1e6:.3f} microseconds")
print(f"The xformer implementation runs in {(e_xformer - s_xformer) / MAX_ITER * 1e6:.3f} microseconds")
self.assertTrue(torch.allclose(out_math, out_torch, atol=1e-3))
self.assertTrue(torch.allclose(out_math, out_xformer, atol=1e-3))
if __name__ == '__main__':
