From 858e6359dd54d12988de2c16e73bf85b20dddd1f Mon Sep 17 00:00:00 2001 From: Paul Horn Date: Wed, 11 Sep 2024 22:55:43 +0200 Subject: [PATCH] Add preliminary polars dataframe support --- lib/Cargo.toml | 8 ++ lib/src/errors.rs | 4 + lib/src/stream.rs | 232 ++++++++++++++++++++++++++++++- lib/tests/result_as_dataframe.rs | 42 ++++++ 4 files changed, 281 insertions(+), 5 deletions(-) create mode 100644 lib/tests/result_as_dataframe.rs diff --git a/lib/Cargo.toml b/lib/Cargo.toml index 0967d74..ee8ca83 100644 --- a/lib/Cargo.toml +++ b/lib/Cargo.toml @@ -17,6 +17,8 @@ rust-version = "1.75.0" [features] json = ["serde_json"] +polars = ["polars_v0_43"] +polars_v0_43 = ["dep:polars"] unstable-v1 = ["unstable-bolt-protocol-impl-v2", "unstable-result-summary"] unstable-serde-packstream-format = [] unstable-result-summary = ["unstable-serde-packstream-format"] @@ -38,6 +40,7 @@ nav-types = { version = "0.5.2", optional = true } neo4rs-macros = { version = "0.3.0", path = "../macros" } paste = "1.0.0" pin-project-lite = "0.2.9" +polars-utils = { version = "0.43.1", default-features = false } rustls-native-certs = "0.7.1" rustls-pemfile = "2.1.2" serde = { version = "1.0.185", features = ["derive"] } # TODO: eliminate derive @@ -62,6 +65,11 @@ version = "0.26.0" default-features = false features = ["tls12", "ring"] +[dependencies.polars] +version = "0.43.0" +default-features = false +optional = true +features = ["rows"] [dev-dependencies] pretty_env_logger = "0.5.0" diff --git a/lib/src/errors.rs b/lib/src/errors.rs index 6efd65f..ff3731f 100644 --- a/lib/src/errors.rs +++ b/lib/src/errors.rs @@ -24,6 +24,10 @@ pub enum Error { #[error(transparent)] ParseError(#[from] de::Error), + #[cfg(feature = "polars_v0_43")] + #[error(transparent)] + Polars(#[from] polars::error::PolarsError), + #[error("Unsupported URI scheme: {0}")] UnsupportedScheme(String), diff --git a/lib/src/stream.rs b/lib/src/stream.rs index f905370..5869114 100644 --- a/lib/src/stream.rs +++ b/lib/src/stream.rs @@ -19,7 +19,7 @@ use crate::{ use futures::{stream::try_unfold, TryStream}; use serde::de::DeserializeOwned; -use std::collections::VecDeque; +use std::{collections::VecDeque, sync::Arc, sync::OnceLock}; #[cfg(feature = "unstable-result-summary")] type BoxedSummary = Box; @@ -42,7 +42,7 @@ pub struct RowStream { available_after: i64, state: State, fetch_size: usize, - buffer: VecDeque, + buffer: VecDeque, } impl RowStream { @@ -139,7 +139,8 @@ impl RowStream { mut handle: impl TransactionHandle, ) -> Result> { loop { - if let Some(row) = self.buffer.pop_front() { + if let Some(record) = self.buffer.pop_front() { + let row = Row::new(self.fields.clone(), record); return Ok(Some(RowItem::Row(row))); } @@ -161,8 +162,7 @@ impl RowStream { .map(BoltType::from) .collect::>(), ); - let row = Row::new(self.fields.clone(), record); - self.buffer.push_back(row); + self.buffer.push_back(record); } Response::Success(Streaming::HasMore) => break State::Ready, Response::Success(Streaming::Done(mut s)) => { @@ -383,6 +383,111 @@ impl RowStream { } }) } + + #[cfg(all(feature = "polars_v0_43", not(feature = "unstable-result-summary")))] + pub async fn into_dataframe( + self, + mut handle: impl TransactionHandle, + ) -> Result { + self.into_df(handle).await + } + + #[cfg(all(feature = "polars_v0_43", feature = "unstable-result-summary"))] + pub async fn into_dataframe( + self, + handle: impl TransactionHandle, + ) -> Result<(polars::frame::DataFrame, Option)> { + let out_summary = Arc::new(OnceLock::new()); + let df = self.into_df(handle, out_summary.clone()).await?; + let summary = Arc::into_inner(out_summary).and_then(|s| s.into_inner()); + Ok((df, summary)) + } + + #[cfg(feature = "polars_v0_43")] + fn into_df( + mut self, + mut handle: impl TransactionHandle, + #[cfg(feature = "unstable-result-summary")] out_summary: Arc>, + ) -> impl std::future::Future> { + let fields = self.fields.value.iter().filter_map(|x| match x { + BoltType::String(s) => Some(s.value.as_str()), + _ => None, + }); + + let mut buf = pl::DataBuf::new(fields); + + for row in self.buffer.drain(..) { + buf.push(row.value); + } + + async move { + while self.state == State::Ready { + #[cfg(feature = "unstable-bolt-protocol-impl-v2")] + { + let pull = Pull::some(self.fetch_size as i64).for_query(self.qid); + let connection = handle.connection(); + connection.send_as(pull).await?; + self.state = loop { + let response = connection + .recv_as::, Streaming>>() + .await?; + match response { + Response::Detail(record) => { + let record = BoltList::from( + record + .into_iter() + .map(BoltType::from) + .collect::>(), + ); + buf.push(record.value); + } + Response::Success(Streaming::HasMore) => break State::Ready, + Response::Success(Streaming::Done(mut s)) => { + s.set_t_first(self.available_after); + break State::Complete(Some(s)); + } + otherwise => return Err(otherwise.into_error("PULL")), + } + }; + buf.flush()?; + } + + #[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))] + { + let pull = BoltRequest::pull(self.fetch_size, self.qid); + let connection = handle.connection(); + connection.send(pull).await?; + + self.state = loop { + match connection.recv().await { + Ok(BoltResponse::Success(s)) => { + break if s.get("has_more").unwrap_or(false) { + State::Ready + } else { + State::Complete(None) + }; + } + Ok(BoltResponse::Record(record)) => { + buf.push(record.data); + } + Ok(msg) => return Err(msg.into_error("PULL")), + Err(e) => return Err(e), + } + }; + buf.flush()?; + } + } + + #[cfg(feature = "unstable-result-summary")] + if let State::Complete(ref mut summary) = self.state { + if let Some(summary) = summary.take() { + out_summary.set(*summary).expect("only one summary"); + }; + } + + Ok(buf.into_df()?) + } + } } impl DetachedRowStream { @@ -470,6 +575,18 @@ impl DetachedRowStream { ) -> impl TryStream, Error = Error> + 'this { self.stream.column_to_items(&mut self.connection, column) } + + #[cfg(all(feature = "polars_v0_43", not(feature = "unstable-result-summary")))] + pub async fn into_dataframe(mut self) -> Result { + self.stream.into_dataframe(&mut self.connection).await + } + + #[cfg(all(feature = "polars_v0_43", feature = "unstable-result-summary"))] + pub async fn into_dataframe( + mut self, + ) -> Result<(polars::frame::DataFrame, Option)> { + self.stream.into_dataframe(&mut self.connection).await + } } #[derive(Clone, PartialEq, Debug)] @@ -477,3 +594,108 @@ enum State { Ready, Complete(Option), } + +// mod pl {{{ +#[cfg(feature = "polars_v0_43")] +mod pl { + use polars::{ + error::PolarsError as Error, + frame::DataFrame, + prelude::{AnyValue, PlSmallStr}, + series::Series, + }; + + use crate::BoltType; + + #[derive(Debug, Clone)] + pub(super) struct DataBuf { + fields: Vec, + buffers: Vec, + } + + impl DataBuf { + pub(super) fn new>(fields: impl IntoIterator) -> Self { + let fields = fields.into_iter().map(Into::into).collect::>(); + let buffers = vec![ColBuf::default(); fields.len()]; + Self { fields, buffers } + } + + pub(super) fn push(&mut self, values: Vec) { + assert_eq!(values.len(), self.fields.len()); + for (buf, value) in self.buffers.iter_mut().zip(values) { + buf.push(value); + } + } + + pub(super) fn flush(&mut self) -> Result<(), Error> { + for buf in &mut self.buffers { + buf.flush()?; + } + Ok(()) + } + + pub(super) fn into_df(self) -> Result { + let serieses = self + .buffers + .into_iter() + .zip(self.fields.into_iter()) + .map(|(buf, field)| buf.into_series(field)) + .collect::, _>>()?; + + DataFrame::new(serieses) + } + } + + // TODO: use AnyValueBuffer + #[derive(Debug, Default, Clone)] + struct ColBuf { + values: Vec>, + series: Option, + } + + impl ColBuf { + fn push(&mut self, value: BoltType) { + let value = match value { + BoltType::String(v) => AnyValue::StringOwned(v.value.into()), + BoltType::Boolean(v) => AnyValue::Boolean(v.value), + BoltType::Map(_) => todo!(), + BoltType::Null(_) => AnyValue::Null, + BoltType::Integer(v) => AnyValue::Int64(v.value), + BoltType::Float(v) => AnyValue::Float64(v.value), + BoltType::List(_) => todo!(), + BoltType::Node(_) => todo!(), + BoltType::Relation(_) => todo!(), + BoltType::UnboundedRelation(_) => todo!(), + BoltType::Point2D(_) => todo!(), + BoltType::Point3D(_) => todo!(), + BoltType::Bytes(v) => AnyValue::BinaryOwned(v.value.into()), + BoltType::Path(_) => todo!(), + BoltType::Duration(_) => todo!(), + BoltType::Date(_) => todo!(), + BoltType::Time(_) => todo!(), + BoltType::LocalTime(_) => todo!(), + BoltType::DateTime(_) => todo!(), + BoltType::LocalDateTime(_) => todo!(), + BoltType::DateTimeZoneId(_) => todo!(), + }; + self.values.push(value); + } + + fn flush(&mut self) -> Result<(), Error> { + let chunk = Series::from_any_values(PlSmallStr::EMPTY, &self.values, false)?; + if let Some(series) = &mut self.series { + series.append(&chunk)?; + } else { + self.series = Some(chunk); + } + + Ok(()) + } + + fn into_series(mut self, name: PlSmallStr) -> Result { + self.flush()?; + Ok(self.series.unwrap().with_name(name)) + } + } +} +// }}} diff --git a/lib/tests/result_as_dataframe.rs b/lib/tests/result_as_dataframe.rs new file mode 100644 index 0000000..cf17dce --- /dev/null +++ b/lib/tests/result_as_dataframe.rs @@ -0,0 +1,42 @@ +#![cfg(feature = "polars_v0_43")] + +use neo4rs::query; +use polars::prelude::{AnyValue, DataType}; + +mod container; + +#[tokio::test] +async fn result_as_dataframe() { + let neo4j = container::Neo4jContainer::new().await; + let graph = neo4j.graph(); + + let result = graph + .execute(query( + "UNWIND [TRUE, FALSE, NULL, 1, 420000, 13.37] AS values RETURN values", + )) + .await + .unwrap(); + + let df = result.into_dataframe().await.unwrap(); + #[cfg(feature = "unstable-result-summary")] + let df = df.0; + + assert_eq!(df.get_column_names(), ["values"]); + assert_eq!(df.height(), 6); + assert_eq!(df.width(), 1); + + let values = df.column("values").unwrap(); + + assert_eq!(values.dtype(), &DataType::Float64); + values + .iter() + .filter_map(|a| match a { + AnyValue::Float64(a) => Some(a), + AnyValue::Null => None, + _ => panic!("`{a:?} is not a float or null"), + }) + .zip([1.0, 0.0, 1.0, 420000.0, 13.37]) + .for_each(|(a, b)| { + assert!((a - b).abs() <= f64::EPSILON); + }); +}