diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index 5320d0a5343e..3b46d5729a1f 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -476,6 +476,29 @@ def test_tensor_array(): del b + +def test_empty_recordbatch_with_row_count(): + """ + A pyarrow.RecordBatch with no columns but with `num_rows` set. + + `datafusion-python` gets this as the result of a `count(*)` query. + """ + + # Create an empty schema with no fields + batch = pa.RecordBatch.from_pydict({"a": [1, 2, 3, 4]}).select([]) + num_rows = 4 + assert batch.num_rows == num_rows + assert batch.num_columns == 0 + + b = rust.round_trip_record_batch(batch) + assert b == batch + assert b.schema == batch.schema + assert b.schema.metadata == batch.schema.metadata + + assert b.num_rows == batch.num_rows + + del b + def test_record_batch_reader(): """ Python -> Rust -> Python diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs index 43cdb4fe0919..336398cbf22f 100644 --- a/arrow/src/pyarrow.rs +++ b/arrow/src/pyarrow.rs @@ -59,7 +59,7 @@ use std::convert::{From, TryFrom}; use std::ptr::{addr_of, addr_of_mut}; use std::sync::Arc; -use arrow_array::{RecordBatchIterator, RecordBatchReader, StructArray}; +use arrow_array::{RecordBatchIterator, RecordBatchOptions, RecordBatchReader, StructArray}; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::ffi::Py_uintptr_t; use pyo3::import_exception; @@ -361,6 +361,7 @@ impl FromPyArrow for RecordBatch { "Expected Struct type from __arrow_c_array.", )); } + let options = RecordBatchOptions::default().with_row_count(Some(array_data.len())); let array = StructArray::from(array_data); // StructArray does not embed metadata from schema. We need to override // the output schema with the schema from the capsule. @@ -371,7 +372,7 @@ impl FromPyArrow for RecordBatch { 0, "Cannot convert nullable StructArray to RecordBatch, see StructArray documentation" ); - return RecordBatch::try_new(schema, columns).map_err(to_py_err); + return RecordBatch::try_new_with_options(schema, columns, &options).map_err(to_py_err); } validate_class("RecordBatch", value)?; @@ -386,7 +387,14 @@ impl FromPyArrow for RecordBatch { .map(|a| Ok(make_array(ArrayData::from_pyarrow_bound(&a)?))) .collect::>()?; - let batch = RecordBatch::try_new(schema, arrays).map_err(to_py_err)?; + let row_count = value + .getattr("num_rows") + .ok() + .and_then(|x| x.extract().ok()); + let options = RecordBatchOptions::default().with_row_count(row_count); + + let batch = + RecordBatch::try_new_with_options(schema, arrays, &options).map_err(to_py_err)?; Ok(batch) } }