Created
April 28, 2023 13:45
-
-
Save saharNooby/e1d871e93522d9c50c5a6fa59f356ba9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# USAGE EXAMPLE | |
logits = llm(...) # Get raw logits from an LLM | |
logits = tail_free_sampling(z=0.95) # Cut off logits in the tail | |
token = sample(logits, temperature=1.0) # Do your usual sampling with temp/top-p | |
def tail_free_sampling(logits: torch.Tensor, z: float = 0.95, mask_value: float = -float('inf')) -> torch.Tensor: | |
""" | |
See https://www.trentonbricken.com/Tail-Free-Sampling/ | |
Code copied from https://github.com/finetunej/transformers/blob/c83109932f4592b871ec4c60326df3b4173b021a/src/transformers/generation_logits_process.py#L243-L284 | |
:param logits: Logits. | |
:param z: Hyperparameter for tail-free sampling. | |
:param mask_value: Tokens that should be excluded from sampling would have their logit set to this value. | |
:return: Masked logits. | |
""" | |
assert len(logits.shape) == 1, str(logits.shape) | |
# numpy sort is faster than PyTorch (5 ms vs 7 ms) | |
logits_np = logits.detach().cpu().numpy() | |
sorted_indices_np = np.argsort(logits_np, kind='quicksort') | |
sorted_indices_np = np.ascontiguousarray(np.flip(sorted_indices_np)) | |
sorted_logits_np = logits_np[sorted_indices_np] | |
sorted_indices = torch.tensor(sorted_indices_np, device=logits.device) | |
sorted_logits = torch.tensor(sorted_logits_np, device=logits.device) | |
d = sorted_logits.softmax(dim=-1) | |
d = d[1:] - d[:-1] | |
d = d[1:] - d[:-1] | |
d = d.abs() | |
d = d / d.sum(dim=-1).item() | |
cumulative_probs = d.cumsum(dim=-1) | |
# Remove tokens with cumulative top_p above the threshold (token with 0 are kept) | |
sorted_indices_to_remove = torch.zeros(sorted_indices.shape, dtype=torch.bool, device=logits.device) | |
sorted_indices_to_remove[:-2] = (cumulative_probs > z)[:] | |
# Always keep the first token | |
sorted_indices_to_remove[0] = 0 | |
# Always remove two last tokens -- they should have negligible probability anyway | |
sorted_indices_to_remove[len(sorted_indices_to_remove) - 1] = 1 | |
sorted_indices_to_remove[len(sorted_indices_to_remove) - 2] = 1 | |
# Scatter sorted tensors to original indexing | |
indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove) | |
return logits.masked_fill(indices_to_remove, mask_value) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment