Created
August 9, 2023 08:38
-
-
Save andysingal/91baaef8e3801b400da8a4800089cb2d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForCausalLM, AutoTokenizer, default_data_collator,DefaultDataCollator,DataCollatorWithPadding, get_linear_schedule_with_warmup | |
from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType | |
import torch,os,random | |
from datasets import load_dataset | |
from torch.utils.data import DataLoader | |
from tqdm import tqdm | |
from evaluate import load as load_metric | |
from transformers import AdamW | |
model_name_or_path = "bigscience/bloomz-560m" #@param | |
tokenizer_name_or_path = "bigscience/bloomz-560m" #@param | |
def load_data(): | |
""" | |
Load Yelp data for training and evaluation. | |
""" | |
raw_datasets = load_dataset("yelp_review_full") | |
raw_datasets = raw_datasets.shuffle(seed=42) | |
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) | |
def tokenize_function(examples): | |
return tokenizer(examples["text"], truncation=True, max_length=max_token_length) | |
train_population = random.sample(range(len(raw_datasets["train"])), 100) | |
test_population = random.sample(range(len(raw_datasets["test"])), 100) | |
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True) | |
tokenized_datasets["train"] = tokenized_datasets["train"].select(train_population) | |
tokenized_datasets["test"] = tokenized_datasets["test"].select(test_population) | |
tokenized_datasets = tokenized_datasets.remove_columns("text") | |
tokenized_datasets = tokenized_datasets.rename_column("label", "labels") | |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
train_loader = DataLoader( | |
tokenized_datasets["train"], | |
shuffle=True, | |
batch_size=32, | |
collate_fn=data_collator, | |
) | |
test_loader = DataLoader( | |
tokenized_datasets["test"], batch_size=32, collate_fn=data_collator | |
) | |
return train_loader, test_loader, tokenizer | |
# Load train and test data | |
train_loader, test_loader, tokenizer = load_data() | |
# Instantiate and move the model to the appropriate device | |
model = model.to(DEVICE) | |
# Define the optimizer and learning rate scheduler | |
optimizer = torch.optim.AdamW(model.parameters(), lr=lr) | |
lr_scheduler = get_linear_schedule_with_warmup( | |
optimizer=optimizer, | |
num_warmup_steps=0, | |
num_training_steps=(len(train_loader) * num_epochs), | |
) | |
# Training loop | |
for epoch in range(num_epochs): | |
model.train() | |
total_loss = 0 | |
for step, batch in enumerate(tqdm(train_loader)): | |
batch = {k: v.to(DEVICE) for k, v in batch.items()} | |
optimizer.zero_grad() | |
with torch.set_grad_enabled(True): | |
outputs = model(**batch) | |
loss = outputs.loss | |
loss.backward() | |
optimizer.step() | |
lr_scheduler.step() | |
total_loss += loss.detach().item() | |
avg_train_loss = total_loss / len(train_loader) | |
model.eval() | |
eval_loss = 0 | |
eval_preds = [] | |
for step, batch in enumerate(tqdm(test_loader)): | |
batch = {k: v.to(DEVICE) for k, v in batch.items()} | |
with torch.set_grad_enabled(False): | |
outputs = model(**batch) | |
loss = outputs.loss | |
eval_loss += loss.detach().item() | |
eval_preds.extend( | |
tokenizer.batch_decode( | |
torch.argmax(outputs.logits, -1).detach().cpu().numpy(), | |
skip_special_tokens=True | |
) | |
) | |
avg_eval_loss = eval_loss / len(test_loader) | |
eval_ppl = torch.exp(torch.tensor(avg_eval_loss)) | |
train_ppl = torch.exp(torch.tensor(avg_train_loss)) | |
print(f"Epoch {epoch}: Train PPL = {train_ppl:.4f}, Eval PPL = {eval_ppl:.4f}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment