Last active
January 22, 2024 05:05
-
-
Save thistleknot/b36a3daf6a31e2c4c2ac21803e575afd to your computer and use it in GitHub Desktop.
Train Mamba
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments | |
import wandb | |
from datasets import load_dataset | |
import torch | |
import os | |
import argparse | |
import numpy as np | |
import pandas as pd | |
from transformers import EvalPrediction | |
from torch.utils.data import DataLoader | |
from transformers import DataCollatorForLanguageModeling | |
os.environ["WANDB_MODE"] = "offline" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load model and tokenizer | |
model = AutoModelForCausalLM.from_pretrained('Q-bert/Mamba-130M', trust_remote_code=True) | |
model.to(device) | |
tokenizer = AutoTokenizer.from_pretrained('Q-bert/Mamba-130M') | |
tokenizer.add_special_tokens({'pad_token': '[PAD]'}) | |
# Move model to appropriate device | |
# Load dataset | |
dataset = load_dataset("Abirate/english_quotes", split='train') | |
max_size = 25 | |
min_size = 10 | |
# Preprocessing function to tokenize the 'quotes' field | |
def tokenize_function(examples): | |
return tokenizer(examples["quote"], padding="max_length", truncation=True, max_length=max_size) | |
# Apply the tokenizer to the dataset | |
tokenized_dataset = dataset.map(tokenize_function, batched=True) | |
# Filter function to keep quotes between 5 and 25 tokens | |
def filter_quotes(batch): | |
# Calculate actual lengths for each example in the batch | |
actual_lengths = [sum(mask) for mask in batch["attention_mask"]] | |
# Determine which examples to keep based on their actual length | |
keep = [min_size <= length <= max_size for length in actual_lengths] | |
return keep | |
# Apply the filter to the dataset | |
filtered_dataset = tokenized_dataset.filter(filter_quotes, batched=True) | |
# Splitting the dataset into training and evaluation sets | |
split_dataset = filtered_dataset.train_test_split(test_size=0.1) # 10% for evaluation | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--block_size", type=int, default=np.max([len(t) for t in filtered_dataset['input_ids']])) | |
args = parser.parse_args() | |
#1200 = 16GB & 130M | |
parser.add_argument("--target_tokens", type=int, default=1200) | |
args = parser.parse_args() | |
parser.add_argument("--batch_size", type=int, default=int(np.round(args.target_tokens/args.block_size))) | |
args = parser.parse_args() | |
parser.add_argument("--epochs", type=int, default=3) | |
args = parser.parse_args() | |
parser.add_argument("--gradient_steps", type=int, default=4) | |
args = parser.parse_args() | |
parser.add_argument("--epoch_iters", type=int, default=int(np.round((len(split_dataset['train'])*args.block_size)/(args.block_size*args.batch_size)/args.gradient_steps))) | |
parser.add_argument("--learning_rate", type=int, default=1e-4) | |
parser.add_argument("--weight_decay", type=int, default=0.1) | |
args = parser.parse_args() | |
print(len(filtered_dataset)) | |
print(args.block_size) | |
print(args.batch_size) | |
print(args.epoch_iters) | |
# Define custom trainer | |
class MambaTrainer(Trainer): | |
def compute_loss(self, model, inputs, return_outputs=False): | |
input_ids = inputs.pop("input_ids") | |
lm_logits = model(input_ids)[0] | |
labels = input_ids.to(lm_logits.device) | |
shift_logits = lm_logits[:, :-1, :].contiguous() | |
labels = labels[:, 1:].contiguous() | |
loss_fct = torch.nn.CrossEntropyLoss() | |
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) | |
return lm_loss | |
# Training arguments with logging and evaluation strategies | |
training_args = TrainingArguments( | |
output_dir="./mamba_trainer_output", | |
per_device_train_batch_size=args.batch_size, | |
per_device_eval_batch_size=args.batch_size, | |
num_train_epochs=args.epochs, | |
logging_strategy="steps", # Log training metrics at every step | |
logging_steps=1, # Log every step | |
evaluation_strategy="epoch", # Evaluate at the end of each e | |
save_steps=args.epoch_iters, | |
save_total_limit=2, | |
weight_decay=args.weight_decay, | |
learning_rate=args.learning_rate, | |
gradient_accumulation_steps=args.gradient_steps | |
) | |
# Initialize trainer | |
trainer = MambaTrainer( | |
model=model, | |
args=training_args, | |
train_dataset=split_dataset["train"], | |
eval_dataset=split_dataset["test"] | |
) | |
# Manually compute evaluation loss | |
# Create a dictionary containing the input data | |
input_data = {"input_ids": torch.tensor(split_dataset["test"]["input_ids"][0:4], dtype=torch.long).to(device)} | |
# Manually compute evaluation loss using compute_loss | |
eval_loss = trainer.compute_loss(model=trainer.model, inputs=input_data) | |
# Print the evaluation loss | |
print("Evaluation Loss:", eval_loss) | |
# Start training | |
trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment