Skip to content

Commit

Permalink
Generate correct metadata in to_arrow with preserve_index=True
Browse files Browse the repository at this point in the history
When preserving the index and we have a RangeIndex, we must
materialize it, and write that information in the metadata correctly.

- Closes #14159
  • Loading branch information
wence- committed Mar 22, 2024
1 parent dea3ab7 commit dd53c60
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5485,14 +5485,18 @@ def from_arrow(cls, table):
return out

@_cudf_nvtx_annotate
def to_arrow(self, preserve_index=True):
def to_arrow(self, preserve_index=None):
"""
Convert to a PyArrow Table.
Parameters
----------
preserve_index : bool, default True
whether index column and its meta data needs to be saved or not
preserve_index : bool, optional
whether index column and its meta data needs to be saved
or not. The default of None will store the index as a
column, except for a RangeIndex which is stored as
metadata only. Setting preserve_index to True will force
a RangeIndex to be materialized.
Returns
-------
Expand Down Expand Up @@ -5523,34 +5527,35 @@ def to_arrow(self, preserve_index=True):

data = self.copy(deep=False)
index_descr = []
if preserve_index:
if isinstance(self.index, cudf.RangeIndex):
write_index = preserve_index is not False
keep_range_index = write_index and preserve_index is None
index = self.index
if write_index:
if isinstance(index, cudf.RangeIndex) and keep_range_index:
descr = {
"kind": "range",
"name": self.index.name,
"start": self.index._start,
"stop": self.index._stop,
"name": index.name,
"start": index._start,
"stop": index._stop,
"step": 1,
}
else:
if isinstance(self.index, MultiIndex):
if isinstance(index, cudf.RangeIndex):
index = index._as_int_index()
index.name = "__index_level_0__"
if isinstance(index, MultiIndex):
gen_names = tuple(
f"level_{i}"
for i, _ in enumerate(self.index._data.names)
f"level_{i}" for i, _ in enumerate(index._data.names)
)
else:
gen_names = (
self.index.names
if self.index.name is not None
else ("index",)
index.names if index.name is not None else ("index",)
)
for gen_name, col_name in zip(
gen_names, self.index._data.names
):
for gen_name, col_name in zip(gen_names, index._data.names):
data._insert(
data.shape[1],
gen_name,
self.index._data[col_name],
index._data[col_name],
)
descr = gen_names[0]
index_descr.append(descr)
Expand All @@ -5560,7 +5565,7 @@ def to_arrow(self, preserve_index=True):
columns_to_convert=[self[col] for col in self._data.names],
df=self,
column_names=out.schema.names,
index_levels=[self.index],
index_levels=[index],
index_descriptors=index_descr,
preserve_index=preserve_index,
types=out.schema.types,
Expand Down

0 comments on commit dd53c60

Please sign in to comment.