Created
February 15, 2023 21:59
-
-
Save g-simmons/d01db39bbb6aa69452db8e4c66efc1b0 to your computer and use it in GitHub Desktop.
Rich progress bars for Joblib parallel tasks
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from joblib import Parallel, delayed | |
from threading import Thread | |
from rich.progress import Progress, BarColumn, TimeRemainingColumn, TextColumn | |
from rich.console import Console | |
from rich.live import Live | |
import time | |
# Define the number of tasks and create a shared memory numpy array to hold their progress | |
num_tasks = 4 | |
progress_array = np.memmap( | |
"progress.mmap2", dtype=np.float32, mode="w+", shape=num_tasks | |
) | |
# Define a function that performs a task and updates the progress array | |
def perform_task(task_idx, progress_array): | |
for i in range(100): | |
# Do some work here | |
# ... | |
# Update the progress array | |
time.sleep(0.1) | |
progress_array[task_idx] = i / 100 | |
# Update the progress array to 100% on completion | |
progress_array[task_idx] = 1 | |
# Create a console for the Rich progress bar | |
console = Console() | |
# Define a function to continuously update the Rich progress bar | |
def update_progress_bar( | |
progress_array=progress_array, | |
console=console, | |
num_tasks=num_tasks, | |
): | |
with Live( | |
refresh_per_second=4, | |
transient=True, | |
console=console, | |
): | |
with Progress( | |
BarColumn(), | |
TextColumn("[bold green]{task.fields[status]}"), | |
TextColumn("[bold blue]{task.fields[name]}"), | |
TimeRemainingColumn(), | |
# console=console, | |
) as progress: | |
tasks = [ | |
progress.add_task( | |
description=f"Task {i}", | |
name=f"Task {i}", | |
status="pending", | |
total=100, | |
) | |
for i in range(num_tasks) | |
] | |
while not all(progress_array == 1): | |
for i, task in enumerate(tasks): | |
progress.update(task, completed=int(progress_array[i] * 100)) | |
time.sleep(0.1) | |
# Launch the progress bar update function in a separate thread | |
Thread(target=update_progress_bar, args=[progress_array, console, num_tasks]).start() | |
# Launch the tasks in parallel using joblib and the perform_task function | |
Parallel(n_jobs=-8, backend="loky")( | |
delayed(perform_task)(i, progress_array) for i in range(num_tasks) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment