Skip to content

Instantly share code, notes, and snippets.

@davidselassie
Created August 25, 2023 00:10
Show Gist options
  • Save davidselassie/27e59163ce8d859a9462a6b8ac3efb21 to your computer and use it in GitHub Desktop.
Save davidselassie/27e59163ce8d859a9462a6b8ac3efb21 to your computer and use it in GitHub Desktop.
Hack Day Mixed Rust-Python Tracebacks
use std::panic::Location;
use pyo3::exceptions::PyValueError;
use pyo3::ffi;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::types::PyFrame;
use pyo3::types::PyTraceback;
use pyo3::AsPyPointer;
trait PyErrTracebackEx {
fn set_traceback<'py>(&self, py: Python<'py>, tb: Option<&'py PyTraceback>);
}
impl PyErrTracebackEx for PyErr {
fn set_traceback<'py>(&self, py: Python<'py>, tb: Option<&'py PyTraceback>) {
let value = self.value(py);
// PyException_SetTraceback says to use Py_None, not NULL to
// clear.
let none = py.None();
let tb = tb.map_or(none.as_ref(py), |tb| tb.as_ref());
// SAFETY: We ensure that we have the GIL and only pass None
// or a PyTraceback as the tb arg.
unsafe {
// Only borrows a reference to tb.
if ffi::PyException_SetTraceback(value.as_ptr(), tb.as_ptr()) != 0 {
let err = PyErr::fetch(py);
panic!("ERROR SETTING TRACEBACK");
}
}
}
}
trait PyErrContextEx: Sized {
fn set_context(&self, py: Python<'_>, ex: Option<Self>);
}
impl PyErrContextEx for PyErr {
fn set_context(&self, py: Python<'_>, ctx: Option<Self>) {
let value = self.value(py);
let ctx = ctx.map(|err| err.into_value(py));
unsafe {
// PyException_SetCause _steals_ a reference to ctx, so
// must use [`IntoPyPointer::into_ptr`].
ffi::PyException_SetContext(
value.as_ptr(),
ctx.map_or(std::ptr::null_mut(), IntoPyPointer::into_ptr),
);
}
}
}
trait PyErrLabelEx<T> {
fn label(self, name: &'static str) -> PyResult<T>;
}
fn forge_tb<'py>(
py: Python<'py>,
orig_tb: Option<&'py PyTraceback>,
name: &'static str,
caller: &'static Location,
) -> PyResult<&'py PyTraceback> {
let filename = caller.file();
let firstlineno = caller
.line()
.try_into()
.map_err(|_| PyValueError::new_err("Line number too large"))?;
// SAFETY: We are passing ownership on to [`ffi::PyFrame_New`].
let code = unsafe {
ffi::PyCode_NewEmpty(
filename.as_ptr() as *const i8,
name.as_ptr() as *const i8,
firstlineno,
)
};
// SAFETY: We have the GIL token.
let thread_state = unsafe { ffi::PyThreadState_Get() };
// No globals or locals since this isn't Python
// code.
let globals = PyDict::new(py);
let locals = PyDict::new(py);
// SAFETY:
let frame = unsafe {
let ptr = ffi::PyFrame_New(thread_state, code, globals.as_ptr(), locals.as_ptr());
py.from_owned_ptr::<PyFrame>(ptr as *mut ffi::PyObject)
};
let lineno = caller.line();
let fake_tb = py
.import("types")?
.call_method1("TracebackType", (orig_tb, frame, 0, lineno))?
.downcast_exact::<PyTraceback>()?;
Ok(fake_tb)
}
impl<T> PyErrLabelEx<T> for PyResult<T> {
#[track_caller]
fn label(self, name: &'static str) -> PyResult<T> {
self.map_err(|err| {
let caller = Location::caller();
Python::with_gil(|py| {
let orig_tb = err.traceback(py);
match forge_tb(py, orig_tb, name, caller) {
Ok(fake_tb) => {
//err.set_cause(py, Some(PyValueError::new_err("poop")));
err.set_traceback(py, None);
err.print(py);
err
}
// If something went wrong injecting into the
// traceback, chain exceptions.
Err(inject_err) => {
inject_err.set_context(py, Some(err));
inject_err
}
}
})
})
}
}
#[pyfunction]
fn test_label(py: Python, call: PyObject) -> PyResult<PyObject> {
call.call1(py, ()).label("testtesttest")
}
#[pymodule]
fn pyplayground(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(test_label, m)?)?;
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment