Last active
April 27, 2019 10:00
-
-
Save costrouc/44b5dbcd7783c9b0fb3a95f065f577d4 to your computer and use it in GitHub Desktop.
Asyncio Majordomo Protocol (18/MDP)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import logging | |
import zmq | |
import zmq.asyncio | |
from scheduler import SchedulerCode | |
class Client: | |
DEFAULT_PROTOCOL = "tcp" | |
DEFAULT_PORT = 8000 | |
DEFAULT_HOSTNAME = '0.0.0.0' | |
def __init__(self, protocol=DEFAULT_PROTOCOL, port=DEFAULT_PORT, hostname=DEFAULT_HOSTNAME, loop=None): | |
self.loop = loop or asyncio.get_event_loop() | |
self.logger = logging.getLogger('mdp.client') | |
self.context = zmq.asyncio.Context() | |
self.socket = self.context.socket(zmq.DEALER) | |
self.uri = f'{protocol}://{hostname}:{port}' | |
self.logger.info(f'Connecting ZMQ socket to {self.uri}') | |
self.socket.connect(self.uri) | |
async def submit(self, service, message): | |
self.logger.debug(f'sending message to service {service}') | |
await self.socket.send_multipart([ | |
b'', SchedulerCode.CLIENT, service, *message | |
]) | |
async def get(self): | |
multipart_message = await self.socket.recv_multipart() | |
_, _, service, *message = multipart_message | |
self.logger.debug(f'recieving message from service {service}') | |
return multipart_message[2], multipart_message[3:] | |
def disconnect(self): | |
self.socket.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
from multiprocessing import Process, Event | |
import asyncio | |
import zmq.asyncio | |
from client import Client | |
from worker import Worker | |
from scheduler import Scheduler | |
def init_logging(): | |
logging.basicConfig(level=logging.INFO) | |
def init_event_loop(): | |
loop = zmq.asyncio.ZMQEventLoop() | |
asyncio.set_event_loop(loop) | |
return loop | |
def init_client(stop_event): | |
init_logging() | |
async def create_work(loop): | |
client = Client(loop=loop) | |
N = 10 | |
MSG_SIZE = 1_000 | |
for i in range(N): | |
if i % 1 == 0: | |
print(f'[ Client ] {i+1} jobs submitted') | |
await client.submit(b'hello.world', [b'o'*MSG_SIZE]) | |
for i in range(N): | |
service, message = await client.get() | |
if i % 1 == 0: | |
print(f'[ Client ] {i+1} jobs completed') | |
print('[ Client ] === DONE ===') | |
client.disconnect() | |
loop = init_event_loop() | |
loop.run_until_complete(create_work(loop)) | |
def init_scheduler(stop_event): | |
init_logging() | |
scheduler = Scheduler(stop_event, loop=init_event_loop()) | |
scheduler.run() | |
def init_worker(stop_event): | |
counter = 0 | |
async def hello_world_worker(*message): | |
nonlocal counter | |
counter += 1 | |
print(f'[ Worker ] processing message {counter}') | |
return (b'1', b'2', b'3') | |
init_logging() | |
loop = init_event_loop() | |
worker = Worker(stop_event, loop=loop) | |
loop.run_until_complete(worker.run(b'hello.world', hello_world_worker)) | |
if __name__ == "__main__": | |
NUM_CLIENTS = 1 | |
NUM_WORKERS = 1 | |
stop_event = Event() | |
worker_processes = [Process(target=init_worker, args=(stop_event,)) for _ in range(NUM_WORKERS)] | |
for worker in worker_processes: | |
worker.start() | |
scheduler_process = Process(target=init_scheduler, args=(stop_event,)) | |
scheduler_process.start() | |
client_processes = [Process(target=init_client, args=(stop_event,)) for _ in range(NUM_CLIENTS)] | |
for client in client_processes: | |
client.start() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import concurrent | |
import collections | |
import logging | |
import datetime as dt | |
import uuid | |
import zmq | |
import zmq.asyncio | |
class Message: | |
def __init__(self, client_id, message): | |
self.date_added = dt.datetime.utcnow() | |
self.client_id = client_id | |
self.message = message | |
class SchedulerCode: | |
WORKER = b"MDPW01" | |
CLIENT = b"MDPC01" | |
READY = bytes([1]) | |
REQUEST = bytes([2]) | |
REPLY = bytes([3]) | |
HEARTBEAT = bytes([4]) | |
DISCONNECT = bytes([5]) | |
class Scheduler: | |
DEFAULT_PROTOCOL = "tcp" | |
DEFAULT_PORT = 8000 | |
DEFAULT_HOSTNAME = '0.0.0.0' | |
def __init__(self, stop_event, protocol=DEFAULT_PROTOCOL, port=DEFAULT_PORT, hostname=DEFAULT_HOSTNAME, loop=None): | |
self.stop_event = stop_event | |
self.loop = loop or asyncio.get_event_loop() | |
self.logger = logging.getLogger('mdp.scheduler') | |
self.context = zmq.asyncio.Context() | |
self.socket = self.context.socket(zmq.ROUTER) | |
self.logger.info(f'Binding ZMQ socket client to {protocol}://{hostname}:{port}') | |
self.socket.bind(f'{protocol}://{hostname}:{port}') | |
self.messages = {} | |
self.workers = {} | |
self.services = collections.defaultdict(lambda: {'workers': set(), 'queue': asyncio.Queue(), 'task': None}) | |
async def _handle_client_message(self, client_id, multipart_message): | |
service, *message_data = multipart_message | |
self.logger.debug(f'adding client {client_id} message for service {service} to queue') | |
message_uuid = uuid.uuid4().bytes | |
message = Message(client_id, message_data) | |
self.messages[message_uuid] = message | |
await self.services[service]['queue'].put(message_uuid) | |
async def _next_worker(self, service): | |
import random | |
return random.sample(service['workers'], 1)[0] # scheduling logic | |
# return service['workers'][0] | |
async def _handle_service_queue(self, service): | |
counter = 0 | |
try: | |
while True: | |
message_uuid = await service['queue'].get() | |
message = self.messages[message_uuid] | |
worker_id = await self._next_worker(service) | |
counter += 1 | |
print(f'[Scheduler] Count {counter} sent to worker {worker_id}') | |
self.workers[worker_id]['messages'].add(message_uuid) | |
await self.socket.send_multipart([ | |
worker_id, b'', SchedulerCode.WORKER, SchedulerCode.REQUEST, | |
message_uuid, b'', *message.message | |
]) | |
service['queue'].task_done() | |
except asyncio.CancelledError: | |
self.logger.info('stopping worker for service') | |
async def _handle_worker_message(self, worker_id, multipart_message): | |
message_type = multipart_message[0] | |
if message_type == SchedulerCode.READY: | |
service_name = multipart_message[1] | |
service = self.services[service_name] | |
self.logger.info(f'adding worker {worker_id} for service {service_name}') | |
self.workers[worker_id] = {'service': service_name, 'messages': set()} | |
service['workers'].add(worker_id) | |
if len(service['workers']) == 1: | |
service['task'] = asyncio.ensure_future(self._handle_service_queue(service)) | |
elif message_type == SchedulerCode.REPLY: | |
message_uuid = multipart_message[1] | |
self.workers[worker_id]['messages'].remove(message_uuid) | |
message = self.messages.pop(message_uuid) | |
self.logger.debug(f'sending client {message.client_id} message response from worker {worker_id}') | |
print(f'[Scheduler] message done from worker {worker_id} for client {message.client_id}') | |
await self.socket.send_multipart([ | |
message.client_id, b'', SchedulerCode.CLIENT, | |
self.workers[worker_id]['service'], *multipart_message[3:] | |
]) | |
elif message_type == SchedulerCode.HEARTBEAT: | |
self.logger.debug('responding with heartbeat') | |
await self.socket.send_multipart([ | |
worker_id, b'', SchedulerCode.WORKER, SchedulerCode.HEARTBEAT | |
]) | |
elif message_type == SchedulerCode.DISCONNECT: | |
if worker_id in self.workers: | |
worker = self.workers[worker_id] | |
service = self.services[worker['service']] | |
if len(service['workers']) == 1: # last worker | |
self.logger.info(f'canceling {worker["service"]} service queue task') | |
service['task'].cancel() | |
try: | |
await service['task'] | |
except concurrent.futures.CancelledError: | |
pass | |
service['task'] = None | |
self.logger.info(f'removing worker {worker_id} for service {worker["service"]} - rescheduling {len(worker["messages"])} messages') | |
service['workers'].remove(worker_id) | |
for message in worker['messages']: | |
await service['queue'].put(message) | |
self.workers.pop(worker_id) | |
async def on_recv_message(self): | |
while not self.stop_event.is_set(): | |
multipart_message = await self.socket.recv_multipart() | |
client_id, _1, message_sender, *message = multipart_message | |
if message_sender == SchedulerCode.WORKER: | |
await self._handle_worker_message(client_id, message) | |
elif message_sender == SchedulerCode.CLIENT: | |
await self._handle_client_message(client_id, message) | |
else: | |
raise ValueError() | |
def run(self): | |
self.loop.run_until_complete(self.on_recv_message()) | |
def disconnect(self): | |
self.stop_event.set() | |
self.socket.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import logging | |
import datetime as dt | |
import zmq | |
import zmq.asyncio | |
from scheduler import SchedulerCode | |
class Worker: | |
DEFAULT_PROTOCOL = "tcp" | |
DEFAULT_PORT = 8000 | |
DEFAULT_HOSTNAME = '0.0.0.0' | |
def __init__(self, stop_event, | |
heartbeat_interval=2, heartbeat_timeout=10, | |
protocol=DEFAULT_PROTOCOL, port=DEFAULT_PORT, hostname=DEFAULT_HOSTNAME, loop=None): | |
self.stop_event = stop_event | |
self.loop = loop or asyncio.get_event_loop() | |
self.logger = logging.getLogger('mdp.worker') | |
self.context = zmq.asyncio.Context() | |
self.socket = self.context.socket(zmq.DEALER) | |
self.uri = f'{protocol}://{hostname}:{port}' | |
self.heartbeat_interval = heartbeat_interval | |
self.heartbeat_timeout = heartbeat_timeout | |
self.heartbeat_last_response = dt.datetime.utcnow() | |
self.service = None | |
self.service_handler = None | |
self.queued_messages = asyncio.Queue() | |
async def _handle_send_heartbeat(self): | |
while not self.stop_event.is_set(): | |
if not self.socket.closed: | |
self.logger.debug('sending heartbeat') | |
await self.socket.send_multipart([ | |
b'', SchedulerCode.WORKER, SchedulerCode.HEARTBEAT | |
]) | |
await asyncio.sleep(self.heartbeat_interval) | |
async def _handle_check_heartbeat(self): | |
while not self.stop_event.is_set(): | |
previous_heartbeat_check = dt.datetime.utcnow() | |
await asyncio.sleep(self.heartbeat_timeout) | |
if not self.socket.closed and \ | |
self.heartbeat_last_response < previous_heartbeat_check: | |
self.logger.info(f'no response from broker in {self.heartbeat_timeout} seconds -- reconnecting') | |
await self.disconnect() | |
await self.connect() | |
async def _handle_queued_messages(self): | |
counter = 0 | |
while not self.stop_event.is_set(): | |
client_id, message = await self.queued_messages.get() | |
result = await self.service_handler(*message) | |
counter += 1 | |
print(f'[ Worker ] Counter {counter:5} completed Queue size: {self.queued_messages.qsize():5}') | |
await self.socket.send_multipart([ | |
b'', SchedulerCode.WORKER, SchedulerCode.REPLY, client_id, b'', *result | |
]) | |
self.queued_messages.task_done() | |
async def _on_recv_message(self): | |
while not self.stop_event.is_set(): | |
multipart_message = await self.socket.recv_multipart() | |
message_type = multipart_message[2] | |
if message_type == SchedulerCode.REQUEST: | |
_, _, message_type, client_id, _, *message = multipart_message | |
self.logger.debug(f'broker sent request message') | |
await self.queued_messages.put((client_id, message)) | |
self.heartbeat_last_response = dt.datetime.utcnow() | |
elif message_type == SchedulerCode.HEARTBEAT: | |
self.logger.debug(f'broker response heartbeat') | |
self.heartbeat_last_response = dt.datetime.utcnow() | |
elif message_type == SchedulerCode.DISCONNECT: | |
self.logger.info(f'broker requests disconnect and reconnect') | |
await self.disconnect() | |
await self.connect() | |
else: | |
raise ValueError() # unknown event type | |
async def run(self, service, service_handler): | |
self.service = service | |
self.service_handler = service_handler | |
await self.connect() | |
await asyncio.gather( | |
self._handle_send_heartbeat(), | |
self._handle_check_heartbeat(), | |
self._handle_queued_messages(), | |
self._on_recv_message() | |
) | |
await self.disconnect() | |
async def connect(self): | |
self.logger.info(f'connecting ZMQ socket to {self.uri}') | |
self.socket.connect(self.uri) | |
await self.socket.send_multipart([ | |
b'', SchedulerCode.WORKER, SchedulerCode.READY, self.service | |
]) | |
async def disconnect(self): | |
self.logger.info(f'disconnecting zmq socket from {self.uri}') | |
await self.socket.send_multipart([ | |
b'', SchedulerCode.WORKER, SchedulerCode.DISCONNECT | |
]) | |
self.socket.disconnect(self.uri) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment