Created
September 28, 2017 21:11
-
-
Save virantha/54a7d02d4d50a9fcbce8082c34ff774d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from curio import Queue, CancelledError | |
class Port: | |
def __init__(self): | |
self.chan = None | |
class InputPort(Port): | |
async def recv(self): | |
tok = await self.chan.recv() | |
return tok | |
class OutputPort(Port): | |
async def send(self, val): | |
await self.chan.send((val)) | |
def connect(a, b, name=''): | |
# Connect ports together by instantiating a channel | |
chan = Channel(name) | |
# Check to make sure the ports have not been connected previously to other channels! | |
assert not a.chan, f"Channel {a} has already been connected!" | |
assert not b.chan, f"Channel {a} has already been connected!" | |
# Check to make sure the two ports are of opposite type (input/output) | |
if isinstance(a, InputPort): | |
assert isinstance(b, OutputPort), f"Channel {a} and {b} are both input ports!" | |
# Store the ports this channel is connected to | |
# b ---chan---> a | |
chan.l = b | |
chan.r = a | |
else: | |
assert isinstance(b, InputPort), f"Channel {a} and {b} are both output ports!" | |
# Store the ports this channel is connected to | |
# a ---chan---> b | |
chan.l = a | |
chan.r = b | |
# Now assign the channel to the two ports | |
a.chan = chan | |
b.chan = chan | |
class Process: | |
next_id = 0 | |
non_producer_processes = {} | |
producer_processes = {} | |
def __init__(self, name): | |
self.name = name | |
self.id = Process.next_id | |
Process.next_id += 1 | |
# Keep track of all source processes (join on these at the end), and non-source processes (cancel on these at the end) | |
if isinstance(self, Producer): | |
Process.producer_processes[self.id] = self | |
else: | |
Process.non_producer_processes[self.id] = self | |
# Inject the ports from the annotations on this instance | |
for name, val in self.__annotations__.items(): | |
if issubclass(val, Port): | |
print(f'injecting port({val}) {name} onto {self.__class__}:{self.name}') | |
port = val() | |
setattr(self, name, port) | |
def __str__(self): | |
return f"{self.name}.{self.id}" | |
def __repr__(self): | |
return f"{type(self).__name__}('{self.name}')" | |
def message(self, m): | |
print(f"{self}: {m}") | |
class Producer(Process): | |
# All processes that drive the system (by injecting values in on channels unconditionally) | |
# must subclass this process | |
pass | |
class Source(Producer): | |
R: OutputPort | |
def __init__(self, name, length, srcval): | |
super().__init__(name) | |
self.val = srcval | |
self.length = length | |
async def exec(self): | |
for i in range(self.length): | |
self.message(f"sending {self.val}") | |
await self.R.send(self.val) | |
self.message(f"sent {self.val}") | |
self.message("terminated") | |
class Sink(Process): | |
L: InputPort | |
def __init__(self, name): | |
super().__init__(name) | |
async def exec(self): | |
tok_count = 0 | |
try: | |
while True: | |
tok = await self.L.recv() | |
tok_count += 1 | |
self.message(f"received {tok}") | |
except CancelledError: | |
self.message(f"{tok_count} tokens received") | |
class Buffer(Process): | |
L: InputPort | |
R: OutputPort | |
def __init__(self, name): | |
super().__init__(name) | |
async def exec(self): | |
while True: | |
tok = await self.L.recv() | |
self.message(f"received {tok}") | |
self.message(f"sending {tok}") | |
await self.R.send(tok) | |
class Channel: | |
def __init__(self, name): | |
self.name = name | |
self.q = Queue(maxsize=1) # Max buffering of 1 | |
async def send(self, val): | |
await self.q.put(val) | |
async def recv(self): | |
tok = await self.q.get() | |
await self.q.task_done() | |
return tok | |
async def close(self): | |
await self.q.join() | |
async def run_all(): | |
source_tasks = [] | |
other_tasks = [] | |
for p in Process.producer_processes.values(): | |
source_tasks.append(await spawn(p.exec())) | |
for p in Process.non_producer_processes.values(): | |
other_tasks.append(await spawn(p.exec())) | |
# Now wait for all sources to end | |
for task in source_tasks: | |
await task.join() | |
for task in other_tasks: | |
await task.cancel() | |
from curio import run, spawn | |
async def system(): | |
N = 10 # How many buffers in our linear pipeline | |
# Instantiate the processes | |
src = Source('src1', 10, 1) | |
buf = [Buffer(f'buf[{i}]') for i in range(N)] | |
snk = Sink('snk') | |
# Connect the processes with the channels | |
connect(src.R, buf[0].L) | |
for i in range(1, N): | |
connect(buf[i-1].R, buf[i].L) | |
connect(snk.L, buf[N-1].R) | |
await run_all() | |
if __name__=='__main__': | |
run(system(), with_monitor=True) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment