Last active
November 27, 2021 10:07
-
-
Save adriangb/1352d711966db89b94395e8f6fb83de6 to your computer and use it in GitHub Desktop.
Hashable Python objects in PyO3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::cmp; | |
use std::hash; | |
use pyo3::basic::CompareOp; | |
use pyo3::prelude::*; | |
// We can't put a Py<PyAny> directly into a HashMap key | |
// So to be able to hold references to arbitrary Python objects in HashMap as keys | |
// we wrap them in a struct that gets the hash() when it receives the object from Python | |
// and then just echoes back that hash when called Rust needs to hash it | |
#[derive(Clone)] | |
pub struct HashedAny(pub Py<PyAny>, isize); | |
impl <'source>FromPyObject<'source> for HashedAny | |
{ | |
fn extract(ob: &'source PyAny) -> PyResult<Self> { | |
Ok( | |
HashedAny(ob.into(), ob.hash()?) | |
) | |
} | |
} | |
impl hash::Hash for HashedAny { | |
fn hash<H: hash::Hasher>(&self, state: &mut H) { | |
self.1.hash(state) | |
} | |
} | |
impl cmp::PartialEq for HashedAny { | |
fn eq(&self, other: &Self) -> bool { | |
Python::with_gil(|py| -> PyResult<bool> { | |
Ok(self.0.as_ref(py).rich_compare(other.0.as_ref(py), CompareOp::Eq)?.is_true()?) | |
}).unwrap() | |
} | |
} | |
impl cmp::Eq for HashedAny {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Just a HashMap wrapper for testing purposes | |
use std::collections::HashMap; | |
use pyo3::{Py, PyAny, Python, exceptions}; | |
use pyo3::prelude::*; | |
mod hashedany; | |
use crate::hashedany::HashedAny; | |
#[pyclass] | |
#[derive(Debug, Clone)] | |
struct _PyHashMap { | |
map: HashMap<HashedAny, Py<PyAny>>, | |
} | |
#[pymethods] | |
impl _PyHashMap { | |
#[new] | |
fn new(map: HashMap<HashedAny, Py<PyAny>>) -> Self { | |
_PyHashMap { map } | |
} | |
fn __setitem__(&mut self, k: HashedAny, v: Py<PyAny>) -> () { | |
self.map.insert(k, v); | |
} | |
fn __getitem__(&self, k: HashedAny) -> PyResult<Py<PyAny>> { | |
match self.map.get(&k) { | |
Some(v) => Ok(v.clone()), | |
None => Err(exceptions::PyKeyError::new_err(format!("KeyError: {:?}", k))), | |
} | |
} | |
fn __delitem__(&mut self, k: HashedAny) -> PyResult<()> { | |
match self.map.remove(&k) { | |
Some(_) => Ok(()), | |
None => Err(exceptions::PyKeyError::new_err(format!("KeyError: {:?}", k))), | |
} | |
} | |
} | |
#[pymodule] | |
fn hashedpyany(_py: Python, m: &PyModule) -> PyResult<()> { | |
m.add_class::<_PyHashMap>()?; | |
Ok(()) | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Python tests using hypothesis | |
from __future__ import annotations | |
import unittest | |
from typing import Any, Dict, Hashable | |
import hypothesis.strategies as st | |
from hypothesis.stateful import Bundle, RuleBasedStateMachine, rule | |
from hashedpyany import PyHashMap | |
class HashMapComparison(RuleBasedStateMachine): | |
def __init__(self): | |
super().__init__() | |
self.python: Dict[Hashable, Any] = {} | |
self.rust: PyHashMap[Hashable, Any] = PyHashMap({}) | |
keys = Bundle("keys") | |
values = Bundle("values") | |
@rule(target=keys, k=st.tuples(st.integers())) | |
def add_key(self, k): | |
return k | |
@rule(target=values, v=st.binary()) | |
def add_value(self, v): | |
return v | |
@rule(k=keys, v=values) | |
def insert(self, k, v): | |
self.python[k] = v | |
self.rust[k] = v | |
@rule(k=keys) | |
def get(self, k): | |
try: | |
py = self.python[k] | |
except KeyError: | |
py = None | |
try: | |
rus = self.rust[k] | |
except KeyError: | |
rus = None | |
assert rus == py | |
@rule(k=keys) | |
def delete(self, k): | |
try: | |
del self.python[k] | |
py = False | |
except KeyError: | |
py = True | |
try: | |
del self.rust[k] | |
rus = False | |
except KeyError: | |
rus = True | |
assert rus == py | |
TestHashMapComparison = HashMapComparison.TestCase | |
if __name__ == "__main__": | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment