From f6199f13c8c6b4e3561e50607ffa16dff305a022 Mon Sep 17 00:00:00 2001 From: kddubey Date: Mon, 22 Jan 2024 04:51:46 -0800 Subject: [PATCH] Update is_categorical_dtype for pandas>=2.1.0 (#105) --- formulae/terms/call.py | 3 ++- formulae/terms/variable.py | 3 ++- formulae/utils.py | 14 ++++++++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/formulae/terms/call.py b/formulae/terms/call.py index ef6c0d4..2dffa76 100644 --- a/formulae/terms/call.py +++ b/formulae/terms/call.py @@ -4,12 +4,13 @@ import numpy as np import pandas as pd -from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_string_dtype +from pandas.api.types import is_numeric_dtype, is_string_dtype from formulae.categorical import ENCODINGS, CategoricalBox, Treatment from formulae.config import config from formulae.transforms import TRANSFORMS, Proportion, Offset from formulae.terms.call_utils import CallVarsExtractor +from formulae.utils import is_categorical_dtype class Call: diff --git a/formulae/terms/variable.py b/formulae/terms/variable.py index 88e4dc9..cb89c20 100644 --- a/formulae/terms/variable.py +++ b/formulae/terms/variable.py @@ -4,10 +4,11 @@ import numpy as np import pandas as pd -from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_string_dtype +from pandas.api.types import is_numeric_dtype, is_string_dtype from formulae.config import config from formulae.categorical import Treatment +from formulae.utils import is_categorical_dtype class Variable: diff --git a/formulae/utils.py b/formulae/utils.py index ea420e2..5c10f0e 100644 --- a/formulae/utils.py +++ b/formulae/utils.py @@ -1,6 +1,7 @@ from copy import deepcopy import numpy as np +import pandas as pd def listify(obj): @@ -38,3 +39,16 @@ def get_interaction_matrix(x, y): for j2 in range(y.shape[1]): l.append(x[:, j1] * y[:, j2]) return np.column_stack(l) + + +def is_categorical_dtype(arr_or_dtype): + """Check whether an array-like or dtype is of the pandas Categorical dtype.""" + # https://pandas.pydata.org/docs/whatsnew/v2.1.0.html#other-deprecations + if pd.__version__ < "2.1.0": + return pd.api.types.is_categorical_dtype(arr_or_dtype) + else: + if hasattr(arr_or_dtype, "dtype"): # it's an array + dtype = getattr(arr_or_dtype, "dtype") + else: + dtype = arr_or_dtype + return isinstance(dtype, pd.CategoricalDtype)