Skip to content

Instantly share code, notes, and snippets.

@glenn-jocher
Last active September 23, 2024 04:12
Show Gist options
  • Save glenn-jocher/4246190a893cb22e2d5d169fe43e90d7 to your computer and use it in GitHub Desktop.
Save glenn-jocher/4246190a893cb22e2d5d169fe43e90d7 to your computer and use it in GitHub Desktop.
Threading vs Multiprocessing benchmarks for YOLOv8
# M3 Macbook Air results:
# Average Threading Time: 0.38 seconds
# Average Multiprocessing Time: 2.39 seconds
# Average Concurrent Futures (ThreadPool) Time: 0.34 seconds
# Average Concurrent Futures (ProcessPool) Time: 2.39 seconds
import time
import torch
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from multiprocessing import Process, set_start_method
from threading import Thread
from ultralytics import YOLO, ASSETS
def predict(model_path, image_path):
"""Performs prediction on an image using a YOLO model."""
model = YOLO(model_path)
results = model.predict(image_path)
# Process results (in this case, we'll just pass to keep the benchmark focused on prediction time)
def run_thread_benchmark(model_path, image_paths, num_threads):
start_time = time.time()
threads = []
for i in range(num_threads):
t = Thread(target=predict, args=(model_path, image_paths[i % len(image_paths)]))
t.start()
threads.append(t)
for t in threads:
t.join()
return time.time() - start_time
def run_process_benchmark(model_path, image_paths, num_processes):
start_time = time.time()
processes = []
for i in range(num_processes):
p = Process(target=predict, args=(model_path, image_paths[i % len(image_paths)]))
p.start()
processes.append(p)
for p in processes:
p.join()
return time.time() - start_time
def run_concurrent_futures_benchmark(model_path, image_paths, num_workers, use_processes=False):
start_time = time.time()
executor_class = ProcessPoolExecutor if use_processes else ThreadPoolExecutor
with executor_class(max_workers=num_workers) as executor:
futures = [executor.submit(predict, model_path, image_paths[i % len(image_paths)]) for i in range(num_workers)]
for future in futures:
future.result()
return time.time() - start_time
def main():
if torch.cuda.is_available():
set_start_method("spawn", force=True)
model_path = "yolov8n.pt"
image_paths = [ASSETS / "zidane.jpg", ASSETS / "bus.jpg", ASSETS / "zidane.jpg", ASSETS / "bus.jpg"]
num_runs = 5
num_workers = len(image_paths)
print(f"Running benchmarks with {num_workers} workers...")
thread_times = []
process_times = []
concurrent_thread_times = []
concurrent_process_times = []
for _ in range(num_runs):
thread_time = run_thread_benchmark(model_path, image_paths, num_workers)
thread_times.append(thread_time)
print(f"Threading run completed in {thread_time:.2f} seconds")
process_time = run_process_benchmark(model_path, image_paths, num_workers)
process_times.append(process_time)
print(f"Multiprocessing run completed in {process_time:.2f} seconds")
concurrent_thread_time = run_concurrent_futures_benchmark(
model_path, image_paths, num_workers, use_processes=False
)
concurrent_thread_times.append(concurrent_thread_time)
print(f"Concurrent Futures (ThreadPool) run completed in {concurrent_thread_time:.2f} seconds")
concurrent_process_time = run_concurrent_futures_benchmark(
model_path, image_paths, num_workers, use_processes=True
)
concurrent_process_times.append(concurrent_process_time)
print(f"Concurrent Futures (ProcessPool) run completed in {concurrent_process_time:.2f} seconds")
avg_thread_time = sum(thread_times) / num_runs
avg_process_time = sum(process_times) / num_runs
avg_concurrent_thread_time = sum(concurrent_thread_times) / num_runs
avg_concurrent_process_time = sum(concurrent_process_times) / num_runs
print(f"\nAverage Threading Time: {avg_thread_time:.2f} seconds")
print(f"Average Multiprocessing Time: {avg_process_time:.2f} seconds")
print(f"Average Concurrent Futures (ThreadPool) Time: {avg_concurrent_thread_time:.2f} seconds")
print(f"Average Concurrent Futures (ProcessPool) Time: {avg_concurrent_process_time:.2f} seconds")
times = [avg_thread_time, avg_process_time, avg_concurrent_thread_time, avg_concurrent_process_time]
methods = ["Threading", "Multiprocessing", "Concurrent Futures (ThreadPool)", "Concurrent Futures (ProcessPool)"]
best_method = methods[times.index(min(times))]
print(f"\n{best_method} performed best in this benchmark.")
if __name__ == "__main__":
main()
@lakshanthad
Copy link

lakshanthad commented Sep 16, 2024

@glenn-jocher We might need to add the below update to test on CUDA devices as well. Otherwise has this error:

RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

Updates:

from multiprocessing import Process, set_start_method
import torch
........
def main():
    if torch.cuda.is_available():
        set_start_method("spawn", force=True)

EDIT:

5700X, RTX3060, Ubuntu 22.04 results:

# Average Threading Time: 0.59 seconds
# Average Multiprocessing Time: 4.01 seconds
# Average Concurrent Futures (ThreadPool) Time: 0.36 seconds
# Average Concurrent Futures (ProcessPool) Time: 4.01 seconds

@glenn-jocher
Copy link
Author

Oh interesting, I haven't actually tried this in a CUDA environment.

Code updated now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment