Skip to content

Commit

Permalink
PyArrow: Pass in null-mask (#1264)
Browse files Browse the repository at this point in the history
* PyArrow: Pass in null-mask

* Add missing flag
  • Loading branch information
Fokko authored Oct 29, 2024
1 parent 58a7be3 commit 9a6a9a1
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
6 changes: 5 additions & 1 deletion pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,11 @@ def struct(
else:
raise ResolveError(f"Field is required, and could not be found in the file: {field}")

return pa.StructArray.from_arrays(arrays=field_arrays, fields=pa.struct(fields))
return pa.StructArray.from_arrays(
arrays=field_arrays,
fields=pa.struct(fields),
mask=struct_array.is_null() if isinstance(struct_array, pa.StructArray) else None,
)

def field(self, field: NestedField, _: Optional[pa.Array], field_array: Optional[pa.Array]) -> Optional[pa.Array]:
return field_array
Expand Down
41 changes: 41 additions & 0 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,47 @@ def test_rewrite_manifest_after_partition_evolution(session_catalog: Catalog) ->
)


@pytest.mark.integration
def test_writing_null_structs(session_catalog: Catalog) -> None:
import pyarrow as pa

schema = pa.schema([
pa.field(
"struct_field_1",
pa.struct([
pa.field("string_nested_1", pa.string()),
pa.field("int_item_2", pa.int32()),
pa.field("float_item_2", pa.float32()),
]),
),
])

records = [
{
"struct_field_1": {
"string_nested_1": "nest_1",
"int_item_2": 1234,
"float_item_2": 1.234,
},
},
{},
]

try:
session_catalog.drop_table(
identifier="default.test_writing_null_structs",
)
except NoSuchTableError:
pass

table = session_catalog.create_table("default.test_writing_null_structs", schema)

pyarrow_table: pa.Table = pa.Table.from_pylist(records, schema=schema)
table.append(pyarrow_table)

assert pyarrow_table.to_pandas()["struct_field_1"].tolist() == table.scan().to_pandas()["struct_field_1"].tolist()


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_abort_table_transaction_on_exception(
Expand Down

0 comments on commit 9a6a9a1

Please sign in to comment.