Skip to content

Instantly share code, notes, and snippets.

@eladn
Created October 19, 2023 08:45
Show Gist options
  • Save eladn/1bc4f2eb1880e01c52cf9795c3610733 to your computer and use it in GitHub Desktop.
Save eladn/1bc4f2eb1880e01c52cf9795c3610733 to your computer and use it in GitHub Desktop.
Python util for saving file to disk concurrently in the background without blocking the main process
__author__ = "Elad Nachmias"
__email__ = "eladnah@gmail.com"
__date__ = "2023-10-19"
import threading
import queue
import time
import lzma
import pickle
from concurrent.futures import CancelledError
from typing import Any, NamedTuple, Set, Collection, Optional
LZM_FILE_EXTENSION = '.lzma'
LZM_PRESET = None
class ThreadSafeFuture:
"""
Concurrent.futures.Future and asyncio.Future are not thread-safe. That is, one cannot share them across several threads.
This class provides a custom future (adapts to the commonly known Future API) but which is thread-safe. One thread can set the
result and the other thread can wait for it.
"""
def __init__(self):
self._result = None
self._exception = None
self._cancelled = False
self._event = threading.Event()
def set_result(self, result):
self._result = result
self._event.set()
def set_exception(self, exception):
self._exception = exception
self._event.set()
def result(self):
self._event.wait()
if self._cancelled:
raise CancelledError("Future was cancelled.")
if self._exception is not None:
raise self._exception
return self._result
def exception(self) -> Optional[Exception]:
self._event.wait()
return self._exception
def done(self) -> bool:
return self._event.is_set()
def cancel(self):
self._cancelled = True
self._event.set()
def cancelled(self) -> bool:
return self._cancelled
@classmethod
def wait(cls, futures: Collection['ThreadSafeFuture']):
for future in futures:
future._event.wait()
class ConcurrentFileSaver:
"""
Allows storing files to disk concurrently (using threading), without blocking the main thread.
"""
_main: Optional['ConcurrentFileSaver'] = None
def __init__(self):
self._queue: queue.Queue[_ItemToBeStored] = queue.Queue()
self._threads_lock = threading.Lock()
self._should_stop_all = threading.Event()
self._should_stop_all_on_empty_queue = threading.Event()
self._threads: Set[threading.Thread] = set()
self._throughput_watchdog_thread: Optional[threading.Thread] = None
@classmethod
def get_main(cls) -> 'ConcurrentFileSaver':
"""
Use as a global static instance for process.
"""
if cls._main is None:
cls._main = cls()
return cls._main
def __del__(self):
self.stop()
# Threads are daemons - no need to wait for them with join(), we only tell them to terminate. The OS will make sure to kill them
# when the non-daemon threads exit.
del self._threads
del self._throughput_watchdog_thread
def enqueue_save_to_file(self, obj: Any, dst_path: str) -> Future:
"""
A non-blocking method to enqueue a file to be dumped to FS. Returning a future allows waiting for completion of concrete jobs
without having to stop the entire instance (and thus destruct it).
:param obj: the object to be dumped.
:param dst_path: the destination path to dump the given object to.
:return: A ThreadSafeFuture. Note: use ThreadSafeFuture.wait(...) to wait for a collection of such.
"""
assert not self._should_stop_all.is_set()
assert not self._should_stop_all_on_empty_queue.is_set()
future = Future()
self._queue.put(_ItemToBeStored(obj=obj, dst_path=dst_path, future=future))
self._verify_enough_threads()
self._verify_throughput_watchdog_thread()
return future
def stop(self):
self._should_stop_all.set()
while not self._queue.empty():
try:
item = self._queue.get(block=False)
item.future.cancel()
except queue.Empty:
continue
def stop_on_empty_queue(self):
"""
Continue running the worker threads until completion of all previously assigned jobs.
"""
self._should_stop_all_on_empty_queue.set()
def join(self):
while True:
with self._threads_lock:
threads = list(self._threads)
if self._throughput_watchdog_thread is not None:
threads.append(self._throughput_watchdog_thread)
if len(threads) == 0:
return
for thread in threads:
thread.join(timeout=10.)
if not thread.is_alive():
with self._threads_lock:
if thread in self._threads:
self._threads.remove(thread)
elif thread == self._throughput_watchdog_thread:
self._throughput_watchdog_thread = None
def _verify_enough_threads(self):
with self._threads_lock:
self._threads = {thread for thread in self._threads if thread.is_alive()}
if len(self._threads) > 0:
return
if self._queue.empty(): # this operation acquires a mutex - thus we call it outside of the threads lock to avoid deadlock
return
self._create_worker_threads(n=10)
def _verify_throughput_watchdog_thread(self):
with self._threads_lock:
if self._throughput_watchdog_thread is not None and not self._throughput_watchdog_thread.is_alive():
self._throughput_watchdog_thread = None
if self._throughput_watchdog_thread is None:
self._throughput_watchdog_thread = threading.Thread(target=self._throughput_watchdog_thread_main, daemon=True)
self._throughput_watchdog_thread.start()
def _create_worker_threads(self, n: int = 1):
for _ in range(n):
with self._threads_lock:
if self._should_stop_all.is_set():
return
new_thread = threading.Thread(target=self._worker_thread_main, daemon=True)
self._threads.add(new_thread)
new_thread.start()
def _worker_thread_main(self):
try:
while not self._should_stop_all.is_set() and threading.main_thread().is_alive() and \
(not self._should_stop_all_on_empty_queue.is_set() or not self._queue.empty()):
self._verify_throughput_watchdog_thread()
item = None
try:
try:
item = self._queue.get(timeout=30.)
except queue.Empty:
continue
self._save_object(file_path=item.dst_path, data=item.obj)
item.future.set_result(0) # mark as successfully completed so the user could see it
except Exception as e:
if item is not None:
item.future.set_exception(exception=e)
except:
pass
finally:
with self._threads_lock:
self._threads.remove(threading.current_thread())
def _save_object(file_path: str, data: Any):
"""
Compression cpu-bound operations block the entire process. However, IO-bound operation potentially
blocks the current thread only, but without acquiring the GIL and thus not blocking the other threads
in the current process.
"""
os.makedirs(os.path.dirname(file_name), exist_ok=True)
with lzma.open(file_path + LZM_FILE_EXTENSION, 'wb', preset=LZM_PRESET) as file:
pickle.dump(data, file)
def _throughput_watchdog_thread_main(self):
"""
Responsible for periodically verify that all the threads in the pool are still alive, and conditionally instantiating
new threads whenever needed.
"""
try:
while not self._should_stop_all.is_set() and threading.main_thread().is_alive() and \
(not self._should_stop_all_on_empty_queue.is_set() or not self._queue.empty()):
# TODO: track the produce vs consume throughput and create/remove threads dynamically. Maybe use `self._queue.qsize()`?
self._verify_enough_threads()
time.sleep(30.)
finally:
with self._threads_lock:
if self._throughput_watchdog_thread == threading.current_thread():
self._throughput_watchdog_thread = None
class _ItemToBeStored(NamedTuple):
obj: Any
dst_path: str
future: Future
def test_async_file_saver():
rng = np.random.RandomState(seed=0)
nr_items = 30
with tempfile.TemporaryDirectory() as tmp_output_dir_name:
saver = ConcurrentFileSaver()
items = [rng.random(size=(rng.randint(1, 200), rng.randint(1, 200))) for _ in range(nr_items)]
futures = []
for item_idx, item in enumerate(items):
future = saver.enqueue_save_to_file(
obj=item, dst_path=os.path.join(tmp_output_dir_name, f'item_{item_idx}.pkl'))
futures.append(future)
saver.stop_on_empty_queue()
saver.join()
assert all(future.done() and not future.cancelled() for future in futures)
loaded_items = []
for idx in range(nr_items):
loaded_items.append(_load_object(os.path.join(tmp_output_dir_name, f'item_{idx}.pkl')))
assert all(np.allclose(orig_item, loaded_item) for orig_item, loaded_item in zip(items, loaded_items))
def test_async_file_saver_stop_without_wait_for_completion():
rng = np.random.RandomState(seed=0)
nr_items = 30
with tempfile.TemporaryDirectory() as tmp_output_dir_name:
saver = ConcurrentFileSaver()
items = [rng.random(size=(rng.randint(1, 200), rng.randint(1, 200))) for _ in range(nr_items)]
futures = []
for item_idx, item in enumerate(items):
future = saver.enqueue_save_to_file(
obj=item, dst_path=os.path.join(tmp_output_dir_name, f'item_{item_idx}.pkl'))
futures.append(future)
saver.stop()
ThreadSafeFuture.wait(futures)
nr_finished = sum(future.done() and not future.cancelled() for future in futures)
nr_canceled = sum(future.cancelled() for future in futures)
assert nr_canceled + nr_finished == nr_items
# we assume that the dispatching the non-blocking and fast (reaching the stop() call before the completion).
assert nr_canceled > 0
def _load_object(file_path: str) -> Any:
with lzma.open(file_path + LZM_FILE_EXTENSION, 'rb', preset=LZM_PRESET) as file:
return pickle.load(file)
if __name__ == '__main__':
import os
import pickle
import tempfile
import numpy as np
test_async_file_saver()
test_async_file_saver_stop_without_wait_for_completion()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment