Skip to content

Commit

Permalink
Support for read_sql_to_numpy (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
milesgranger authored Sep 11, 2022
1 parent 86efefb commit be05a27
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 76 deletions.
114 changes: 49 additions & 65 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ extension-module = ["pyo3/extension-module"]
[dependencies]
uuid = "0.8.2"
serde_json = "^1"
arrow2 = { version = "^0.13", features = ["io_ipc", "io_parquet", "io_ipc_compression"] }
numpy = "0.17"
arrow2 = { version = "^0.13", features = ["io_ipc", "io_parquet"] }
rust_decimal = { version = "1.16.0", features = ["db-postgres"] }
time = { version = "0.3.3", features = ["formatting"] }
postgres = { version = "0.19.1", features = ["with-time-0_3", "with-serde_json-1", "with-uuid-0_8"] }
Expand Down
16 changes: 8 additions & 8 deletions benchmarks/test_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import numpy as np
import pandas as pd
import connectorx as cx
import flaco
from memory_profiler import profile
from sqlalchemy import create_engine
Expand Down Expand Up @@ -119,22 +118,23 @@ def _table_setup(n_rows: int = 1_000_000, include_nulls: bool = False):
def memory_profile():
stmt = "select * from test_table"
flaco.read_sql_to_file(DB_URI, stmt, 'result.feather', flaco.FileFormat.Feather)

data = flaco.read_sql_to_numpy(DB_URI, stmt)
df = pd.DataFrame(data, copy=False).convert_dtypes()
import duckdb
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.feather as pf
import pyarrow.dataset as ds
#table = pq.read_table('result.parquet', memory_map=True).to_pandas()
with pa.memory_map('result.feather', 'rb') as source:
mytable = pa.ipc.open_file(source).read_all()
cur = duckdb.connect()
v = cur.execute('select count(*) from mytable').fetchall()
print(v)
table = mytable.rename_columns([f"col_{i}" for i in range(10)])
table_df = table.to_pandas()
#print(v)
print(type(mytable), len(mytable))
print(type(table), len(table))
print(pa.total_allocated_bytes() >> 20)
#engine = create_engine(DB_URI)
#_pandas_df = pd.read_sql(stmt, engine)
engine = create_engine(DB_URI)
_pandas_df = pd.read_sql(stmt, engine)


if __name__ == "__main__":
Expand Down
73 changes: 71 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
use arrow2::array::{
BinaryArray, BooleanArray, FixedSizeBinaryArray, MutableBinaryArray, MutableBooleanArray,
MutableFixedSizeBinaryArray, MutablePrimitiveArray, MutableUtf8Array, PrimitiveArray,
Utf8Array,
};
use arrow2::chunk::Chunk;
use arrow2::datatypes::{DataType, Schema};
use arrow2::io::{ipc, parquet};
use arrow2::{array, array::MutableArray};
use numpy::IntoPyArray;
use pyo3::create_exception;
use pyo3::exceptions::PyException;
use pyo3::prelude::*;
Expand All @@ -18,6 +24,7 @@ create_exception!(flaco, FlacoException, PyException);
fn flaco(py: Python, m: &PyModule) -> PyResult<()> {
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
m.add_function(wrap_pyfunction!(read_sql_to_file, m)?)?;
m.add_function(wrap_pyfunction!(read_sql_to_numpy, m)?)?;
m.add_class::<FileFormat>()?;
m.add("FlacoException", py.get_type::<FlacoException>())?;
Ok(())
Expand All @@ -35,6 +42,7 @@ fn to_py_err(err: impl ToString) -> PyErr {
PyErr::new::<FlacoException, _>(err.to_string())
}

/// Read SQL to a file; Parquet or Feather/IPC format.
// TODO: Stream data into a file in chunks during query reading
#[pyfunction]
pub fn read_sql_to_file(uri: &str, stmt: &str, path: &str, format: FileFormat) -> PyResult<()> {
Expand All @@ -47,6 +55,23 @@ pub fn read_sql_to_file(uri: &str, stmt: &str, path: &str, format: FileFormat) -
Ok(())
}

/// Read SQL to a dict of numpy arrays, where keys are column names.
/// NOTE: This is not very efficient currently, likely should not use it.
#[pyfunction]
pub fn read_sql_to_numpy<'py>(
py: Python<'py>,
uri: &str,
stmt: &str,
) -> PyResult<BTreeMap<String, PyObject>> {
let mut client = postgres::Client::connect(uri, postgres::NoTls).map_err(to_py_err)?;
let table = postgresql::read_sql(&mut client, stmt).map_err(to_py_err)?;
let mut result = BTreeMap::new();
for (name, column) in table {
result.insert(name, column.into_pyarray(py));
}
Ok(result)
}

pub type Table = BTreeMap<String, Column>;

pub struct Column {
Expand All @@ -67,10 +92,50 @@ impl Column {
pub fn inner_mut<T: Any + 'static>(&mut self) -> &mut T {
self.array.as_mut_any().downcast_mut::<T>().unwrap()
}
pub fn inner<T: Any + 'static>(&self) -> &T {
self.array.as_any().downcast_ref::<T>().unwrap()
}
pub fn push<V, T: array::TryPush<V> + Any + 'static>(&mut self, value: V) -> Result<()> {
self.inner_mut::<T>().try_push(value)?;
Ok(())
}
pub fn into_pyarray(mut self, py: Python) -> PyObject {
macro_rules! to_pyarray {
($mut_arr:ty, $arr:ty) => {{
self.inner_mut::<$mut_arr>()
.as_arc()
.as_ref()
.as_any()
.downcast_ref::<$arr>()
.unwrap()
.iter()
.map(|v| v.to_object(py))
.collect::<Vec<_>>()
.into_pyarray(py)
.to_object(py)
}};
}
match self.dtype {
DataType::Boolean => to_pyarray!(MutableBooleanArray, BooleanArray),
DataType::Binary => to_pyarray!(MutableBinaryArray<i32>, BinaryArray<i32>),
DataType::Utf8 => to_pyarray!(MutableUtf8Array<i32>, Utf8Array<i32>),
DataType::Int8 => to_pyarray!(MutablePrimitiveArray<i8>, PrimitiveArray<i8>),
DataType::Int16 => to_pyarray!(MutablePrimitiveArray<i16>, PrimitiveArray<i16>),
DataType::Int32 => to_pyarray!(MutablePrimitiveArray<i32>, PrimitiveArray<i32>),
DataType::UInt32 => to_pyarray!(MutablePrimitiveArray<u32>, PrimitiveArray<u32>),
DataType::Int64 => to_pyarray!(MutablePrimitiveArray<i64>, PrimitiveArray<i64>),
DataType::UInt64 => to_pyarray!(MutablePrimitiveArray<u64>, PrimitiveArray<u64>),
DataType::Float32 => to_pyarray!(MutablePrimitiveArray<f32>, PrimitiveArray<f32>),
DataType::Float64 => to_pyarray!(MutablePrimitiveArray<f64>, PrimitiveArray<f64>),
DataType::FixedSizeBinary(_) => {
to_pyarray!(MutableFixedSizeBinaryArray, FixedSizeBinaryArray)
}
_ => unimplemented!(
"Dtype: {:?} not implemented for conversion to numpy",
&self.dtype
),
}
}
}

fn write_table_to_parquet(table: Table, path: &str) -> Result<()> {
Expand Down Expand Up @@ -225,13 +290,17 @@ pub mod postgresql {
table
.entry(column_name)
.or_insert_with(|| Column::new(MutablePrimitiveArray::<f32>::new()))
.push::<_, MutablePrimitiveArray<f32>>(row.get::<_, Option<f32>>(idx))?;
.push::<_, MutablePrimitiveArray<f32>>(
row.get::<_, Option<f32>>(idx).or_else(|| Some(f32::NAN)),
)?;
}
&Type::FLOAT8 => {
table
.entry(column_name)
.or_insert_with(|| Column::new(MutablePrimitiveArray::<f64>::new()))
.push::<_, MutablePrimitiveArray<f64>>(row.get::<_, Option<f64>>(idx))?;
.push::<_, MutablePrimitiveArray<f64>>(
row.get::<_, Option<f64>>(idx).or_else(|| Some(f64::NAN)),
)?;
}
&Type::TIMESTAMP => {
table
Expand Down

0 comments on commit be05a27

Please sign in to comment.