Last active
July 27, 2024 10:58
-
-
Save corporatepiyush/2d65f1ca78ef519b0f31af70474e21b8 to your computer and use it in GitHub Desktop.
Disk backed in-memory chunked array DataFrame wrapper for Polars
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import polars as pl | |
import os | |
import time | |
from threading import Thread, Event, RLock, Lock | |
from typing import List | |
import uuid | |
def synchronized(lock): | |
"""Synchronization decorator.""" | |
def wrapper(f): | |
def wrapped(*args, **kwargs): | |
with lock: | |
return f(*args, **kwargs) | |
return wrapped | |
return wrapper | |
class ChunkDataFrame: | |
def __init__(self, base_dir: str): | |
self.file_path = os.path.join(base_dir, f"chunk_{uuid.uuid4()}.parquet") | |
self._df = None | |
self.dirty = False | |
self.lock = RLock() | |
@property | |
def df(self): | |
with self.lock: | |
if self._df is None: | |
self._load() | |
return self._df | |
@df.setter | |
def df(self, new_df: pl.DataFrame): | |
with self.lock: | |
self._df = new_df | |
self.dirty = True | |
def _load(self): | |
if os.path.exists(self.file_path): | |
self._df = pl.read_parquet(self.file_path) | |
print(f"Loaded DataFrame from {self.file_path}") | |
else: | |
self._df = pl.DataFrame() | |
print(f"Initialized empty DataFrame for {self.file_path}") | |
def save(self): | |
with self.lock: | |
if self.dirty: | |
self._df.write_parquet(self.file_path) | |
self.dirty = False | |
print(f"Saved DataFrame to {self.file_path}") | |
def __getattr__(self, attr): | |
df_attr = getattr(self.df, attr) | |
if callable(df_attr): | |
def wrapper(*args, **kwargs): | |
result = df_attr(*args, **kwargs) | |
if attr in ['with_column', 'drop', 'apply']: | |
with self.lock: | |
self.dirty = True | |
return result | |
return wrapper | |
else: | |
return df_attr | |
class ChunkArrayDataFrame: | |
dirty_chunks = [] | |
dirty_chunks_lock = Lock() | |
stop_event = Event() | |
saver_thread = None | |
@classmethod | |
def start_saver_thread(cls): | |
if cls.saver_thread is None: | |
cls.saver_thread = Thread(target=cls._save_dirty_chunks_periodically, daemon=True) | |
cls.saver_thread.start() | |
@classmethod | |
def _save_dirty_chunks_periodically(cls): | |
while not cls.stop_event.is_set(): | |
time.sleep(1) | |
cls._save_dirty_chunks() | |
@classmethod | |
@synchronized(dirty_chunks_lock) | |
def _save_dirty_chunks(cls): | |
for chunk in cls.dirty_chunks: | |
chunk.save() | |
cls.dirty_chunks.clear() | |
@classmethod | |
def stop_saver_thread(cls): | |
cls.stop_event.set() | |
if cls.saver_thread: | |
cls.saver_thread.join() | |
def __init__(self, base_dir: str, chunk_size: int = 1000): | |
self.base_dir = base_dir | |
self.chunk_size = chunk_size | |
self.chunks = [] | |
self.chunk_locks = [] | |
os.makedirs(base_dir, exist_ok=True) | |
def _get_or_create_chunk(self, chunk_index: int): | |
while chunk_index >= len(self.chunks): | |
self.chunks.append(ChunkDataFrame(self.base_dir)) | |
self.chunk_locks.append(RLock()) | |
return self.chunks[chunk_index], self.chunk_locks[chunk_index] | |
def append(self, df: pl.DataFrame): | |
start_idx = 0 | |
while start_idx < len(df): | |
chunk_index = start_idx // self.chunk_size | |
end_idx = min(start_idx + self.chunk_size, len(df)) | |
chunk_df = df[start_idx:end_idx] | |
chunk, chunk_lock = self._get_or_create_chunk(chunk_index) | |
with chunk_lock: | |
if chunk.df.is_empty(): | |
chunk.df = chunk_df | |
else: | |
chunk.df = pl.concat([chunk.df, chunk_df]) | |
with ChunkArrayDataFrame.dirty_chunks_lock: | |
ChunkArrayDataFrame.dirty_chunks.append(chunk) | |
start_idx = end_idx | |
def filter(self, condition): | |
results = [chunk.df.filter(condition) for chunk, chunk_lock in zip(self.chunks, self.chunk_locks) | |
if not chunk.df.is_empty()] | |
return pl.concat(results) if results else pl.DataFrame() | |
def join(self, other, on=None, how='inner'): | |
results = [chunk.df.join(other_chunk.df, on=on, how=how) | |
for chunk, chunk_lock in zip(self.chunks, self.chunk_locks) | |
for other_chunk, other_chunk_lock in zip(other.chunks, other.chunk_locks)] | |
return pl.concat(results) if results else pl.DataFrame() | |
def __getattr__(self, attr): | |
def wrapper(*args, **kwargs): | |
results = [getattr(chunk.df, attr)(*args, **kwargs) | |
for chunk, chunk_lock in zip(self.chunks, self.chunk_locks)] | |
return pl.concat(results) if results and isinstance(results[0], pl.DataFrame) else results | |
return wrapper | |
# Start the global saver thread | |
ChunkArrayDataFrame.start_saver_thread() | |
# Example Usage | |
if __name__ == "__main__": | |
base_dir = "./chunked_data" | |
df1 = ChunkArrayDataFrame(base_dir) | |
df2 = ChunkArrayDataFrame(base_dir) | |
# Create sample DataFrames | |
data1 = pl.DataFrame({ | |
'id': list(range(500)), | |
'department': ['Engineering'] * 250 + ['Sales'] * 250, | |
'salary': list(range(500)) | |
}) | |
data2 = pl.DataFrame({ | |
'id': list(range(250, 750)), | |
'bonus': list(range(250, 750)), | |
}) | |
# Append data to the ChunkArrayDataFrames | |
df1.append(data1) | |
df2.append(data2) | |
# Perform a filtering operation on df1 | |
filtered_df1 = df1.filter(pl.col('salary') > 200) | |
print("Filtered DataFrame 1:") | |
print(filtered_df1) | |
# Join df1 and df2 on 'id' | |
joined_df = df1.join(df2, on='id', how='inner') | |
print("Joined DataFrame:") | |
print(joined_df) | |
# Add a new column to df1 (this will mark the chunks as dirty) | |
df1 = df1.with_column(pl.lit('temp').alias('new_column')) | |
print("After Adding a New Column to DataFrame 1:") | |
print(df1) | |
# Drop a column from df1 (this will mark the chunks as dirty) | |
df1 = df1.drop('new_column') | |
print("After Dropping the New Column from DataFrame 1:") | |
print(df1) | |
# Group by operation on df1 | |
grouped_df1 = df1.groupby("department").agg(pl.col("salary").mean().alias("average_salary")) | |
print("Grouped DataFrame 1 with Average Salary:") | |
print(grouped_df1) | |
# Apply a custom function to df1 | |
def custom_func(salary): | |
return salary * 1.1 | |
df1 = df1.with_column(pl.col("salary").apply(custom_func).alias("salary")) | |
print("After Applying Custom Function to Salary in DataFrame 1:") | |
print(df1) | |
# Stop the saver thread when done | |
ChunkArrayDataFrame.stop_saver_thread() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment