Created
April 18, 2023 15:43
-
-
Save dchaplinsky/265bbd702caef219423d073c5065b46d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
# Load pre-trained model for sentence embeddings | |
model = SentenceTransformer("paraphrase-multilingual-mpnet-base-v2") | |
# Set up LSTM model | |
input_size = 768 # Size of the sentence embeddings | |
hidden_size = 2048 # Number of hidden units in the LSTM | |
num_layers = 1 # Number of layers in the LSTM | |
batch_size = 32 # Batch size for training the LSTM | |
learning_rate = 0.001 # Initial learning rate for the optimizer | |
accumulation_steps = 4 # Number of batches to accumulate gradients over | |
# Define the LSTM model | |
lstm_model = torch.nn.LSTM( | |
input_size=input_size, hidden_size=hidden_size, num_layers=num_layers | |
) | |
# Set up the optimizer and loss function for training the LSTM | |
optimizer = torch.optim.Adam(lstm_model.parameters(), lr=learning_rate) | |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
optimizer, mode="min", factor=0.1, patience=10, verbose=True | |
) | |
criterion = torch.nn.CrossEntropyLoss() | |
# Load data from disk | |
data_path = "/path/to/data/folder" | |
filenames = os.listdir(data_path) | |
# Train the LSTM model | |
for filename in filenames: | |
# Load the file as a list of sentence strings | |
with open(os.path.join(data_path, filename), "r", encoding="utf-8") as f: | |
sentences = f.read().splitlines() | |
# Calculate sentence embeddings using the pre-trained model | |
embeddings = model.encode(sentences) | |
# Convert sentence embeddings to PyTorch tensors | |
embeddings_tensor = torch.from_numpy(np.array(embeddings)) | |
# Reshape the embeddings tensor to match the expected input shape of the LSTM | |
embeddings_tensor = embeddings_tensor.view(len(sentences), 1, -1) | |
# Train the LSTM model | |
total_loss = 0 | |
for i in range(0, len(sentences) - batch_size, batch_size): | |
# Get the batch of sentence embeddings and corresponding targets | |
batch_embeddings = embeddings_tensor[i : i + batch_size] | |
targets = torch.LongTensor(range(i + 1, i + batch_size + 1)) | |
# Zero the gradients and make a forward pass through the LSTM | |
lstm_model.zero_grad() | |
outputs, _ = lstm_model(batch_embeddings) | |
# Compute the loss and perform backpropagation | |
loss = criterion(outputs.view(batch_size, -1), targets) | |
loss /= accumulation_steps | |
loss.backward() | |
# Accumulate gradients over multiple batches | |
if (i + 1) % accumulation_steps == 0: | |
optimizer.step() | |
optimizer.zero_grad() | |
total_loss += loss.item() | |
# Adjust the learning rate using the scheduler | |
scheduler.step(total_loss / (len(sentences) // batch_size)) | |
# Log the loss | |
print(f"File: {filename}, Loss: {total_loss / (len(sentences) // batch_size)}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment