-
-
Save pveierland/029892f3d796afea4f429d1d762f7d8e to your computer and use it in GitHub Desktop.
Helper for tests with celery, running tasks in main thread but keeping control on when (alternative to eager)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
from celery.app.task import Task | |
from celery.app.utils import find_app | |
class CeleryTestTask(object): | |
"""A context manager to patch task in order to queue delayed tasks | |
and eventually run them. This is for tests. | |
You may access tasks and done_tasks in your test to verify which tasks | |
was triggered / runned. | |
This class is useful in case using eg. transaction hooks and atomic requests. | |
Simple usage, to run task sequentialy at end of request in a functional test:: | |
with CeleryTestTask(apps="my_app"): | |
self.client.post("/my-url", {"foo": "bar"}) | |
""" | |
def __init__(self, apps, run=True): | |
""" | |
:param list apps: name of celery apps to catch, | |
you can use a string for a single app | |
:param run: run queued tasks on exit. | |
Note that if tasks launch new tasks they will be queued and played. | |
""" | |
if isinstance(apps, str): | |
apps = [apps] | |
self.apps = apps | |
self.run = run | |
self.tasks = [] | |
self.done_tasks = [] | |
self.orig_apply_async = {} | |
def queue_task(self, task_class, args, kwargs): | |
self.tasks.append((task_class, args, kwargs)) | |
def __enter__(self): | |
# patch task | |
self.task_apply_async_orig = Task.apply_async | |
def apply_async_patch(task, args=None, kwargs=None, **options): | |
self.queue_task(task, args, kwargs) | |
# patch all tasks | |
for app in self.apps: | |
for name, task in find_app(app).tasks.items(): | |
task_class = task.__class__ | |
self.orig_apply_async.setdefault(task_class, task.apply_async) | |
task_class.apply_async = partial(apply_async_patch, task) | |
def __exit__(self, exc_type, exc_value, traceback): | |
if exc_type is None: | |
if self.run: | |
while self.tasks: | |
task, args, kwargs = self.tasks.pop(0) | |
task(*args, **kwargs) | |
self.done_tasks.append((task, args, kwargs)) | |
# unpatch | |
for task_class, orig_apply_async in self.orig_apply_async.items(): | |
task_class.apply_async = orig_apply_async |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment