Created
June 18, 2024 09:37
-
-
Save folkertdev/06a7060f0ea0689a721659a8317f9801 to your computer and use it in GitHub Desktop.
tsp async pyo3 experiment
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from dataclasses import dataclass | |
import tsp_python | |
from tsp_python import AsyncStore, OwnedVid, ReceivedTspMessageVariant, FlatReceivedTspMessage | |
class ReceivedTspMessage: | |
@staticmethod | |
def from_flat(msg: FlatReceivedTspMessage): | |
match msg.variant: | |
case ReceivedTspMessageVariant.GenericMessage: | |
return AcceptRelationship(msg.sender, msg.nonconfidential_data, msg.message, msg.message_type) | |
case ReceivedTspMessageVariant.RequestRelationship: | |
raise ValueError("todo!") | |
case ReceivedTspMessageVariant.AcceptRelationship: | |
return AcceptRelationship(msg.sender) | |
case ReceivedTspMessageVariant.CancelRelationship: | |
return CancelRelationship(msg.sender) | |
case ReceivedTspMessageVariant.ForwardRequest: | |
raise ValueError("todo!") | |
case ReceivedTspMessageVariant.PendingMessage: | |
raise ValueError("todo!") | |
case other: | |
raise ValueError(f"Unrecognized variant: {other}") | |
@dataclass | |
class GenericMessage(ReceivedTspMessage): | |
sender: str | |
nonconfidential_data = None | |
message = [] | |
message_type = None | |
@dataclass | |
class AcceptRelationship(ReceivedTspMessage): | |
sender: str | |
@dataclass | |
class AcceptRelationship(ReceivedTspMessage): | |
sender: str | |
@dataclass | |
class CancelRelationship(ReceivedTspMessage): | |
sender: str | |
class AsyncStore: | |
inner: tsp_python.AsyncStore | |
def __init__(self): | |
self.inner = tsp_python.AsyncStore() | |
async def receive(self, address): | |
return ReceivedTspMessageStream(await self.inner.receive(address)) | |
def add_private_vid(self, *args, **kwargs): | |
return self.inner.add_private_vid(*args, **kwargs) | |
def verify_vid(self, *args, **kwargs): | |
return self.inner.verify_vid(*args, **kwargs) | |
async def send(self, *args, **kwargs): | |
# seal | |
# await transport_send | |
return await self.inner.send(*args, **kwargs) | |
class ReceivedTspMessageStream: | |
def __init__(self, future): | |
self.future = future | |
def __aiter__(self): | |
return self | |
async def __anext__(self): | |
result = await self.future.next() | |
if result is None: | |
raise StopAsyncIteration | |
else: | |
return ReceivedTspMessage.from_flat(result) | |
async def main(): | |
# bob database | |
print("bob database"); | |
bob_db = AsyncStore(); | |
bob_vid = await OwnedVid.from_file("../examples/test/bob.json") | |
bob_db.add_private_vid(bob_vid) | |
await bob_db.verify_vid("did:web:did.tsp-test.org:user:alice") | |
bobs_messages = await bob_db.receive("did:web:did.tsp-test.org:user:bob") | |
print("got the stream") | |
# alice database | |
print("alice database"); | |
alice_db = AsyncStore(); | |
alice_vid = await OwnedVid.from_file("../examples/test/alice.json") | |
alice_db.add_private_vid(alice_vid) | |
await alice_db.verify_vid("did:web:did.tsp-test.org:user:bob") | |
# send a message | |
print("send a message") | |
await alice_db.send( | |
"did:web:did.tsp-test.org:user:alice", | |
"did:web:did.tsp-test.org:user:bob", | |
b"extra non-confidential data", | |
b"hello world", | |
) | |
print("receive message") | |
match await anext(bobs_messages): | |
case GenericMessage(message): | |
print("success: {}", message) | |
case other: | |
print(f"failure {other}") | |
import asyncio | |
loop = asyncio.get_event_loop() | |
loop.run_until_complete(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::future::Future; | |
use pyo3::{exceptions::PyException, prelude::*}; | |
fn tokio() -> &'static tokio::runtime::Runtime { | |
use std::sync::OnceLock; | |
static RT: OnceLock<tokio::runtime::Runtime> = OnceLock::new(); | |
RT.get_or_init(|| tokio::runtime::Runtime::new().unwrap()) | |
} | |
fn py_exception<E: std::fmt::Debug>(e: E) -> PyErr { | |
PyException::new_err(format!("{e:?}")) | |
} | |
#[pyclass] | |
struct AsyncStore(tsp::Store); | |
async fn spawn<F>(fut: F) -> PyResult<F::Output> | |
where | |
F: Future + Send + 'static, | |
F::Output: Send + 'static, | |
{ | |
tokio().spawn(fut).await.map_err(py_exception) | |
} | |
#[pymethods] | |
impl AsyncStore { | |
#[new] | |
fn new() -> Self { | |
Self(tsp::Store::default()) | |
} | |
fn add_private_vid(&self, vid: OwnedVid) -> PyResult<()> { | |
self.0.add_private_vid(vid.0).unwrap(); | |
Ok(()) | |
} | |
async fn verify_vid(&mut self, vid: String) -> PyResult<()> { | |
let verified_vid = spawn(async move { tsp::vid::verify_vid(&vid).await }) | |
.await? | |
.map_err(py_exception)?; | |
self.0.add_verified_vid(verified_vid).map_err(py_exception) | |
} | |
#[pyo3(signature = (sender, receiver, nonconfidential_data, message))] | |
async fn send( | |
&self, | |
sender: String, | |
receiver: String, | |
nonconfidential_data: Option<Vec<u8>>, | |
message: Vec<u8>, | |
) -> PyResult<Vec<u8>> { | |
let (url, bytes) = self | |
.0 | |
.seal_message( | |
&sender, | |
&receiver, | |
nonconfidential_data.as_deref(), | |
&message, | |
) | |
.map_err(py_exception)?; | |
let fut = async move { | |
tsp::transport::send_message(&url, &bytes).await?; | |
Ok::<Vec<_>, tsp::transport::TransportError>(bytes) | |
}; | |
spawn(fut).await?.map_err(py_exception) | |
} | |
/// Send TSP broadcast message to the specified VIDs | |
pub async fn send_anycast( | |
&self, | |
sender: String, | |
receivers: Vec<String>, | |
nonconfidential_message: Vec<u8>, | |
) -> PyResult<()> { | |
let message = self | |
.0 | |
.sign_anycast(&sender, &nonconfidential_message) | |
.unwrap(); | |
let inner = self.0.clone(); | |
let fut = async move { | |
for vid in receivers { | |
let receiver = inner.get_verified_vid(vid.as_ref()).unwrap(); | |
tsp::transport::send_message(receiver.endpoint(), &message) | |
.await | |
.unwrap(); | |
} | |
}; | |
spawn(fut).await | |
} | |
pub async fn receive(&self, vid: String) -> PyResult<ReceivedTspMessageStream> { | |
let receiver = self.0.get_private_vid(&vid).map_err(py_exception)?; | |
let messages = | |
spawn(async move { tsp::transport::receive_messages(receiver.endpoint()).await }) | |
.await? | |
.map_err(py_exception)?; | |
use futures::StreamExt; | |
let db = self.0.clone(); | |
Ok(ReceivedTspMessageStream(Box::pin(messages.then( | |
move |message| { | |
let db_inner = db.clone(); | |
async move { | |
match message { | |
Ok(mut m) => match db_inner.open_message(&mut m) { | |
Err(tsp::Error::UnverifiedSource(unknown_vid)) => { | |
Ok(tsp::ReceivedTspMessage::PendingMessage { | |
unknown_vid, | |
payload: m.to_vec(), | |
}) | |
} | |
maybe_message => maybe_message, | |
}, | |
Err(e) => Err(e.into()), | |
} | |
} | |
}, | |
)))) | |
} | |
// pub async fn receive(&self, vid: &str) -> Result<TSPStream<ReceivedTspMessage, Error>, Error> { | |
// let receiver = self.inner.get_private_vid(vid)?; | |
// let messages = crate::transport::receive_messages(receiver.endpoint()).await?; | |
// | |
// let db = self.inner.clone(); | |
// Ok(Box::pin(messages.then(move |message| { | |
// let db_inner = db.clone(); | |
// async move { | |
// match message { | |
// Ok(mut m) => match db_inner.open_message(&mut m) { | |
// Err(Error::UnverifiedSource(unknown_vid)) => { | |
// Ok(ReceivedTspMessage::PendingMessage { | |
// unknown_vid, | |
// payload: m.to_vec(), | |
// }) | |
// } | |
// maybe_message => maybe_message, | |
// }, | |
// Err(e) => Err(e.into()), | |
// } | |
// } | |
// }))) | |
// } | |
} | |
#[pyclass] | |
struct ReceivedTspMessageStream(tsp::definitions::TSPStream<tsp::ReceivedTspMessage, tsp::Error>); | |
#[pyclass] | |
#[derive(Clone, Copy)] | |
enum ReceivedTspMessageVariant { | |
GenericMessage, | |
RequestRelationship, | |
AcceptRelationship, | |
CancelRelationship, | |
ForwardRequest, | |
PendingMessage, | |
} | |
impl From<&tsp::ReceivedTspMessage> for ReceivedTspMessageVariant { | |
fn from(value: &tsp::ReceivedTspMessage) -> Self { | |
match value { | |
tsp::ReceivedTspMessage::GenericMessage { .. } => Self::GenericMessage, | |
tsp::ReceivedTspMessage::RequestRelationship { .. } => Self::RequestRelationship, | |
tsp::ReceivedTspMessage::AcceptRelationship { .. } => Self::AcceptRelationship, | |
tsp::ReceivedTspMessage::CancelRelationship { .. } => Self::CancelRelationship, | |
tsp::ReceivedTspMessage::ForwardRequest { .. } => Self::ForwardRequest, | |
tsp::ReceivedTspMessage::PendingMessage { .. } => Self::PendingMessage, | |
} | |
} | |
} | |
#[pyclass] | |
#[derive(Clone, Copy)] | |
enum MessageType { | |
Signed, | |
SignedAndEncrypted, | |
} | |
#[pyclass] | |
struct FlatReceivedTspMessage { | |
#[pyo3(get, set)] | |
variant: ReceivedTspMessageVariant, | |
#[pyo3(get, set)] | |
sender: Option<String>, | |
#[pyo3(get, set)] | |
nonconfidential_data: Option<Option<Vec<u8>>>, | |
#[pyo3(get, set)] | |
message: Option<Vec<u8>>, | |
#[pyo3(get, set)] | |
message_type: Option<MessageType>, | |
#[pyo3(get, set)] | |
route: Option<Option<Vec<Vec<u8>>>>, | |
#[pyo3(get, set)] | |
thread_id: Option<[u8; 32]>, | |
#[pyo3(get, set)] | |
next_hop: Option<String>, | |
#[pyo3(get, set)] | |
payload: Option<Vec<u8>>, | |
#[pyo3(get, set)] | |
opaque_payload: Option<Vec<u8>>, | |
#[pyo3(get, set)] | |
unknown_vid: Option<String>, | |
} | |
impl From<tsp::ReceivedTspMessage> for FlatReceivedTspMessage { | |
fn from(value: tsp::ReceivedTspMessage) -> Self { | |
let variant = ReceivedTspMessageVariant::from(&value); | |
let mut this = FlatReceivedTspMessage { | |
variant, | |
sender: None, | |
nonconfidential_data: None, | |
message: None, | |
message_type: None, | |
route: None, | |
thread_id: None, | |
next_hop: None, | |
payload: None, | |
opaque_payload: None, | |
unknown_vid: None, | |
}; | |
match value { | |
tsp::ReceivedTspMessage::GenericMessage { | |
sender, | |
nonconfidential_data, | |
message, | |
message_type, | |
} => { | |
this.sender = Some(sender); | |
this.nonconfidential_data = Some(nonconfidential_data); | |
this.message = Some(message); | |
this.message_type = match message_type { | |
tsp::definitions::MessageType::Signed => Some(MessageType::Signed), | |
tsp::definitions::MessageType::SignedAndEncrypted => { | |
Some(MessageType::SignedAndEncrypted) | |
} | |
}; | |
} | |
tsp::ReceivedTspMessage::RequestRelationship { | |
sender, | |
route, | |
thread_id, | |
} => { | |
this.sender = Some(sender); | |
this.route = Some(route); | |
this.thread_id = Some(thread_id); | |
} | |
tsp::ReceivedTspMessage::AcceptRelationship { sender } => { | |
this.sender = Some(sender); | |
} | |
tsp::ReceivedTspMessage::CancelRelationship { sender } => { | |
this.sender = Some(sender); | |
} | |
tsp::ReceivedTspMessage::ForwardRequest { | |
sender, | |
next_hop, | |
route, | |
opaque_payload, | |
} => { | |
this.sender = Some(sender); | |
this.next_hop = Some(next_hop); | |
this.route = Some(Some(route)); | |
this.opaque_payload = Some(opaque_payload); | |
} | |
tsp::ReceivedTspMessage::PendingMessage { | |
unknown_vid, | |
payload, | |
} => { | |
this.unknown_vid = Some(unknown_vid); | |
this.payload = Some(payload); | |
} | |
}; | |
this | |
} | |
} | |
#[pymethods] | |
impl ReceivedTspMessageStream { | |
async fn next(&mut self) -> PyResult<Option<FlatReceivedTspMessage>> { | |
use futures::prelude::*; | |
match self.0.next().await { | |
None => Ok(None), | |
Some(Ok(value)) => Ok(Some(value.into())), | |
Some(Err(e)) => Err(py_exception(e)), | |
} | |
} | |
} | |
#[pyclass] | |
#[derive(Clone)] | |
struct OwnedVid(tsp::OwnedVid); | |
#[pymethods] | |
impl OwnedVid { | |
#[staticmethod] | |
async fn from_file(path: String) -> PyResult<OwnedVid> { | |
let fut = async move { | |
let owned_vid = tsp::OwnedVid::from_file(&path) | |
.await | |
.map_err(py_exception)?; | |
Ok(Self(owned_vid)) | |
}; | |
tokio().spawn(fut).await.unwrap() | |
} | |
} | |
/// A Python module implemented in Rust. | |
#[pymodule] | |
fn tsp_python(m: &Bound<'_, PyModule>) -> PyResult<()> { | |
m.add_class::<AsyncStore>()?; | |
m.add_class::<OwnedVid>()?; | |
m.add_class::<ReceivedTspMessageVariant>()?; | |
m.add_class::<FlatReceivedTspMessage>()?; | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment