Skip to content

Instantly share code, notes, and snippets.

@corporatepiyush
Last active July 27, 2024 10:58
Show Gist options
  • Save corporatepiyush/2d65f1ca78ef519b0f31af70474e21b8 to your computer and use it in GitHub Desktop.
Save corporatepiyush/2d65f1ca78ef519b0f31af70474e21b8 to your computer and use it in GitHub Desktop.
Disk backed in-memory chunked array DataFrame wrapper for Polars
import polars as pl
import os
import time
from threading import Thread, Event, RLock, Lock
from typing import List
import uuid
def synchronized(lock):
"""Synchronization decorator."""
def wrapper(f):
def wrapped(*args, **kwargs):
with lock:
return f(*args, **kwargs)
return wrapped
return wrapper
class ChunkDataFrame:
def __init__(self, base_dir: str):
self.file_path = os.path.join(base_dir, f"chunk_{uuid.uuid4()}.parquet")
self._df = None
self.dirty = False
self.lock = RLock()
@property
def df(self):
with self.lock:
if self._df is None:
self._load()
return self._df
@df.setter
def df(self, new_df: pl.DataFrame):
with self.lock:
self._df = new_df
self.dirty = True
def _load(self):
if os.path.exists(self.file_path):
self._df = pl.read_parquet(self.file_path)
print(f"Loaded DataFrame from {self.file_path}")
else:
self._df = pl.DataFrame()
print(f"Initialized empty DataFrame for {self.file_path}")
def save(self):
with self.lock:
if self.dirty:
self._df.write_parquet(self.file_path)
self.dirty = False
print(f"Saved DataFrame to {self.file_path}")
def __getattr__(self, attr):
df_attr = getattr(self.df, attr)
if callable(df_attr):
def wrapper(*args, **kwargs):
result = df_attr(*args, **kwargs)
if attr in ['with_column', 'drop', 'apply']:
with self.lock:
self.dirty = True
return result
return wrapper
else:
return df_attr
class ChunkArrayDataFrame:
dirty_chunks = []
dirty_chunks_lock = Lock()
stop_event = Event()
saver_thread = None
@classmethod
def start_saver_thread(cls):
if cls.saver_thread is None:
cls.saver_thread = Thread(target=cls._save_dirty_chunks_periodically, daemon=True)
cls.saver_thread.start()
@classmethod
def _save_dirty_chunks_periodically(cls):
while not cls.stop_event.is_set():
time.sleep(1)
cls._save_dirty_chunks()
@classmethod
@synchronized(dirty_chunks_lock)
def _save_dirty_chunks(cls):
for chunk in cls.dirty_chunks:
chunk.save()
cls.dirty_chunks.clear()
@classmethod
def stop_saver_thread(cls):
cls.stop_event.set()
if cls.saver_thread:
cls.saver_thread.join()
def __init__(self, base_dir: str, chunk_size: int = 1000):
self.base_dir = base_dir
self.chunk_size = chunk_size
self.chunks = []
self.chunk_locks = []
os.makedirs(base_dir, exist_ok=True)
def _get_or_create_chunk(self, chunk_index: int):
while chunk_index >= len(self.chunks):
self.chunks.append(ChunkDataFrame(self.base_dir))
self.chunk_locks.append(RLock())
return self.chunks[chunk_index], self.chunk_locks[chunk_index]
def append(self, df: pl.DataFrame):
start_idx = 0
while start_idx < len(df):
chunk_index = start_idx // self.chunk_size
end_idx = min(start_idx + self.chunk_size, len(df))
chunk_df = df[start_idx:end_idx]
chunk, chunk_lock = self._get_or_create_chunk(chunk_index)
with chunk_lock:
if chunk.df.is_empty():
chunk.df = chunk_df
else:
chunk.df = pl.concat([chunk.df, chunk_df])
with ChunkArrayDataFrame.dirty_chunks_lock:
ChunkArrayDataFrame.dirty_chunks.append(chunk)
start_idx = end_idx
def filter(self, condition):
results = [chunk.df.filter(condition) for chunk, chunk_lock in zip(self.chunks, self.chunk_locks)
if not chunk.df.is_empty()]
return pl.concat(results) if results else pl.DataFrame()
def join(self, other, on=None, how='inner'):
results = [chunk.df.join(other_chunk.df, on=on, how=how)
for chunk, chunk_lock in zip(self.chunks, self.chunk_locks)
for other_chunk, other_chunk_lock in zip(other.chunks, other.chunk_locks)]
return pl.concat(results) if results else pl.DataFrame()
def __getattr__(self, attr):
def wrapper(*args, **kwargs):
results = [getattr(chunk.df, attr)(*args, **kwargs)
for chunk, chunk_lock in zip(self.chunks, self.chunk_locks)]
return pl.concat(results) if results and isinstance(results[0], pl.DataFrame) else results
return wrapper
# Start the global saver thread
ChunkArrayDataFrame.start_saver_thread()
# Example Usage
if __name__ == "__main__":
base_dir = "./chunked_data"
df1 = ChunkArrayDataFrame(base_dir)
df2 = ChunkArrayDataFrame(base_dir)
# Create sample DataFrames
data1 = pl.DataFrame({
'id': list(range(500)),
'department': ['Engineering'] * 250 + ['Sales'] * 250,
'salary': list(range(500))
})
data2 = pl.DataFrame({
'id': list(range(250, 750)),
'bonus': list(range(250, 750)),
})
# Append data to the ChunkArrayDataFrames
df1.append(data1)
df2.append(data2)
# Perform a filtering operation on df1
filtered_df1 = df1.filter(pl.col('salary') > 200)
print("Filtered DataFrame 1:")
print(filtered_df1)
# Join df1 and df2 on 'id'
joined_df = df1.join(df2, on='id', how='inner')
print("Joined DataFrame:")
print(joined_df)
# Add a new column to df1 (this will mark the chunks as dirty)
df1 = df1.with_column(pl.lit('temp').alias('new_column'))
print("After Adding a New Column to DataFrame 1:")
print(df1)
# Drop a column from df1 (this will mark the chunks as dirty)
df1 = df1.drop('new_column')
print("After Dropping the New Column from DataFrame 1:")
print(df1)
# Group by operation on df1
grouped_df1 = df1.groupby("department").agg(pl.col("salary").mean().alias("average_salary"))
print("Grouped DataFrame 1 with Average Salary:")
print(grouped_df1)
# Apply a custom function to df1
def custom_func(salary):
return salary * 1.1
df1 = df1.with_column(pl.col("salary").apply(custom_func).alias("salary"))
print("After Applying Custom Function to Salary in DataFrame 1:")
print(df1)
# Stop the saver thread when done
ChunkArrayDataFrame.stop_saver_thread()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment