Last active
January 23, 2022 19:19
-
-
Save adriangb/4769659899abd24f5d184332a2cdbee8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
import cProfile | |
import pstats | |
from collections import deque | |
from dataclasses import dataclass, field | |
from random import Random | |
from timeit import default_timer | |
from typing import Awaitable, Callable, Deque, Iterable, List, Optional, Protocol, Union | |
import asyncio | |
import anyio | |
class Job(Protocol): | |
async def __call__(self) -> Union[None, Iterable[Job]]: | |
... | |
async def worker(job: Job, send: Callable[[Optional[Job]], Awaitable[None]]) -> None: | |
newjobs = await job() | |
if newjobs is None: | |
await send(None) | |
return | |
for newjob in newjobs: | |
await send(newjob) | |
async def anyio_main(jobs: List[Job]) -> None: | |
send, receive = anyio.create_memory_object_stream(float("inf")) | |
async with send, receive: | |
for j in jobs: | |
await send.send(j) | |
async with anyio.create_task_group() as taskgroup: | |
while True: | |
newjob = await receive.receive() | |
if newjob is None: | |
return None | |
taskgroup.start_soon(worker, newjob, send.send) | |
async def asyncio_main(jobs: List[Job]) -> None: | |
queue: asyncio.Queue[Optional[Job]] = asyncio.Queue() | |
for j in jobs: | |
await queue.put(j) | |
unfinished: Deque[asyncio.Future[None]] = deque() | |
while True: | |
newjob = await queue.get() | |
if newjob is None: | |
await asyncio.gather(*unfinished) | |
return | |
unfinished.append(asyncio.create_task(worker(newjob, queue.put))) | |
unfinished = deque((fut for fut in unfinished if not fut.done())) | |
@dataclass | |
class JobGenerator(Job): | |
"""Deterministic job generator""" | |
cycles: int | |
max_jobs_per_cycle: int | |
count: int = 0 | |
sleep: float = 1e-3 | |
random: Random = field(default_factory=lambda: Random(42)) | |
async def __call__(self) -> Union[None, Iterable[Job]]: | |
if self.cycles <= 0: | |
return None | |
self.cycles -= 1 | |
await anyio.sleep(self.sleep) | |
return [self for _ in range(0, self.random.randint(1, self.max_jobs_per_cycle))] | |
async def time(func: Callable[[Iterable[Job]], Awaitable[None]], backend: str) -> None: | |
iters = 100 | |
times: List[float] = [] | |
for _ in range(iters): | |
seed = [JobGenerator(75, 15)] | |
start = default_timer() | |
await func(seed) | |
end = default_timer() | |
times.append(end-start) | |
print(f"{func.__name__} on {backend} took {sum(times)/len(times):e} sec/run") | |
async def profile(func: Callable[[Iterable[Job]], Awaitable[None]], backend: str) -> None: | |
seed = [JobGenerator(75, 15)] | |
profiler = cProfile.Profile() | |
profiler.enable() | |
await func(seed) | |
profiler.disable() | |
stats = pstats.Stats(profiler) | |
stats.dump_stats(filename=f"{func.__name__}_{backend}.prof") | |
if __name__ == "__main__": | |
anyio.run(time, asyncio_main, "asyncio") | |
anyio.run(time, anyio_main, "asyncio") | |
anyio.run(time, anyio_main, "trio", backend="trio") | |
anyio.run(profile, asyncio_main, "asyncio") | |
anyio.run(profile, anyio_main, "asyncio") | |
anyio.run(profile, anyio_main, "trio", backend="trio") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment