Skip to content

Instantly share code, notes, and snippets.

View gante's full-sized avatar

Joao Gante gante

View GitHub Profile
@gante
gante / dola_demo.py
Created July 10, 2024 14:57
DoLa demo
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B-Instruct", torch_dtype=torch.bfloat16, device_map="auto"
)
model.generation_config.eos_token_id = model.generation_config.pad_token_id
question = 'What does Darth Vader say to Luke in "The Empire Strikes Back"?'
@gante
gante / yarn_checks.py
Created July 8, 2024 10:37
yarn checks
"""
Assumes:
1. transformers on this branch (https://github.com/huggingface/transformers/pull/30910)
2. yarn pip installed (https://github.com/jquesnelle/yarn)
3. HF login with read token (`huggingface-cli login`)
"""
import torch
from huggingface_hub import hf_hub_download
from transformers import AutoConfig, AutoTokenizer
# `torch.compile`-enabled Llama 3
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch, time, os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto", torch_dtype=torch.float16
)
@gante
gante / llama2_compile.py
Created March 21, 2024 15:42
v4.39 Llama 2 + torch.compile
# `torch.compile` enabled Llama 2 🏎️
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch, time
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", device_map="auto", torch_dtype=torch.float16
)
model.forward = torch.compile(model.forward, fullgraph=True, mode="reduce-overhead")
@gante
gante / galactica_contrastive_search.py
Created November 18, 2022 11:33
Galactica (1.3b) + contrastive search examples
from transformers import AutoTokenizer, OPTForCausalLM
tokenizer = AutoTokenizer.from_pretrained("facebook/galactica-1.3b")
model = OPTForCausalLM.from_pretrained("facebook/galactica-1.3b", device_map="auto")
# input_text = "Question: How small is a human cell? Answer:" # they should get the same short answers
input_text = "Question: What do Maxwell's equations represent? Answer:" # better with repetitions
# input_text = "Question: Simplify the following Python code using math:```pythondef calc_sum(n): i = 0 s = 0 while i <= n: s += i i += 1 return s```Answer:" # better with early stop
# input_text = "Question: What technology will revolutionize language models? Answer:"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")
@gante
gante / benchmark_whisper.py
Last active October 7, 2022 12:15
OpenAI Whisper Benchmark
import time
from datetime import timedelta
from functools import wraps
from tqdm import tqdm
# PyTorch imports and settings
import torch
from transformers.testing_utils import torch_device
torch.backends.cuda.matmul.allow_tf32 = True # All frameworks using TF32
@gante
gante / pt_img_gen.py
Last active July 29, 2023 20:08
Portuguese image generation
from diffusers import StableDiffusionPipeline
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from torch import autocast
PT_PROMPT = "Um gato com um chapéu, pintura a aguarelas" # A cat with a hat, watercolor painting
# translation PT -> EN
transl_model_id = "Narrativa/mbart-large-50-finetuned-opus-pt-en-translation"
tokenizer = AutoTokenizer.from_pretrained(transl_model_id)
text_model = AutoModelForSeq2SeqLM.from_pretrained(transl_model_id)
@gante
gante / generate_benchmark.py
Last active December 19, 2022 16:39
Benchmark Hugging Face generation for PT, FLAX, TF with Eager Execution, and TF with XLA.
import os
import time
from datetime import timedelta
from functools import wraps, partial
from tqdm import tqdm
# JAX imports and settings
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import jax