Created
November 23, 2021 17:22
-
-
Save adriangb/ef1454e33cb06ca21d1952ff1cbe7930 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::cmp; | |
use std::hash; | |
use std::collections::HashSet; | |
use std::os::raw::c_int; | |
use pyo3::prelude::*; | |
use pyo3::ffi; | |
use pyo3::conversion::{AsPyPointer}; | |
use pyo3::{pyobject_native_type_base,pyobject_native_type_info,pyobject_native_type_extract,PyNativeType}; | |
use pyo3::types::{PyString}; | |
// Implement Hash and Eq for PyAny as a newtype so that it can be used in HashMap keys | |
// Unfortunately, these are unsafe implementations: | |
// they'll panic if __hash__ doesn't exist on the Python object or otherwise fails | |
// So we need to make sure to massage the panic into a Python error :( | |
struct HashablePyAny(PyAny); | |
// What follows is a bunch of boilerplate copied from pyo3::types::any | |
// We need this just to be able to get FromPyObject for HashablePyAny | |
// since Rust doesn't support delegation for newtypes | |
impl AsPyPointer for HashablePyAny { | |
#[inline] | |
fn as_ptr(&self) -> *mut ffi::PyObject { | |
self.0.as_ptr() | |
} | |
} | |
impl HashablePyAny { | |
#[inline] | |
pub fn py(&self) -> Python<'_> { | |
PyNativeType::py(self) | |
} | |
pub fn repr(&self) -> PyResult<&PyString> { | |
unsafe { | |
self.py() | |
.from_owned_ptr_or_err(ffi::PyObject_Repr(self.as_ptr())) | |
} | |
} | |
pub fn str(&self) -> PyResult<&PyString> { | |
unsafe { | |
self.py() | |
.from_owned_ptr_or_err(ffi::PyObject_Str(self.as_ptr())) | |
} | |
} | |
pub fn hash(&self) -> PyResult<isize> { | |
let v = unsafe { ffi::PyObject_Hash(self.as_ptr()) }; | |
if v == -1 { | |
Err(PyErr::fetch(self.py())) | |
} else { | |
Ok(v) | |
} | |
} | |
} | |
pyobject_native_type_base!(HashablePyAny); | |
#[allow(non_snake_case)] | |
fn PyObject_Check(_: *mut ffi::PyObject) -> c_int { | |
1 | |
} | |
pyobject_native_type_info!( | |
HashablePyAny, | |
ffi::PyBaseObject_Type, | |
Some("builtins"), | |
#checkfunction=PyObject_Check | |
); | |
pyobject_native_type_extract!(HashablePyAny); | |
// End copied boilerplate | |
impl hash::Hash for HashablePyAny { | |
fn hash<H: hash::Hasher>(&self, state: &mut H) { | |
let v = unsafe { ffi::PyObject_Hash(self.as_ptr()) }; | |
if v == -1 { | |
panic!("Calling __hash__ on Python object failed") | |
} else { | |
state.write(&v.to_be_bytes()) | |
} | |
} | |
} | |
impl cmp::Eq for HashablePyAny {} | |
#[pyfunction] | |
fn func(ob: Py<HashablePyAny>) -> PyResult<()> { | |
let mut set: HashSet<Py<HashablePyAny>> = HashSet::new(); | |
Python::with_gil(|py| { | |
let a = ob.as_ref(py); | |
set.insert(a.into()); // error here | |
}); | |
println!("{:?}", set); | |
Ok(()) | |
} | |
#[pymodule] | |
fn di_lib(_py: Python, m: &PyModule) -> PyResult<()> { | |
m.add_function(wrap_pyfunction!(func, m)?)?; | |
Ok(()) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment