Created
August 21, 2023 05:22
-
-
Save andysingal/12460580a29ba59c45178f5210da6268 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import pytorch_lightning as pl | |
from datasets import load_dataset, load_metric | |
from transformers import T5Config, T5ForConditionalGeneration | |
from transformers import ( | |
AutoModel, | |
AutoModelForSeq2SeqLM, | |
AutoTokenizer, | |
DataCollatorForSeq2Seq, | |
Seq2SeqTrainingArguments, | |
Seq2SeqTrainer, | |
) | |
class MyLightningModule(pl.LightningModule): | |
def __init__(self, model_name, learning_rate, weight_decay): | |
super().__init__() | |
self.model_name = model_name | |
self.learning_rate = learning_rate | |
self.weight_decay = weight_decay | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Load the pre-trained model and tokenizer | |
#self.model = torch.compile( | |
# AutoModelForSeq2SeqLM.from_pretrained(self.model_name) | |
#) | |
# Create a T5-small configuration | |
config = T5Config.from_pretrained("t5-small") | |
# Initialize the T5 model with random weights | |
self.model = torch.compile(T5ForConditionalGeneration(config)) | |
# Load the ROUGE metric | |
self.metric = load_metric("rouge") | |
self.logits = [] | |
self.labels = [] | |
def forward(self, input_ids, attention_mask, labels=None): | |
output = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
labels=labels, | |
) | |
return output.loss, output.logits | |
def training_step(self, batch, batch_idx): | |
input_ids = batch["input_ids"] | |
attention_mask = batch["attention_mask"] | |
labels = batch["labels"] | |
loss, logits = self(input_ids, attention_mask, labels) | |
self.log("train_loss", loss, on_epoch=True, on_step=True, prog_bar=True) | |
return {"loss": loss, "logits": logits} | |
def validation_step(self, batch, batch_idx): | |
input_ids = batch["input_ids"] | |
attention_mask = batch["attention_mask"] | |
labels = batch["labels"] | |
loss, logits = self(input_ids, attention_mask, labels) | |
self.log("val_loss", loss, on_epoch=True, on_step=False) | |
# add logits and labels to instance attributes, but make sure to detach them | |
# from the computational graph first | |
self.logits.append(logits.argmax(dim=-1).detach().cpu()) | |
self.labels.append(labels.detach().cpu()) | |
return {"loss": loss, "logits": logits, "labels": labels} | |
def on_validation_epoch_end(self): | |
# Concatenate tensors in logits and labels lists | |
pred_token_ids = torch.cat(self.logits, dim=0) | |
true_labels = torch.cat(self.labels, dim=0) | |
# Decode predictions and labels using the saved instance attributes | |
decoded_preds = self.tokenizer.batch_decode( | |
pred_token_ids, skip_special_tokens=True | |
) | |
decoded_labels = self.tokenizer.batch_decode( | |
true_labels, skip_special_tokens=True | |
) | |
# Compute ROUGE scores | |
scores = self.metric.compute( | |
predictions=decoded_preds, references=decoded_labels, rouge_types=["rouge1"] | |
)["rouge1"].mid | |
self.log("rouge1_precision", scores.precision, prog_bar=True) | |
self.log("rouge1_recall", scores.recall, prog_bar=True) | |
self.log("rouge1_fmeasure", scores.fmeasure, prog_bar=True) | |
# Clear logits and labels instance attributes for the next validation epoch | |
self.logits.clear() | |
self.labels.clear() | |
def predict(self, article: str, max_input_length: int = 512, max_output_length: int = 150) -> str: | |
# Set the model to evaluation mode | |
self.model.eval() | |
# Tokenize the input article | |
inputs = self.tokenizer( | |
article, | |
max_length=max_input_length, | |
padding="max_length", | |
truncation=True, | |
return_tensors="pt" | |
) | |
# Move the input tensors to the same device as the model | |
inputs = {key: value.to(self.device) for key, value in inputs.items()} | |
# Generate summary | |
with torch.no_grad(): | |
output = self.model.generate( | |
input_ids=inputs["input_ids"], | |
attention_mask=inputs["attention_mask"], | |
max_length=max_output_length, | |
num_return_sequences=1, | |
) | |
# Decode and return the summary | |
summary = self.tokenizer.decode(output[0], skip_special_tokens=True) | |
return summary | |
def configure_optimizers(self): | |
optimizer = torch.optim.AdamW( | |
self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay | |
) | |
return optimizer |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment