Skip to content

Commit

Permalink
Update for new pyo3 API
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Dec 16, 2024
1 parent 2fb896f commit d51ba47
Showing 1 changed file with 36 additions and 15 deletions.
51 changes: 36 additions & 15 deletions datafusion/common/src/pyarrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use arrow_array::Array;
use pyo3::exceptions::PyException;
use pyo3::prelude::PyErr;
use pyo3::types::{PyAnyMethods, PyList};
use pyo3::{Bound, FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python};
use pyo3::{Bound, FromPyObject, IntoPyObject, PyAny, PyObject, PyResult, Python};

use crate::{DataFusionError, ScalarValue};

Expand All @@ -40,8 +40,8 @@ impl FromPyArrow for ScalarValue {
let val = value.call_method0("as_py")?;

// construct pyarrow array from the python value and pyarrow type
let factory = py.import_bound("pyarrow")?.getattr("array")?;
let args = PyList::new_bound(py, [val]);
let factory = py.import("pyarrow")?.getattr("array")?;
let args = PyList::new(py, [val])?;
let array = factory.call1((args, typ))?;

// convert the pyarrow array to rust array using C data interface
Expand Down Expand Up @@ -69,14 +69,25 @@ impl<'source> FromPyObject<'source> for ScalarValue {
}
}

impl IntoPy<PyObject> for ScalarValue {
fn into_py(self, py: Python) -> PyObject {
self.to_pyarrow(py).unwrap()
impl<'source> IntoPyObject<'source> for ScalarValue {
type Target = PyAny;

type Output = Bound<'source, Self::Target>;

type Error = PyErr;

fn into_pyobject(self, py: Python<'source>) -> Result<Self::Output, Self::Error> {
let array = self.to_array()?;
// convert to pyarrow array using C data interface
let pyarray = array.to_data().to_pyarrow(py)?;
let pyarray_bound = pyarray.bind(py);
pyarray_bound.call_method1("__getitem__", (0,))
}
}

#[cfg(test)]
mod tests {
use pyo3::ffi::c_str;
use pyo3::prepare_freethreaded_python;
use pyo3::py_run;
use pyo3::types::PyDict;
Expand All @@ -86,10 +97,12 @@ mod tests {
fn init_python() {
prepare_freethreaded_python();
Python::with_gil(|py| {
if py.run_bound("import pyarrow", None, None).is_err() {
let locals = PyDict::new_bound(py);
py.run_bound(
"import sys; executable = sys.executable; python_path = sys.path",
if py.run(c_str!("import pyarrow"), None, None).is_err() {
let locals = PyDict::new(py);
py.run(
c_str!(
"import sys; executable = sys.executable; python_path = sys.path"
),
None,
Some(&locals),
)
Expand Down Expand Up @@ -135,17 +148,25 @@ mod tests {
}

#[test]
fn test_py_scalar() {
fn test_py_scalar() -> PyResult<()> {
init_python();

Python::with_gil(|py| {
Python::with_gil(|py| -> PyResult<()> {
let scalar_float = ScalarValue::Float64(Some(12.34));
let py_float = scalar_float.into_py(py).call_method0(py, "as_py").unwrap();
let py_float = scalar_float
.into_pyobject(py)?
.call_method0("as_py")
.unwrap();
py_run!(py, py_float, "assert py_float == 12.34");

let scalar_string = ScalarValue::Utf8(Some("Hello!".to_string()));
let py_string = scalar_string.into_py(py).call_method0(py, "as_py").unwrap();
let py_string = scalar_string
.into_pyobject(py)?
.call_method0("as_py")
.unwrap();
py_run!(py, py_string, "assert py_string == 'Hello!'");
});

Ok(())
})
}
}

0 comments on commit d51ba47

Please sign in to comment.