Skip to content

Instantly share code, notes, and snippets.

@shreyansh26
Created June 7, 2023 19:01
Show Gist options
  • Save shreyansh26/47f165588a360d0e10d519d5e0888cd5 to your computer and use it in GitHub Desktop.
Save shreyansh26/47f165588a360d0e10d519d5e0888cd5 to your computer and use it in GitHub Desktop.
from transformers import AutoModelForCausalLM, AutoTokenizer
tok = AutoTokenizer.from_pretrained("distilgpt2")
model = AutoModelForCausalLM.from_pretrained("distilgpt2")
inputs = tok(["Hello how"], return_tensors="pt")
len_inp = len(inputs.input_ids[0])
print(len_inp)
generated = model.generate(**inputs, do_sample=False, max_new_tokens=10)
forward_confirmation = model(generated).logits.argmax(-1)
forward_pass = model(generated).logits
print(forward_pass.shape)
print(generated)
print(tok.decode(generated[0]))
print(forward_confirmation)
print(tok.decode(forward_confirmation[0]))
print(generated[0][len_inp:].tolist() == forward_confirmation[0][len_inp-1:-1].tolist()) # True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment