Skip to content

Instantly share code, notes, and snippets.

@wmayner
Last active May 11, 2022 19:05
Show Gist options
  • Save wmayner/be88c0cb6f8fb3708e8126774e86ce95 to your computer and use it in GitHub Desktop.
Save wmayner/be88c0cb6f8fb3708e8126774e86ce95 to your computer and use it in GitHub Desktop.
Map a Ray remote function, returning early if a particular value is found
import functools
import time
import ray
ray.init()
def as_completed(object_refs, num_returns=1):
"""Yield remote results in order of completion."""
unfinished = object_refs
while unfinished:
finished, unfinished = ray.wait(unfinished, num_returns=num_returns)
yield from ray.get(finished)
@functools.wraps(ray.cancel)
def cancel_all(object_refs, *args, **kwargs):
"""Cancel all remote tasks."""
for ref in object_refs:
ray.cancel(ref, *args, **kwargs)
return object_refs
def shortcircuit(
items,
shortcircuit_value=None,
shortcircuit_callback=None,
shortcircuit_callback_args=None,
):
"""Yield from an iterable, stopping early if a certain value is found."""
for result in items:
yield result
if result == shortcircuit_value:
if shortcircuit_callback:
shortcircuit_callback(shortcircuit_callback_args)
return
def map_remote_shortcircuit(
func,
*arg_lists,
shortcircuit_value=None,
shortcircuit_callback=cancel_all,
shortcircuit_callback_args=None,
**kwargs,
):
"""
Map a remote function to some arguments, returning early if a particular value is found.
By default, all remaining tasks are canceled.
"""
tasks = [func.remote(*args, **kwargs) for args in zip(*arg_lists)]
# Default to passing object_refs to the callback
if shortcircuit_callback_args is None:
shortcircuit_callback_args = tasks
return shortcircuit(
as_completed(tasks),
shortcircuit_value=shortcircuit_value,
shortcircuit_callback=shortcircuit_callback,
shortcircuit_callback_args=shortcircuit_callback_args,
)
# Test
@ray.remote
def f(x):
t = x
if not x:
t = 10
time.sleep(t / 10)
return x
shortcircuit_value = 0
args = list(range(1, 100)) + [shortcircuit_value]
print(list(map_remote_shortcircuit(f, args, shortcircuit_value=shortcircuit_value)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment