Skip to content

Instantly share code, notes, and snippets.

@freckletonj
Created August 30, 2024 00:53
Show Gist options
  • Save freckletonj/373a4fb84956fcba563bb0c345a8dadf to your computer and use it in GitHub Desktop.
Save freckletonj/373a4fb84956fcba563bb0c345a8dadf to your computer and use it in GitHub Desktop.
Generate tokens using past_key_values/kv-cache in transformers
def generate_with_cache(model, model_inputs, max_new_tokens):
''' Use past_key_values for a theoretical speedup. '''
generated_tokens = []
past_key_values = None
next_token = None
input_ids = model_inputs['input_ids']
attention_mask = model_inputs['attention_mask']
for i in range(max_new_tokens):
# For the first iteration, use the full prompt. For subsequent
# iterations, use only the last generated token. `attention_mask` will
# continue to grow as the entire sequence length seen so far
if i > 0:
input_ids = next_token.unsqueeze(1)
attention_mask = torch.cat([attention_mask, torch.ones_like(input_ids)], dim=-1)
out = model(input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, return_dict=True)
next_token = out.logits[:, -1].argmax(dim=-1)
generated_tokens.append(next_token)
past_key_values = out.past_key_values
return torch.stack(generated_tokens, dim=-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment