Skip to content

Instantly share code, notes, and snippets.

@wjurkowlaniec
Created August 2, 2024 13:43
Show Gist options
  • Save wjurkowlaniec/c436d6abfae54381bc2d8e440d018a93 to your computer and use it in GitHub Desktop.
Save wjurkowlaniec/c436d6abfae54381bc2d8e440d018a93 to your computer and use it in GitHub Desktop.
import os
import pathlib
import torch
from transformers import AutoModel, AutoTokenizer
import json
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
# Settings
codebase_dir = "project/path"
cache_dir = "./cache"
model_name = "distilbert-base-uncased"
exclude_dirs = ["venv", "node_modules"]
# Load model and tokenizer
model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
class ChangeHandler(FileSystemEventHandler):
def __init__(self, last_update_dates):
self.last_update_dates = last_update_dates
def on_modified(self, event):
if (
not event.is_directory
and event.src_path.endswith(".py")
and exclude_dirs_func(pathlib.Path(event.src_path))
):
load_file(pathlib.Path(event.src_path), self.last_update_dates)
def load_file(file_path, last_update_dates):
try:
print(f"Loading file: {file_path}", end=" ", flush=True)
# Load file contents
with open(file_path, "r") as f:
contents = f.read()
if not contents.strip():
print(f"File is empty: {file_path}")
return
# Set the chunk size to 512 tokens (adjust as needed)
chunk_size = 512
# Initialize an empty list to store the tokenized chunks
tokenized_chunks = []
# Loop through the file contents in chunks
for i in range(0, len(contents), chunk_size):
chunk = contents[i : i + chunk_size]
# Tokenize the chunk with padding
inputs = tokenizer(
chunk,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=512,
)
# Add the tokenized chunk to the list
tokenized_chunks.append(inputs)
# Combine the tokenized chunks into a single tensor
tokenized_file = torch.cat(
[chunk["input_ids"] for chunk in tokenized_chunks], dim=0
)
# Run the tokenized file through the model
with torch.no_grad():
outputs = model(tokenized_file)
# Save the outputs to the cache
file_name = file_path.name
file_dir = file_path.parent
path_name = os.path.join(
cache_dir, str(file_dir).replace("/", "_") + "_" + file_name + ".pt"
)
os.makedirs(os.path.dirname(path_name), exist_ok=True)
torch.save(outputs, path_name)
# Save the last update date of the file in a JSON file
last_update_dates[str(file_path)] = file_path.stat().st_mtime
with open("last_update_dates.json", "w") as f:
json.dump(last_update_dates, f)
print(f"Saved")
except Exception as e:
print(f"Error loading file: {file_path} - {e}")
def exclude_dirs_func(file_path):
for dir in exclude_dirs:
if dir in file_path.parts:
return False
return True
def main():
print("Starting cache update...")
# Load the last update dates
last_update_dates = {}
try:
with open("last_update_dates.json", "r") as f:
if f.read().strip(): # Check if the file is not empty
last_update_dates = json.load(f)
except FileNotFoundError:
with open("last_update_dates.json", "w") as f:
json.dump({}, f)
except json.JSONDecodeError:
with open("last_update_dates.json", "w") as f:
json.dump({}, f)
event_handler = ChangeHandler(last_update_dates)
observer = Observer()
observer.schedule(event_handler, path=codebase_dir, recursive=True)
observer.start()
try:
while True:
pass
except KeyboardInterrupt:
observer.stop()
observer.join()
print("Cache update complete!")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment