Skip to content

Commit

Permalink
pyarrow: Preserve RecordBatch's schema metadata (#5355)
Browse files Browse the repository at this point in the history
* pyarrow: Preserve RecordBatch's schema metadata

PyArrow's RecordBatch gets imported by using a StructArray, then
transforming it into a RecordBatch. This was losing the Schema's
metadata in the process (because a StructArray does not hold
metadata).

This commit changes the test to show the issue, and fixes it.

* FromPyArrow for RecordBatch: directly create RB with the right schema

---------

Co-authored-by: atwam <wam@atwam.com>
  • Loading branch information
atwam and atwam authored Feb 4, 2024
1 parent 5093b78 commit f303c9e
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
11 changes: 9 additions & 2 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,10 @@ def test_field_metadata_roundtrip():

def test_schema_roundtrip():
pyarrow_fields = zip(string.ascii_lowercase, _supported_pyarrow_types)
pyarrow_schema = pa.schema(pyarrow_fields)
pyarrow_schema = pa.schema(pyarrow_fields, metadata = {b'key1': b'value1'})
schema = rust.round_trip_schema(pyarrow_schema)
assert schema == pyarrow_schema
assert schema.metadata == pyarrow_schema.metadata


def test_primitive_python():
Expand Down Expand Up @@ -467,9 +468,11 @@ def test_tensor_array():
b = rust.round_trip_array(f32_array)
assert b == f32_array.storage

batch = pa.record_batch([f32_array], ["tensor"])
batch = pa.record_batch([f32_array], ["tensor"], metadata={b'key1': b'value1'})
b = rust.round_trip_record_batch(batch)
assert b == batch
assert b.schema == batch.schema
assert b.schema.metadata == batch.schema.metadata

del b

Expand All @@ -486,13 +489,15 @@ def test_record_batch_reader():
b = rust.round_trip_record_batch_reader(a)

assert b.schema == schema
assert b.schema.metadata == schema.metadata
got_batches = list(b)
assert got_batches == batches

# Also try the boxed reader variant
a = pa.RecordBatchReader.from_batches(schema, batches)
b = rust.boxed_reader_roundtrip(a)
assert b.schema == schema
assert b.schema.metadata == schema.metadata
got_batches = list(b)
assert got_batches == batches

Expand All @@ -511,6 +516,7 @@ def test_record_batch_reader_pycapsule():
b = rust.round_trip_record_batch_reader(wrapped)

assert b.schema == schema
assert b.schema.metadata == schema.metadata
got_batches = list(b)
assert got_batches == batches

Expand All @@ -519,6 +525,7 @@ def test_record_batch_reader_pycapsule():
wrapped = StreamWrapper(a)
b = rust.boxed_reader_roundtrip(wrapped)
assert b.schema == schema
assert b.schema.metadata == schema.metadata
got_batches = list(b)
assert got_batches == batches

Expand Down
11 changes: 10 additions & 1 deletion arrow/src/pyarrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,16 @@ impl FromPyArrow for RecordBatch {
));
}
let array = StructArray::from(array_data);
return Ok(array.into());
// StructArray does not embed metadata from schema. We need to override
// the output schema with the schema from the capsule.
let schema = Arc::new(Schema::try_from(schema_ptr).map_err(to_py_err)?);
let (_fields, columns, nulls) = array.into_parts();
assert_eq!(
nulls.map(|n| n.null_count()).unwrap_or_default(),
0,
"Cannot convert nullable StructArray to RecordBatch, see StructArray documentation"
);
return RecordBatch::try_new(schema, columns).map_err(to_py_err);
}

validate_class("RecordBatch", value)?;
Expand Down

0 comments on commit f303c9e

Please sign in to comment.