Skip to content

Instantly share code, notes, and snippets.

@ConradStack
Last active September 6, 2022 19:12
Show Gist options
  • Save ConradStack/a18cab83917bbe111328283c10892cf2 to your computer and use it in GitHub Desktop.
Save ConradStack/a18cab83917bbe111328283c10892cf2 to your computer and use it in GitHub Desktop.
# Derived from https://towardsdatascience.com/how-to-fine-tune-gpt-2-for-text-generation-ae2ea53bc272
import os
import pandas as pd
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
import torch.nn.functional as F
import csv
import re
tok_delim = re.compile(r'\s+')
SONG_COLS = ['Artist', 'SName']
# ----- Data Prep -----
### Prepare data
lyrics = pd.read_csv('data/lyrics-data.csv')
lyrics = lyrics[lyrics.language=='en']
artists = pd.read_csv('data/artists-data.csv')
artists.loc[:,"Genres"] = artists.Genres.str.split(";")
artists = artists.explode("Genres")
artists.loc[:,"Genres"] = artists.Genres.str.strip()
### Only keep popular artists, with genre Rock/Pop and popularity high enough
artists = artists[(artists['Genres'].isin(['Rock', 'Pop'])) & (artists['Popularity']>5)]
### Drop duplicated artist rows (keeping 'Rock' over 'Pop')
artists.sort_values('Genres', ascending=False, inplace=True)
artists.drop_duplicates( subset = list(set(artists.columns) - set(['Genres'])), inplace=True, keep='first' )
### Join lyrics, artists
df = lyrics.merge(artists[['Artist', 'Genres', 'Link']], left_on='ALink', right_on='Link', how='inner')
df.drop(columns=['ALink','SLink','Genres','Link'], inplace=True)
### Tokenize lyric text, add columns to df
tmp = df.Lyric.str.split(tok_delim)
def notempty(y): return(len(y) > 0)
tmp = tmp.apply( lambda x: list(filter(notempty, x)))
lyric_nwords = tmp.apply(len)
df.insert(df.shape[1], 'lyric_nwords', lyric_nwords )
df.insert(df.shape[1], 'lyric_tokens', tmp )
### ... overwrite original lyric strings with simplified versions
df.loc[:, "Lyric"] = df.lyric_tokens.apply(' '.join)
### filter out songs with too few (<25) or too many words (>350)
df = df[ (lyric_nwords>=25) & (lyric_nwords < 350) ].reset_index(drop=True)
del lyric_nwords, tmp
### Create a very small test set to compare generated text with the reality
test_set = df.sample(n = 200, random_state = 106)
train_set = df.drop( index=test_set.index ).copy()
test_set.reset_index(drop=True, inplace=True)
train_set.reset_index(drop=True, inplace=True)
### sanity checks
### ... row counts
assert df.shape[0] == (train_set.shape[0] + test_set.shape[0])
### ... confirm no overlapping songs
shared_songs = train_set.loc[:,SONG_COLS].merge(test_set.loc[:,SONG_COLS], how='inner')
assert shared_songs.shape[0] == 0, "ERROR: overlapping songs in test, train sets"
### For the test set only, keep last 20 words in a new column, then remove them from original column
test_set.insert( test_set.shape[1], 'True_end_lyrics', test_set.lyric_tokens.str[-20:].apply(' '.join) )
test_set.loc[:,'Lyric'] = test_set.lyric_tokens.str[:-20].apply(' '.join)
class SongLyrics(Dataset):
def __init__(self, lyrics : pd.Series, gpt2_type = "gpt2", max_length=1022, truncate=0, **kwargs):
self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type, **kwargs)
self.lyrics = []
for i,text in lyrics.iteritems():
if (truncate > 0) and (i == truncate):
break
lyric_toks = self.tokenizer.tokenize(text)
if len(lyric_toks) > max_length:
istart = np.random.randint( (len(lyric_toks) - max_length) )
lyric_toks = lyric_toks[istart:(istart+max_length)]
self.lyrics.append( torch.tensor([
self.tokenizer.bos_token_id,
*self.tokenizer.convert_tokens_to_ids(lyric_toks),
self.tokenizer.eos_token_id
]))
self.lyrics_count = len(self.lyrics)
def __len__(self):
return self.lyrics_count
def __getitem__(self, idx):
return idx, self.lyrics[idx]
dataset = SongLyrics(train_set.Lyric, gpt2_type="gpt2")
# Get the tokenizer and model
tokenizer = dataset.tokenizer
model = GPT2LMHeadModel.from_pretrained('gpt2')
# Function that helps combine (encoded) lyric data into mini-batches, dynamically
def pack_tensor(new_tensor, packed_tensor, max_seq_len):
if packed_tensor is None:
return new_tensor, True, None
if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len:
return packed_tensor, False, new_tensor
else:
packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1)
return packed_tensor, True, None
# ----- Train -----
def train(
dataset, model, tokenizer,
batch_size=16, epochs=5, lr=2e-5,
max_seq_len=768, warmup_steps=200,
output_dir="_scratch",
output_prefix="lyric_gpt2demo",
test_mode=False,
save_model_on_epoch=False,
):
device = torch.device("cuda:0")
model = model.to(device)
model.train()
optimizer = AdamW(model.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1
)
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True, pin_memory=True)
for epoch in range(epochs):
optimizer.zero_grad()
loss=0
input_tensor = None
### DEBUG/
# (optional) vector of loss values per minibatch
losses=[]
# (optional) vector of flags indicating the minibatches where `optimizer.zero_grad` was called
accumulate = torch.zeros(len(train_dataloader), dtype=torch.bool)
### /DEBUG
print(f"Training epoch {epoch}")
for batch_idx, (idx, entry) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), mininterval=15, maxinterval=60, miniters=200, leave=False):
(input_tensor, carry_on, remainder) = pack_tensor(entry, input_tensor, max_seq_len )
if carry_on and ((batch_idx+1) != len(train_dataloader)):
continue
input_tensor = input_tensor.to(device)
outputs = model(input_tensor, labels=input_tensor)
loss = outputs[0]
loss.backward()
if (((batch_idx+1) % batch_size) == 0) or ((batch_idx+1) == len(train_dataloader)):
optimizer.step()
scheduler.step()
optimizer.zero_grad()
accumulate[batch_idx] = 1
#input_tensor = None
input_tensor = remainder
losses.append( loss.detach().cpu().item() )
print(f"avg loss: {np.mean(losses)} for epoch {epoch}")
if save_model_on_epoch:
print('saving epoch state')
torch.save({
"epoch" : epoch,
"accum_batches" : batch_size,
"lr" : lr,
"max_seq_len" : max_seq_len,
"state_dict" : model.state_dict(),
"losses" : losses,
'accumulate' : accumulate
},
os.path.join(output_dir, f"{output_prefix}-{epoch}.torch"),
)
return model
model = train(dataset, model, tokenizer, save_model_on_epoch=True)
## (optional) Save fine-tuned model
# torch.save( model.state_dict(), '_scratch/gpt2demo_finetuned.STATE_DICT.torch' )
# model.save_pretrained( '_scratch/lyrics_gpt2demo/model' )
# tokenizer.save_pretrained( '_scratch/lyrics_gpt2demo/tokenizer' )
# ----- Generate -----
def generate(
model,
tokenizer,
prompt,
entry_length=30, #maximum number of tokens to generate
top_p=0.8,
temperature=1.,
):
model.eval()
#generated_list = []
filter_value = -float("Inf")
# with torch.no_grad():
with torch.inference_mode():
entry_finished = False
generated = torch.tensor(tokenizer.encode(prompt), device=model.device).unsqueeze(0)
nstart = generated.shape[-1]
for nth in range(entry_length):
# outputs = model(generated, labels=generated)
loss, logits, __ = model(generated, labels=generated).to_tuple()
logits = logits[:, -1, :] / temperature
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[:, indices_to_remove] = filter_value
next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token), dim=1)
## Flag whether or not the next token is the end-to-string special token
entry_finished = (next_token.item() == tokenizer.eos_token_id)
## stop early if end-of-sequence token is reached:
if entry_finished: break
ngenerated = (generated.shape[-1] - nstart)
assert ngenerated == (nth+1), "sanity check failed; check loop"
output_list = list(generated.cpu().squeeze().numpy())
#output_text = f"{tokenizer.decode(output_list)}{'' if entry_finished else '<|endoftext|>'}"
### only return the new (generated) text:
generated_list = output_list[-ngenerated:]
generated_text = f"{tokenizer.decode(generated_list)}{'' if entry_finished else tokenizer.eos_token}"
return generated_text
# generate lyrics for test_set
generated_lyrics = ['']*test_set.shape[0]
for i in trange(test_set.shape[0], leave=False):
generated_lyrics[i] = generate(model, tokenizer, test_set.Lyric.iloc[i])
test_set.insert( test_set.shape[1], 'Generated_lyrics', generated_lyrics )
@ConradStack
Copy link
Author

This is a work-in-progress! If using this expect to find / fix bugs

@wilfoderek
Copy link

Do you have a colab ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment