Created
October 19, 2023 08:45
-
-
Save eladn/1bc4f2eb1880e01c52cf9795c3610733 to your computer and use it in GitHub Desktop.
Python util for saving file to disk concurrently in the background without blocking the main process
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
__author__ = "Elad Nachmias" | |
__email__ = "eladnah@gmail.com" | |
__date__ = "2023-10-19" | |
import threading | |
import queue | |
import time | |
import lzma | |
import pickle | |
from concurrent.futures import CancelledError | |
from typing import Any, NamedTuple, Set, Collection, Optional | |
LZM_FILE_EXTENSION = '.lzma' | |
LZM_PRESET = None | |
class ThreadSafeFuture: | |
""" | |
Concurrent.futures.Future and asyncio.Future are not thread-safe. That is, one cannot share them across several threads. | |
This class provides a custom future (adapts to the commonly known Future API) but which is thread-safe. One thread can set the | |
result and the other thread can wait for it. | |
""" | |
def __init__(self): | |
self._result = None | |
self._exception = None | |
self._cancelled = False | |
self._event = threading.Event() | |
def set_result(self, result): | |
self._result = result | |
self._event.set() | |
def set_exception(self, exception): | |
self._exception = exception | |
self._event.set() | |
def result(self): | |
self._event.wait() | |
if self._cancelled: | |
raise CancelledError("Future was cancelled.") | |
if self._exception is not None: | |
raise self._exception | |
return self._result | |
def exception(self) -> Optional[Exception]: | |
self._event.wait() | |
return self._exception | |
def done(self) -> bool: | |
return self._event.is_set() | |
def cancel(self): | |
self._cancelled = True | |
self._event.set() | |
def cancelled(self) -> bool: | |
return self._cancelled | |
@classmethod | |
def wait(cls, futures: Collection['ThreadSafeFuture']): | |
for future in futures: | |
future._event.wait() | |
class ConcurrentFileSaver: | |
""" | |
Allows storing files to disk concurrently (using threading), without blocking the main thread. | |
""" | |
_main: Optional['ConcurrentFileSaver'] = None | |
def __init__(self): | |
self._queue: queue.Queue[_ItemToBeStored] = queue.Queue() | |
self._threads_lock = threading.Lock() | |
self._should_stop_all = threading.Event() | |
self._should_stop_all_on_empty_queue = threading.Event() | |
self._threads: Set[threading.Thread] = set() | |
self._throughput_watchdog_thread: Optional[threading.Thread] = None | |
@classmethod | |
def get_main(cls) -> 'ConcurrentFileSaver': | |
""" | |
Use as a global static instance for process. | |
""" | |
if cls._main is None: | |
cls._main = cls() | |
return cls._main | |
def __del__(self): | |
self.stop() | |
# Threads are daemons - no need to wait for them with join(), we only tell them to terminate. The OS will make sure to kill them | |
# when the non-daemon threads exit. | |
del self._threads | |
del self._throughput_watchdog_thread | |
def enqueue_save_to_file(self, obj: Any, dst_path: str) -> Future: | |
""" | |
A non-blocking method to enqueue a file to be dumped to FS. Returning a future allows waiting for completion of concrete jobs | |
without having to stop the entire instance (and thus destruct it). | |
:param obj: the object to be dumped. | |
:param dst_path: the destination path to dump the given object to. | |
:return: A ThreadSafeFuture. Note: use ThreadSafeFuture.wait(...) to wait for a collection of such. | |
""" | |
assert not self._should_stop_all.is_set() | |
assert not self._should_stop_all_on_empty_queue.is_set() | |
future = Future() | |
self._queue.put(_ItemToBeStored(obj=obj, dst_path=dst_path, future=future)) | |
self._verify_enough_threads() | |
self._verify_throughput_watchdog_thread() | |
return future | |
def stop(self): | |
self._should_stop_all.set() | |
while not self._queue.empty(): | |
try: | |
item = self._queue.get(block=False) | |
item.future.cancel() | |
except queue.Empty: | |
continue | |
def stop_on_empty_queue(self): | |
""" | |
Continue running the worker threads until completion of all previously assigned jobs. | |
""" | |
self._should_stop_all_on_empty_queue.set() | |
def join(self): | |
while True: | |
with self._threads_lock: | |
threads = list(self._threads) | |
if self._throughput_watchdog_thread is not None: | |
threads.append(self._throughput_watchdog_thread) | |
if len(threads) == 0: | |
return | |
for thread in threads: | |
thread.join(timeout=10.) | |
if not thread.is_alive(): | |
with self._threads_lock: | |
if thread in self._threads: | |
self._threads.remove(thread) | |
elif thread == self._throughput_watchdog_thread: | |
self._throughput_watchdog_thread = None | |
def _verify_enough_threads(self): | |
with self._threads_lock: | |
self._threads = {thread for thread in self._threads if thread.is_alive()} | |
if len(self._threads) > 0: | |
return | |
if self._queue.empty(): # this operation acquires a mutex - thus we call it outside of the threads lock to avoid deadlock | |
return | |
self._create_worker_threads(n=10) | |
def _verify_throughput_watchdog_thread(self): | |
with self._threads_lock: | |
if self._throughput_watchdog_thread is not None and not self._throughput_watchdog_thread.is_alive(): | |
self._throughput_watchdog_thread = None | |
if self._throughput_watchdog_thread is None: | |
self._throughput_watchdog_thread = threading.Thread(target=self._throughput_watchdog_thread_main, daemon=True) | |
self._throughput_watchdog_thread.start() | |
def _create_worker_threads(self, n: int = 1): | |
for _ in range(n): | |
with self._threads_lock: | |
if self._should_stop_all.is_set(): | |
return | |
new_thread = threading.Thread(target=self._worker_thread_main, daemon=True) | |
self._threads.add(new_thread) | |
new_thread.start() | |
def _worker_thread_main(self): | |
try: | |
while not self._should_stop_all.is_set() and threading.main_thread().is_alive() and \ | |
(not self._should_stop_all_on_empty_queue.is_set() or not self._queue.empty()): | |
self._verify_throughput_watchdog_thread() | |
item = None | |
try: | |
try: | |
item = self._queue.get(timeout=30.) | |
except queue.Empty: | |
continue | |
self._save_object(file_path=item.dst_path, data=item.obj) | |
item.future.set_result(0) # mark as successfully completed so the user could see it | |
except Exception as e: | |
if item is not None: | |
item.future.set_exception(exception=e) | |
except: | |
pass | |
finally: | |
with self._threads_lock: | |
self._threads.remove(threading.current_thread()) | |
def _save_object(file_path: str, data: Any): | |
""" | |
Compression cpu-bound operations block the entire process. However, IO-bound operation potentially | |
blocks the current thread only, but without acquiring the GIL and thus not blocking the other threads | |
in the current process. | |
""" | |
os.makedirs(os.path.dirname(file_name), exist_ok=True) | |
with lzma.open(file_path + LZM_FILE_EXTENSION, 'wb', preset=LZM_PRESET) as file: | |
pickle.dump(data, file) | |
def _throughput_watchdog_thread_main(self): | |
""" | |
Responsible for periodically verify that all the threads in the pool are still alive, and conditionally instantiating | |
new threads whenever needed. | |
""" | |
try: | |
while not self._should_stop_all.is_set() and threading.main_thread().is_alive() and \ | |
(not self._should_stop_all_on_empty_queue.is_set() or not self._queue.empty()): | |
# TODO: track the produce vs consume throughput and create/remove threads dynamically. Maybe use `self._queue.qsize()`? | |
self._verify_enough_threads() | |
time.sleep(30.) | |
finally: | |
with self._threads_lock: | |
if self._throughput_watchdog_thread == threading.current_thread(): | |
self._throughput_watchdog_thread = None | |
class _ItemToBeStored(NamedTuple): | |
obj: Any | |
dst_path: str | |
future: Future | |
def test_async_file_saver(): | |
rng = np.random.RandomState(seed=0) | |
nr_items = 30 | |
with tempfile.TemporaryDirectory() as tmp_output_dir_name: | |
saver = ConcurrentFileSaver() | |
items = [rng.random(size=(rng.randint(1, 200), rng.randint(1, 200))) for _ in range(nr_items)] | |
futures = [] | |
for item_idx, item in enumerate(items): | |
future = saver.enqueue_save_to_file( | |
obj=item, dst_path=os.path.join(tmp_output_dir_name, f'item_{item_idx}.pkl')) | |
futures.append(future) | |
saver.stop_on_empty_queue() | |
saver.join() | |
assert all(future.done() and not future.cancelled() for future in futures) | |
loaded_items = [] | |
for idx in range(nr_items): | |
loaded_items.append(_load_object(os.path.join(tmp_output_dir_name, f'item_{idx}.pkl'))) | |
assert all(np.allclose(orig_item, loaded_item) for orig_item, loaded_item in zip(items, loaded_items)) | |
def test_async_file_saver_stop_without_wait_for_completion(): | |
rng = np.random.RandomState(seed=0) | |
nr_items = 30 | |
with tempfile.TemporaryDirectory() as tmp_output_dir_name: | |
saver = ConcurrentFileSaver() | |
items = [rng.random(size=(rng.randint(1, 200), rng.randint(1, 200))) for _ in range(nr_items)] | |
futures = [] | |
for item_idx, item in enumerate(items): | |
future = saver.enqueue_save_to_file( | |
obj=item, dst_path=os.path.join(tmp_output_dir_name, f'item_{item_idx}.pkl')) | |
futures.append(future) | |
saver.stop() | |
ThreadSafeFuture.wait(futures) | |
nr_finished = sum(future.done() and not future.cancelled() for future in futures) | |
nr_canceled = sum(future.cancelled() for future in futures) | |
assert nr_canceled + nr_finished == nr_items | |
# we assume that the dispatching the non-blocking and fast (reaching the stop() call before the completion). | |
assert nr_canceled > 0 | |
def _load_object(file_path: str) -> Any: | |
with lzma.open(file_path + LZM_FILE_EXTENSION, 'rb', preset=LZM_PRESET) as file: | |
return pickle.load(file) | |
if __name__ == '__main__': | |
import os | |
import pickle | |
import tempfile | |
import numpy as np | |
test_async_file_saver() | |
test_async_file_saver_stop_without_wait_for_completion() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment