Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save davidberard98/903fb3e586edb6d1d466786e1a610eba to your computer and use it in GitHub Desktop.
Save davidberard98/903fb3e586edb6d1d466786e1a610eba to your computer and use it in GitHub Desktop.
import torch
import torch._functorch.config
def fn(values, offsets, w):
for _ in range(10):
nt = torch.nested.nested_tensor_from_jagged(values, offsets, min_seqlen=1, max_seqlen=4).view(-1, -1, 4, 16).transpose(1, 2)
nt = torch.nn.functional.scaled_dot_product_attention(nt, nt, nt)
values = nt.transpose(1, 2).view(-1, -1, 64).values().cos()
values = values @ w
return values
values = torch.randn(20, 64, requires_grad=True, dtype=torch.bfloat16, device="cuda")
offsets = torch.tensor([0, 1, 3, 6, 10, 14, 17, 19, 20], device="cuda")
w = torch.randn(64, 64, requires_grad=True, dtype=torch.bfloat16, device="cuda")
# fn(values, offsets, w)
with torch._functorch.config.patch(activation_memory_budget=0.8, debug_partitioner=True):
torch.compile(fn)(values, offsets, w)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment