Skip to content

Instantly share code, notes, and snippets.

@lucidrains
Created August 12, 2024 17:48
Show Gist options
  • Save lucidrains/fe5d6c2c896bf1f935acf63d45339916 to your computer and use it in GitHub Desktop.
Save lucidrains/fe5d6c2c896bf1f935acf63d45339916 to your computer and use it in GitHub Desktop.
Tree Attention Decoding
import torch
from torch import einsum
import torch.distributed as dist
def tree_attn_decode(q, k, v):
"""
Algorithm 3 proposed in Tree Attention
https://arxiv.org/abs/2408.04093
"""
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
# scale queries
scale = q.shape[-1] ** -0.5
q = q * scale
# each machine (rank) takes care of a chunk of kv sequence within the world of many machines
k = k.chunk(world_size, dim = -2)
v = v.chunk(world_size, dim = -2)
k, v = k[rank], v[rank]
# first calculate local output
sim = einsum('... i d, ... j d -> ... i j', q, k)
local_max = sim.amax(dim = -1, keepdim = True)
sim = sim - local_max
lse = sim.logsumexp(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1)
out = einsum('... i j, ... j d -> ... i d', attn, v)
den = lse.exp()
num = out * den
# first get global max through an all reduce (max)
global_max = local_max.clone()
dist.all_reduce(global_max, dist.ReduceOp.MAX)
# renormalize the numerator and denominators
renorm_factor = (local_max - global_max).exp()
den = den * renorm_factor
num = num * renorm_factor
# second and third all reduce (sum)
dist.all_reduce(den)
dist.all_reduce(num)
return num / den
# regular attention for testing
def regular_decode(q, k, v):
scale = q.shape[-1] ** -0.5
q = q * scale
sim = einsum('... i d, ... j d -> ... i j', q, k)
attn = sim.softmax(dim = -1)
return einsum('... i j, ... j d -> ... i d', attn, v)
# for testing the above tree decoding function
# `pip install click` as requirement, besides `torch`
import os
import click
from math import ceil
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(
rank,
world_size,
use_cuda
):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
backend = "gloo" if not use_cuda else "nccl"
dist.init_process_group(backend, rank = rank, world_size = world_size)
if use_cuda:
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def start(
rank,
world_size,
seq_len,
use_cuda,
):
setup(rank, world_size, use_cuda)
is_main = rank == 0
ring_seq_size = ceil(seq_len / world_size)
# inputs
q = torch.randn(1, 1, 512)
k = torch.randn(1, seq_len, 512)
v = torch.randn(1, seq_len, 512)
# easy forcing all q, k, v to be same across all device
dist.all_reduce(q)
dist.all_reduce(k)
dist.all_reduce(v)
# outputs
out = regular_decode(q, k, v)
tree_out = tree_attn_decode(q, k, v)
# if not main early return
if not is_main:
return cleanup()
# if is main, validate output is the same for kv sequence split across machines vs without
tree_out = tree_out.cpu()
out = out.cpu()
output_atol = 1e-2 if use_cuda else 1e-5
assert torch.allclose(tree_out, out, atol = output_atol), '🟥 output is not the same'
print('✅ output is the same between tree and non-tree attention decoding')
cleanup()
@click.command()
@click.option('--world-size', default = 8, help = 'number of machines / processes')
@click.option('--use-cuda', is_flag = True, help = 'whether to test with CUDA and NCCL')
@click.option('--seq-len', default = 31, help = 'sequence length to test')
def test(
world_size: int,
use_cuda: bool,
seq_len: int,
):
assert not use_cuda or world_size <= torch.cuda.device_count(), f'world size {world_size} must be less than the number of cuda devices {torch.cuda.device_count()}'
mp.spawn(
start,
args = (world_size, seq_len, use_cuda),
nprocs = world_size,
join = True
)
if __name__ == '__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment