Created
May 12, 2021 22:16
-
-
Save fantix/01fbbf0ad57c3faf580952b66e7ea4d7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import functools | |
import signal as signal_mod | |
import warnings | |
def _release_waiter(waiter, *args): | |
if not waiter.done(): | |
waiter.set_result(None) | |
class SignalError(Exception): | |
def __init__(self, signo): | |
self.signo = signo | |
class SignalController: | |
_registry = {} | |
def __init__(self, *signals): | |
self._signals = signals | |
self._loop = asyncio.get_running_loop() | |
self._waiters = {} | |
def __enter__(self): | |
for signal in self._signals: | |
self._loop.add_signal_handler( | |
signal, self._signal_callback, signal | |
) | |
self._registry.setdefault(self._loop, {}).setdefault( | |
signal, set() | |
).add(self) | |
return self | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
if self._waiters: | |
warnings.warn( | |
"SignalController exited before wait_for() completed." | |
) | |
for signal in self._signals: | |
controllers = self._registry[self._loop][signal] | |
controllers.discard(self) | |
if not controllers: | |
del self._registry[self._loop][signal] | |
self._loop.remove_signal_handler(signal) | |
def _on_signal(self, signal): | |
for waiter in self._waiters.get(signal, []): | |
if not waiter.done(): | |
waiter.set_result(signal) | |
async def _cancel_and_wait(self, fut): | |
waiter = self._loop.create_future() | |
cb = functools.partial(_release_waiter, waiter) | |
fut.add_done_callback(cb) | |
try: | |
fut.cancel() | |
# We cannot wait on *fut* directly to make | |
# sure _cancel_and_wait itself is reliably cancellable. | |
await waiter | |
finally: | |
fut.remove_done_callback(cb) | |
def _create_waiter(self, signal): | |
waiter = self._loop.create_future() | |
self._register_waiter(signal, waiter) | |
return waiter | |
def _register_waiter(self, signal, waiter): | |
self._waiters.setdefault(signal, set()).add(waiter) | |
def _discard_waiter(self, signal, waiter): | |
waiters = self._waiters.get(signal) | |
if waiters: | |
waiters.discard(waiter) | |
if not waiters: | |
del self._waiters[signal] | |
async def wait_for(self, fut, *, cancel_on): | |
fut = asyncio.ensure_future(fut) | |
waiter = self._loop.create_future() | |
cb = functools.partial(_release_waiter, waiter) | |
fut.add_done_callback(cb) | |
for signal in cancel_on: | |
self._register_waiter(signal, waiter) | |
try: | |
try: | |
signal = await waiter | |
except asyncio.CancelledError: | |
if fut.done(): | |
return fut.result() | |
else: | |
fut.remove_done_callback(cb) | |
await self._cancel_and_wait(fut) | |
raise | |
if fut.done(): | |
return fut.result() | |
else: | |
fut.remove_done_callback(cb) | |
await self._cancel_and_wait(fut) | |
try: | |
fut.result() | |
except asyncio.CancelledError as exc: | |
raise SignalError(signal) from exc | |
else: | |
raise SignalError(signal) | |
finally: | |
for signal in cancel_on: | |
self._discard_waiter(signal, waiter) | |
async def wait_for_signals(self): | |
while True: | |
waiters = {} | |
for signal in self._signals: | |
waiters[signal] = self._create_waiter(signal) | |
try: | |
done, pending = await asyncio.wait(waiters.values()) | |
finally: | |
for signal, waiter in waiters.items(): | |
self._discard_waiter(signal, waiter) | |
for fut in done: | |
yield fut.result() | |
@classmethod | |
def _signal_callback(cls, signal): | |
registry = cls._registry.get(asyncio.get_running_loop()) | |
if not registry: | |
return | |
controllers = registry.get(signal) | |
if not controllers: | |
return | |
for controller in controllers: | |
controller._on_signal(signal) | |
async def main(): | |
with SignalController(signal_mod.SIGUSR1) as sc: | |
try: | |
await sc.wait_for( | |
asyncio.sleep(3600), cancel_on={signal_mod.SIGUSR1} | |
) | |
except SignalError as e: | |
print(e.signo) | |
async for signal in sc.wait_for_signals(): | |
print(signal) | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment