Skip to content

Commit

Permalink
Update is_categorical_dtype for pandas>=2.1.0 (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
kddubey authored Jan 22, 2024
1 parent e02d76d commit f6199f1
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
3 changes: 2 additions & 1 deletion formulae/terms/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import numpy as np
import pandas as pd

from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_string_dtype
from pandas.api.types import is_numeric_dtype, is_string_dtype

from formulae.categorical import ENCODINGS, CategoricalBox, Treatment
from formulae.config import config
from formulae.transforms import TRANSFORMS, Proportion, Offset
from formulae.terms.call_utils import CallVarsExtractor
from formulae.utils import is_categorical_dtype


class Call:
Expand Down
3 changes: 2 additions & 1 deletion formulae/terms/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import numpy as np
import pandas as pd

from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_string_dtype
from pandas.api.types import is_numeric_dtype, is_string_dtype

from formulae.config import config
from formulae.categorical import Treatment
from formulae.utils import is_categorical_dtype


class Variable:
Expand Down
14 changes: 14 additions & 0 deletions formulae/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from copy import deepcopy

import numpy as np
import pandas as pd


def listify(obj):
Expand Down Expand Up @@ -38,3 +39,16 @@ def get_interaction_matrix(x, y):
for j2 in range(y.shape[1]):
l.append(x[:, j1] * y[:, j2])
return np.column_stack(l)


def is_categorical_dtype(arr_or_dtype):
"""Check whether an array-like or dtype is of the pandas Categorical dtype."""
# https://pandas.pydata.org/docs/whatsnew/v2.1.0.html#other-deprecations
if pd.__version__ < "2.1.0":
return pd.api.types.is_categorical_dtype(arr_or_dtype)
else:
if hasattr(arr_or_dtype, "dtype"): # it's an array
dtype = getattr(arr_or_dtype, "dtype")
else:
dtype = arr_or_dtype
return isinstance(dtype, pd.CategoricalDtype)

0 comments on commit f6199f1

Please sign in to comment.