Last active
November 6, 2023 03:13
-
-
Save ThirVondukr/7bf7880ffbed1b445b4d573d51de8bd5 to your computer and use it in GitHub Desktop.
Redis/In-Memory pub-sub with multicast and client side routing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import contextlib | |
import random | |
import traceback | |
from collections import defaultdict | |
from collections.abc import AsyncIterator, Callable, Hashable | |
from types import TracebackType | |
from typing import ( | |
TYPE_CHECKING, | |
Generic, | |
Protocol, | |
Self, | |
TypeVar, | |
) | |
import anyio | |
import redis.asyncio as redis | |
from anyio.streams.stapled import StapledObjectStream | |
from pydantic import BaseModel | |
TValue = TypeVar("TValue") | |
TBaseModel = TypeVar("TBaseModel", bound=BaseModel) | |
if TYPE_CHECKING: | |
RedisClient = redis.Redis[bytes] | |
else: | |
RedisClient = redis.Redis | |
class UserEvent(BaseModel): | |
user_id: int | |
TUserEvent = TypeVar("TUserEvent", bound=UserEvent) | |
class MulticastReceiver(Generic[TValue]): | |
def __init__(self, event: asyncio.Event, container: list[TValue]) -> None: | |
self._event = event | |
self._container = container | |
def __aiter__(self) -> Self: | |
return self | |
async def __anext__(self) -> TValue: | |
return await self.recv() | |
async def recv(self) -> TValue: | |
await self._event.wait() | |
value = self._container.pop() | |
if not self._container: | |
self._event.clear() | |
return value | |
class MulticastStream(Generic[TValue]): | |
def __init__(self) -> None: | |
self._receivers: dict[asyncio.Event, list[TValue]] = {} | |
async def send(self, value: TValue) -> None: | |
for event, container in self._receivers.items(): | |
container.append(value) | |
event.set() | |
await asyncio.sleep(0) | |
@contextlib.asynccontextmanager | |
async def _get_event( | |
self, | |
) -> AsyncIterator[tuple[asyncio.Event, list[TValue]]]: | |
event = asyncio.Event() | |
self._receivers[event] = [] | |
yield event, self._receivers[event] | |
self._receivers.pop(event) | |
@contextlib.asynccontextmanager | |
async def recv(self) -> AsyncIterator[MulticastReceiver[TValue]]: | |
async with self._get_event() as (event, container): | |
yield MulticastReceiver(event, container) | |
class ExchangeTransport(Protocol[TValue]): | |
async def send(self, item: TValue) -> None: | |
... | |
def recv(self) -> AsyncIterator[TValue]: | |
... | |
async def close(self) -> None: | |
pass | |
class InMemoryExchangeTransport(ExchangeTransport[TValue]): | |
def __init__(self) -> None: | |
self._stream = StapledObjectStream(*anyio.create_memory_object_stream()) | |
async def send(self, item: TValue) -> None: | |
await self._stream.send(item) | |
def recv(self) -> AsyncIterator[TValue]: | |
return self._stream.receive_stream | |
class RedisExchangeTransport(ExchangeTransport[TBaseModel]): | |
def __init__( | |
self, | |
client: RedisClient, | |
channel: str, | |
model: type[TBaseModel], | |
) -> None: | |
self._client = client | |
self._channel = channel | |
self._model_cls = model | |
async def send(self, item: TBaseModel) -> None: | |
await self._client.publish( | |
channel=self._channel, | |
message=item.model_dump_json(), | |
) | |
async def recv(self) -> AsyncIterator[TBaseModel]: | |
async with self._client.pubsub() as pubsub: | |
await pubsub.subscribe(self._channel) | |
while True: | |
message = await pubsub.get_message( | |
ignore_subscribe_messages=True, | |
timeout=None, # type: ignore[arg-type] | |
) | |
if message: | |
yield self._model_cls.model_validate_json(message["data"]) | |
async def close(self) -> None: | |
await self._client.aclose() # type: ignore[attr-defined] | |
RoutingKey = TypeVar("RoutingKey", bound=Hashable) | |
RouteKeyRouter = Callable[[TValue], RoutingKey] | |
def user_id_key(event: TUserEvent) -> int: | |
return event.user_id | |
class EventExchange(Generic[RoutingKey, TValue]): | |
def __init__( | |
self, | |
transport: ExchangeTransport[TValue], | |
route_key: RouteKeyRouter[TValue, RoutingKey], | |
) -> None: | |
self._routing_key = route_key | |
self._transport = transport | |
self._streams: dict[Hashable, MulticastStream[object]] = defaultdict( | |
MulticastStream, | |
) | |
self._consumer_task: asyncio.Task[None] | None = None | |
async def _consume(self) -> None: | |
try: | |
async for message in self._transport.recv(): | |
key = self._routing_key(message) | |
await self._streams[key].send(message) | |
except Exception as e: | |
traceback.print_exception(e) | |
raise | |
async def __aenter__(self) -> Self: | |
self._consumer_task = asyncio.create_task(self._consume()) | |
return self | |
async def __aexit__( | |
self, | |
exc_type: type[BaseException] | None, | |
exc_val: BaseException | None, | |
exc_tb: TracebackType | None, | |
) -> None: | |
if self._consumer_task is not None: | |
self._consumer_task.cancel() | |
await self._transport.close() | |
def _get_receiver(self, key: RoutingKey) -> MulticastStream[TValue]: | |
return self._streams[key] # type: ignore[return-value] | |
async def publish(self, event: TValue) -> None: | |
await self._transport.send(event) | |
async def subscribe( | |
self, | |
routing_key: RoutingKey, | |
) -> AsyncIterator[TValue]: | |
stream = self._streams[routing_key] | |
async with stream.recv() as recv: | |
async for message in recv: | |
yield message # type: ignore[misc] | |
async def worker( | |
exchange: EventExchange[int, UserEvent], | |
test_duration: float, | |
num_clients: int, | |
) -> None: | |
try: | |
async with asyncio.timeout(test_duration): | |
while True: | |
await exchange.publish( | |
UserEvent(user_id=random.randint(0, num_clients)), | |
) | |
except asyncio.TimeoutError: | |
return | |
async def consumer( | |
exchange: EventExchange[int, UserEvent], | |
user_id: int, | |
test_duration: float, | |
) -> int: | |
count = 0 | |
try: | |
async with asyncio.timeout(test_duration): | |
async for _ in exchange.subscribe(user_id): | |
count += 1 | |
except asyncio.TimeoutError: | |
return count | |
raise NotImplementedError | |
async def main() -> None: | |
await asyncio.sleep(10) | |
transport = RedisExchangeTransport[UserEvent]( | |
client=RedisClient(host="redis", db=1), | |
channel="events", | |
model=UserEvent, | |
) | |
# transport = InMemoryExchangeTransport[UserEvent]() | |
test_duration = 60 | |
num_clients = 1000 | |
async with ( | |
EventExchange(transport=transport, route_key=user_id_key) as exchange, | |
asyncio.TaskGroup() as tg, | |
): | |
tg.create_task( | |
worker(exchange, test_duration=test_duration, num_clients=num_clients), | |
) | |
consumers = [ | |
tg.create_task(consumer(exchange, user_id=i, test_duration=test_duration)) | |
for i in range(num_clients) | |
] | |
values = [c.result() for c in consumers] | |
print(values) | |
print(sum(values) / len(consumers) / test_duration) | |
print(sum(values)) | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment