Skip to content

Commit

Permalink
Add preliminary polars dataframe support
Browse files Browse the repository at this point in the history
  • Loading branch information
knutwalker committed Sep 19, 2024
1 parent b70889c commit 858e635
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 5 deletions.
8 changes: 8 additions & 0 deletions lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ rust-version = "1.75.0"

[features]
json = ["serde_json"]
polars = ["polars_v0_43"]
polars_v0_43 = ["dep:polars"]
unstable-v1 = ["unstable-bolt-protocol-impl-v2", "unstable-result-summary"]
unstable-serde-packstream-format = []
unstable-result-summary = ["unstable-serde-packstream-format"]
Expand All @@ -38,6 +40,7 @@ nav-types = { version = "0.5.2", optional = true }
neo4rs-macros = { version = "0.3.0", path = "../macros" }
paste = "1.0.0"
pin-project-lite = "0.2.9"
polars-utils = { version = "0.43.1", default-features = false }
rustls-native-certs = "0.7.1"
rustls-pemfile = "2.1.2"
serde = { version = "1.0.185", features = ["derive"] } # TODO: eliminate derive
Expand All @@ -62,6 +65,11 @@ version = "0.26.0"
default-features = false
features = ["tls12", "ring"]

[dependencies.polars]
version = "0.43.0"
default-features = false
optional = true
features = ["rows"]

[dev-dependencies]
pretty_env_logger = "0.5.0"
Expand Down
4 changes: 4 additions & 0 deletions lib/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ pub enum Error {
#[error(transparent)]
ParseError(#[from] de::Error),

#[cfg(feature = "polars_v0_43")]
#[error(transparent)]
Polars(#[from] polars::error::PolarsError),

#[error("Unsupported URI scheme: {0}")]
UnsupportedScheme(String),

Expand Down
232 changes: 227 additions & 5 deletions lib/src/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
use futures::{stream::try_unfold, TryStream};
use serde::de::DeserializeOwned;

use std::collections::VecDeque;
use std::{collections::VecDeque, sync::Arc, sync::OnceLock};

#[cfg(feature = "unstable-result-summary")]
type BoxedSummary = Box<ResultSummary>;
Expand All @@ -42,7 +42,7 @@ pub struct RowStream {
available_after: i64,
state: State,
fetch_size: usize,
buffer: VecDeque<Row>,
buffer: VecDeque<BoltList>,
}

impl RowStream {
Expand Down Expand Up @@ -139,7 +139,8 @@ impl RowStream {
mut handle: impl TransactionHandle,
) -> Result<Option<RowItem>> {
loop {
if let Some(row) = self.buffer.pop_front() {
if let Some(record) = self.buffer.pop_front() {
let row = Row::new(self.fields.clone(), record);
return Ok(Some(RowItem::Row(row)));
}

Expand All @@ -161,8 +162,7 @@ impl RowStream {
.map(BoltType::from)
.collect::<Vec<BoltType>>(),
);
let row = Row::new(self.fields.clone(), record);
self.buffer.push_back(row);
self.buffer.push_back(record);
}
Response::Success(Streaming::HasMore) => break State::Ready,
Response::Success(Streaming::Done(mut s)) => {
Expand Down Expand Up @@ -383,6 +383,111 @@ impl RowStream {
}
})
}

#[cfg(all(feature = "polars_v0_43", not(feature = "unstable-result-summary")))]
pub async fn into_dataframe(
self,
mut handle: impl TransactionHandle,
) -> Result<polars::frame::DataFrame> {
self.into_df(handle).await
}

#[cfg(all(feature = "polars_v0_43", feature = "unstable-result-summary"))]
pub async fn into_dataframe(
self,
handle: impl TransactionHandle,
) -> Result<(polars::frame::DataFrame, Option<ResultSummary>)> {
let out_summary = Arc::new(OnceLock::new());
let df = self.into_df(handle, out_summary.clone()).await?;
let summary = Arc::into_inner(out_summary).and_then(|s| s.into_inner());
Ok((df, summary))
}

#[cfg(feature = "polars_v0_43")]
fn into_df(
mut self,
mut handle: impl TransactionHandle,
#[cfg(feature = "unstable-result-summary")] out_summary: Arc<OnceLock<ResultSummary>>,
) -> impl std::future::Future<Output = Result<polars::frame::DataFrame, Error>> {
let fields = self.fields.value.iter().filter_map(|x| match x {
BoltType::String(s) => Some(s.value.as_str()),
_ => None,
});

let mut buf = pl::DataBuf::new(fields);

for row in self.buffer.drain(..) {
buf.push(row.value);
}

async move {
while self.state == State::Ready {
#[cfg(feature = "unstable-bolt-protocol-impl-v2")]
{
let pull = Pull::some(self.fetch_size as i64).for_query(self.qid);
let connection = handle.connection();
connection.send_as(pull).await?;
self.state = loop {
let response = connection
.recv_as::<Response<Vec<Bolt>, Streaming>>()
.await?;
match response {
Response::Detail(record) => {
let record = BoltList::from(
record
.into_iter()
.map(BoltType::from)
.collect::<Vec<BoltType>>(),
);
buf.push(record.value);
}
Response::Success(Streaming::HasMore) => break State::Ready,
Response::Success(Streaming::Done(mut s)) => {
s.set_t_first(self.available_after);
break State::Complete(Some(s));
}
otherwise => return Err(otherwise.into_error("PULL")),
}
};
buf.flush()?;
}

#[cfg(not(feature = "unstable-bolt-protocol-impl-v2"))]
{
let pull = BoltRequest::pull(self.fetch_size, self.qid);
let connection = handle.connection();
connection.send(pull).await?;

self.state = loop {
match connection.recv().await {
Ok(BoltResponse::Success(s)) => {
break if s.get("has_more").unwrap_or(false) {
State::Ready
} else {
State::Complete(None)
};
}
Ok(BoltResponse::Record(record)) => {
buf.push(record.data);
}
Ok(msg) => return Err(msg.into_error("PULL")),
Err(e) => return Err(e),
}
};
buf.flush()?;
}
}

#[cfg(feature = "unstable-result-summary")]
if let State::Complete(ref mut summary) = self.state {
if let Some(summary) = summary.take() {
out_summary.set(*summary).expect("only one summary");
};
}

Ok(buf.into_df()?)
}
}
}

impl DetachedRowStream {
Expand Down Expand Up @@ -470,10 +575,127 @@ impl DetachedRowStream {
) -> impl TryStream<Ok = RowItem<T>, Error = Error> + 'this {
self.stream.column_to_items(&mut self.connection, column)
}

#[cfg(all(feature = "polars_v0_43", not(feature = "unstable-result-summary")))]
pub async fn into_dataframe(mut self) -> Result<polars::frame::DataFrame> {
self.stream.into_dataframe(&mut self.connection).await
}

#[cfg(all(feature = "polars_v0_43", feature = "unstable-result-summary"))]
pub async fn into_dataframe(
mut self,
) -> Result<(polars::frame::DataFrame, Option<ResultSummary>)> {
self.stream.into_dataframe(&mut self.connection).await
}
}

#[derive(Clone, PartialEq, Debug)]
enum State {
Ready,
Complete(Option<BoxedSummary>),
}

// mod pl {{{
#[cfg(feature = "polars_v0_43")]
mod pl {
use polars::{
error::PolarsError as Error,
frame::DataFrame,
prelude::{AnyValue, PlSmallStr},
series::Series,
};

use crate::BoltType;

#[derive(Debug, Clone)]
pub(super) struct DataBuf {
fields: Vec<PlSmallStr>,
buffers: Vec<ColBuf>,
}

impl DataBuf {
pub(super) fn new<S: Into<PlSmallStr>>(fields: impl IntoIterator<Item = S>) -> Self {
let fields = fields.into_iter().map(Into::into).collect::<Vec<_>>();
let buffers = vec![ColBuf::default(); fields.len()];
Self { fields, buffers }
}

pub(super) fn push(&mut self, values: Vec<BoltType>) {
assert_eq!(values.len(), self.fields.len());
for (buf, value) in self.buffers.iter_mut().zip(values) {
buf.push(value);
}
}

pub(super) fn flush(&mut self) -> Result<(), Error> {
for buf in &mut self.buffers {
buf.flush()?;
}
Ok(())
}

pub(super) fn into_df(self) -> Result<DataFrame, polars::error::PolarsError> {
let serieses = self
.buffers
.into_iter()
.zip(self.fields.into_iter())
.map(|(buf, field)| buf.into_series(field))
.collect::<Result<Vec<_>, _>>()?;

DataFrame::new(serieses)
}
}

// TODO: use AnyValueBuffer
#[derive(Debug, Default, Clone)]
struct ColBuf {
values: Vec<AnyValue<'static>>,
series: Option<Series>,
}

impl ColBuf {
fn push(&mut self, value: BoltType) {
let value = match value {
BoltType::String(v) => AnyValue::StringOwned(v.value.into()),
BoltType::Boolean(v) => AnyValue::Boolean(v.value),
BoltType::Map(_) => todo!(),
BoltType::Null(_) => AnyValue::Null,
BoltType::Integer(v) => AnyValue::Int64(v.value),
BoltType::Float(v) => AnyValue::Float64(v.value),
BoltType::List(_) => todo!(),
BoltType::Node(_) => todo!(),
BoltType::Relation(_) => todo!(),
BoltType::UnboundedRelation(_) => todo!(),
BoltType::Point2D(_) => todo!(),
BoltType::Point3D(_) => todo!(),
BoltType::Bytes(v) => AnyValue::BinaryOwned(v.value.into()),
BoltType::Path(_) => todo!(),
BoltType::Duration(_) => todo!(),
BoltType::Date(_) => todo!(),
BoltType::Time(_) => todo!(),
BoltType::LocalTime(_) => todo!(),
BoltType::DateTime(_) => todo!(),
BoltType::LocalDateTime(_) => todo!(),
BoltType::DateTimeZoneId(_) => todo!(),
};
self.values.push(value);
}

fn flush(&mut self) -> Result<(), Error> {
let chunk = Series::from_any_values(PlSmallStr::EMPTY, &self.values, false)?;
if let Some(series) = &mut self.series {
series.append(&chunk)?;
} else {
self.series = Some(chunk);
}

Ok(())
}

fn into_series(mut self, name: PlSmallStr) -> Result<Series, Error> {
self.flush()?;
Ok(self.series.unwrap().with_name(name))
}
}
}
// }}}
42 changes: 42 additions & 0 deletions lib/tests/result_as_dataframe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#![cfg(feature = "polars_v0_43")]

use neo4rs::query;
use polars::prelude::{AnyValue, DataType};

mod container;

#[tokio::test]
async fn result_as_dataframe() {
let neo4j = container::Neo4jContainer::new().await;
let graph = neo4j.graph();

let result = graph
.execute(query(
"UNWIND [TRUE, FALSE, NULL, 1, 420000, 13.37] AS values RETURN values",
))
.await
.unwrap();

let df = result.into_dataframe().await.unwrap();
#[cfg(feature = "unstable-result-summary")]
let df = df.0;

assert_eq!(df.get_column_names(), ["values"]);
assert_eq!(df.height(), 6);
assert_eq!(df.width(), 1);

let values = df.column("values").unwrap();

assert_eq!(values.dtype(), &DataType::Float64);
values
.iter()
.filter_map(|a| match a {
AnyValue::Float64(a) => Some(a),
AnyValue::Null => None,
_ => panic!("`{a:?} is not a float or null"),
})
.zip([1.0, 0.0, 1.0, 420000.0, 13.37])
.for_each(|(a, b)| {
assert!((a - b).abs() <= f64::EPSILON);
});
}

0 comments on commit 858e635

Please sign in to comment.