Skip to content

Instantly share code, notes, and snippets.

@cccntu
Created August 8, 2024 15:15
Show Gist options
  • Save cccntu/f461ec7360273ecb8d4b5bc463be8952 to your computer and use it in GitHub Desktop.
Save cccntu/f461ec7360273ecb8d4b5bc463be8952 to your computer and use it in GitHub Desktop.
"""Generates a document causal attention mask based on a document ID tensor"""
from typing import List, Union
import torch
from torch import Tensor
from torch.nn.attention.flex_attention import _mask_mod_signature, or_masks
from attn_gym.masks import causal_mask
def _offsets_to_doc_ids_tensor(offsets):
device = offsets.device
counts = offsets[1:] - offsets[:-1]
return torch.repeat_interleave(
torch.arange(len(counts), device=device, dtype=torch.int32), counts
)
def length_to_offsets(lengths: List[int], device: Union[str, torch.device]) -> Tensor:
"""Converts a list of lengths to a list of offsets.
Args:
lengths: A list of lengths.
"""
offsets = [0]
offsets.extend(lengths)
offsets = torch.tensor(offsets, device=device, dtype=torch.int32)
offsets = torch.cumsum(offsets, dim=-1)
return offsets
def generate_doc_mask_mod(mask_mod: _mask_mod_signature, offsets: Tensor) -> _mask_mod_signature:
"""Generates mask mods that apply to inputs to flex attention in the sequence stacked
format.
Args:
mask_mod: The mask mod to apply to the documents
offsets: This tensor should be of shape(num_documents + 1)
this should contain the cumulative counts of document tokens.
e.g. if you have 3 documents of length 2, 4, 3 then
offsets = [0, 2, 6, 9]
Note:
What is the sequence stacked format? When assembling batches of inputs, we
take multiple sequences and stack them together to form 1 large sequence. We then
use masking to ensure that the attention scores are only applied to tokens within
the same document.
"""
document_id = _offsets_to_doc_ids_tensor(offsets)
def doc_mask_mod(b, h, q_idx, kv_idx):
same_doc = document_id[q_idx] == document_id[kv_idx]
q_logical = q_idx - offsets[document_id[q_idx]]
kv_logical = kv_idx - offsets[document_id[kv_idx]]
inner_mask = mask_mod(b, h, q_logical, kv_logical)
return same_doc & inner_mask
return doc_mask_mod
def generate_turn_mask_mod(offsets: Tensor, is_user: Tensor) -> _mask_mod_signature:
turn_id = _offsets_to_doc_ids_tensor(offsets)
def doc_mask_mod(b, h, q_idx, kv_idx):
same_turn = turn_id[q_idx] == turn_id[kv_idx]
return same_turn & is_user[turn_id[q_idx]]
return doc_mask_mod
def main(device: str = "cpu"):
"""Visualize the attention scores of document causal mask mod.
Args:
device (str): Device to use for computation. Defaults to "cpu".
"""
from attn_gym import visualize_attention_scores
import random
random.seed(0)
def generate_random_lengths(total_length, num_documents):
# Initialize all lengths to 1 to ensure each document has at least one token
lengths = [1] * num_documents
remaining_length = total_length - num_documents
# Randomly distribute the remaining length
for _ in range(remaining_length):
index = random.randint(0, num_documents - 1)
lengths[index] += 1
return lengths
max_seq_len, doc_count, turn_count = 40, 2, 10
B, H, SEQ_LEN, HEAD_DIM = 1, 1, max_seq_len, 8
turn_lengths = generate_random_lengths(max_seq_len, turn_count)
doc_turn_counts = generate_random_lengths(turn_count, doc_count)
doc_lengths = []
turn_is_user = []
i = 0
for doc_id in range(doc_count):
doc_lengths.append(sum(turn_lengths[i:i + doc_turn_counts[doc_id]]))
turn_is_user.extend([True] + [True, False] * ((doc_turn_counts[doc_id] - 1) // 2) + [True] * ((doc_turn_counts[doc_id] - 1) % 2))
i += doc_turn_counts[doc_id]
doc_offsets = length_to_offsets(doc_lengths, device)
turn_offsets = length_to_offsets(turn_lengths, device)
def make_tensor():
return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device)
query, key = make_tensor(), make_tensor()
document_causal_mask = generate_doc_mask_mod(causal_mask, doc_offsets)
turn_is_user = torch.tensor(turn_is_user, device=device)
turn_mask = generate_turn_mask_mod(turn_offsets, turn_is_user)
chat_mask = or_masks(document_causal_mask, turn_mask)
visualize_attention_scores(
query,
key,
mask_mod=chat_mask,
device=device,
name="document_packing_chat_tuning_mask",
)
if __name__ == "__main__":
try:
from jsonargparse import CLI
except ImportError:
raise ImportError("Be sure to run: pip install -e .[viz]")
CLI(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment