Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TST(string dtype): Resolve HDF5 xfails in test_put.py #60625

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions pandas/io/pytables.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
PeriodArray,
)
from pandas.core.arrays.datetimes import tz_to_dtype
from pandas.core.arrays.string_ import BaseStringArray
import pandas.core.common as com
from pandas.core.computation.pytables import (
PyTablesExpr,
Expand Down Expand Up @@ -3185,6 +3186,8 @@ def write_array(
# both self._filters and EA

value = extract_array(obj, extract_numpy=True)
if isinstance(value, BaseStringArray):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this logic live in extract_array?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The line from the docstring of extract_array states:

Extract the ndarray or ExtensionArray from a Series or Index.

and the extract_numpy argument:

Whether to extract the ndarray from a NumpyExtensionArray.

So I think no - I would expect that function to still return the ExtensionArray.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the tests currently cover the "string" data type going through pytables? This seems like it might mangle the NA markers?

Not super familiar with pytables so not saying this is right or wrong - just want to double check

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are certainly tests with dataframe roundtrip that contain string columns, not entirely sure if there are also tests where those columns have missing values, though

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jorisvandenbossche - I'm not seeing this particular line hit with string dtype. I added a test, including an NA value.

@WillAyd - the current behavior is to write out the underlying objects, and then infer upon loading. So if we start with string an future.infer_string=False, we get object. When that option is True, we get str.

value = value.to_numpy()

if key in self.group:
self._handle.remove_node(self.group, key)
Expand Down Expand Up @@ -3294,7 +3297,11 @@ def read(
index = self.read_index("index", start=start, stop=stop)
values = self.read_array("values", start=start, stop=stop)
result = Series(values, index=index, name=self.name, copy=False)
if using_string_dtype() and is_string_array(values, skipna=True):
if (
using_string_dtype()
and isinstance(values, np.ndarray)
and is_string_array(values, skipna=True)
):
result = result.astype(StringDtype(na_value=np.nan))
return result

Expand Down Expand Up @@ -3363,7 +3370,11 @@ def read(

columns = items[items.get_indexer(blk_items)]
df = DataFrame(values.T, columns=columns, index=axes[1], copy=False)
if using_string_dtype() and is_string_array(values, skipna=True):
if (
using_string_dtype()
and isinstance(values, np.ndarray)
and is_string_array(values, skipna=True)
):
Comment on lines -3366 to +3377
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same code pattern happens above in SeriesFixed read() method. Not entirely sure if the same change should be applied there as well, but would expect so (but maybe not covered by any test?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, added a test

df = df.astype(StringDtype(na_value=np.nan))
dfs.append(df)

Expand Down Expand Up @@ -4737,9 +4748,10 @@ def read(
df = DataFrame._from_arrays([values], columns=cols_, index=index_)
if not (using_string_dtype() and values.dtype.kind == "O"):
assert (df.dtypes == values.dtype).all(), (df.dtypes, values.dtype)
if using_string_dtype() and is_string_array(
values, # type: ignore[arg-type]
skipna=True,
if (
using_string_dtype()
and isinstance(values, np.ndarray)
and is_string_array(values, skipna=True)
):
df = df.astype(StringDtype(na_value=np.nan))
frames.append(df)
Expand Down
54 changes: 39 additions & 15 deletions pandas/tests/io/pytables/test_put.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas._libs.tslibs import Timestamp

import pandas as pd
Expand All @@ -26,7 +24,6 @@

pytestmark = [
pytest.mark.single_cpu,
pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False),
]


Expand Down Expand Up @@ -106,7 +103,7 @@ def test_put(setup_path):
)
df = DataFrame(
np.random.default_rng(2).standard_normal((20, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=20, freq="B"),
)
store["a"] = ts
Expand All @@ -133,7 +130,9 @@ def test_put(setup_path):

# overwrite table
store.put("c", df[:10], format="table", append=False)
tm.assert_frame_equal(df[:10], store["c"])
expected = df[:10]
result = store["c"]
tm.assert_frame_equal(result, expected)


def test_put_string_index(setup_path):
Expand Down Expand Up @@ -166,12 +165,14 @@ def test_put_compression(setup_path):
with ensure_clean_store(setup_path) as store:
df = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=10, freq="B"),
)

store.put("c", df, format="table", complib="zlib")
tm.assert_frame_equal(store["c"], df)
expected = df
result = store["c"]
tm.assert_frame_equal(result, expected)

# can't compress if format='fixed'
msg = "Compression not supported on Fixed format stores"
Expand All @@ -183,7 +184,7 @@ def test_put_compression(setup_path):
def test_put_compression_blosc(setup_path):
df = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=10, freq="B"),
)

Expand All @@ -194,17 +195,29 @@ def test_put_compression_blosc(setup_path):
store.put("b", df, format="fixed", complib="blosc")

store.put("c", df, format="table", complib="blosc")
tm.assert_frame_equal(store["c"], df)
expected = df
result = store["c"]
tm.assert_frame_equal(result, expected)


def test_put_mixed_type(setup_path, performance_warning):
def test_put_datetime_ser(setup_path, performance_warning, using_infer_string):
# https://github.com/pandas-dev/pandas/pull/60625
ser = Series(3 * [Timestamp("20010102").as_unit("ns")])
with ensure_clean_store(setup_path) as store:
store.put("ser", ser)
expected = ser.copy()
result = store.get("ser")
tm.assert_series_equal(result, expected)


def test_put_mixed_type(setup_path, performance_warning, using_infer_string):
df = DataFrame(
np.random.default_rng(2).standard_normal((10, 4)),
columns=Index(list("ABCD"), dtype=object),
columns=Index(list("ABCD")),
index=date_range("2000-01-01", periods=10, freq="B"),
)
df["obj1"] = "foo"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
df["obj1"] = "foo"
df["obj1"] = np.array(["foo"] * 10, dtype=object)
df["str"] = pd.Series(["a", None, "b", "c", "d"] * 2).array

To explicitly test here with both object dtype and string dtype, and with strings with missing values (this is passing locally, so that should address Will's comment at https://github.com/pandas-dev/pandas/pull/60625/files#r1900476124

df["obj2"] = "bar"
df["obj2"] = pd.array([pd.NA] + 9 * ["bar"], dtype="string")
df["bool1"] = df["A"] > 0
df["bool2"] = df["B"] > 0
df["bool3"] = True
Expand All @@ -223,8 +236,13 @@ def test_put_mixed_type(setup_path, performance_warning):
with tm.assert_produces_warning(performance_warning):
store.put("df", df)

expected = store.get("df")
tm.assert_frame_equal(expected, df)
expected = df.copy()
if using_infer_string:
expected["obj2"] = expected["obj2"].astype("str")
else:
expected["obj2"] = expected["obj2"].astype("object")
result = store.get("df")
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("format", ["table", "fixed"])
Expand Down Expand Up @@ -253,7 +271,7 @@ def test_store_index_types(setup_path, format, index):
tm.assert_frame_equal(df, store["df"])


def test_column_multiindex(setup_path):
def test_column_multiindex(setup_path, using_infer_string):
# GH 4710
# recreate multi-indexes properly

Expand All @@ -264,6 +282,12 @@ def test_column_multiindex(setup_path):
expected = df.set_axis(df.index.to_numpy())

with ensure_clean_store(setup_path) as store:
if using_infer_string:
# TODO(infer_string) make this work for string dtype
msg = "Saving a MultiIndex with an extension dtype is not supported."
with pytest.raises(NotImplementedError, match=msg):
store.put("df", df)
return
store.put("df", df)
tm.assert_frame_equal(
store["df"], expected, check_index_type=True, check_column_type=True
Expand Down
Loading