diff --git a/CHANGELOG.md b/CHANGELOG.md index 162a0d0..a563ff2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ ### Deprecation +### v0.3.3 + +### Maintenance and fixes + +- Fixed a bug in `CategoricalBox`. Now it considers the order of the categories if `data` is ordered and `levels` is `None` (#73) + ### v0.3.2 ### Maintenance and fixes diff --git a/formulae/categorical.py b/formulae/categorical.py index 818030b..cbcd0fa 100644 --- a/formulae/categorical.py +++ b/formulae/categorical.py @@ -76,6 +76,9 @@ class CategoricalBox: """ def __init__(self, data, contrast, levels): + # If 'data' is ordered and no explicit levels have been passed, use order in 'data'. + if hasattr(data.dtype, "ordered") and data.dtype.ordered and levels is None: + levels = data.dtype.categories.tolist() self.data = data self.contrast = contrast self.levels = levels @@ -198,7 +201,7 @@ def __init__(self, omit=None): def _omit_index(self, levels): """Returns a number between 0 and len(levels) - 1""" if self.omit is None: - # By default, omit the lats level. + # By default, omit the last level. return len(levels) - 1 else: return levels.index(self.omit) diff --git a/formulae/tests/test_design_matrices.py b/formulae/tests/test_design_matrices.py index fdc19c0..1912921 100644 --- a/formulae/tests/test_design_matrices.py +++ b/formulae/tests/test_design_matrices.py @@ -1133,7 +1133,18 @@ def test_extra_namespace(data): assert df["myfunc(x3)"].equals(np.log(df["x3"])) -def test_categorical_series(): - data = pd.DataFrame({"x": list("abc") * 10}) - data["x"] = pd.Categorical(data["x"], list("abc"), ordered=True) +def test_categorical_ordered_series(): + # Test it works + data = pd.DataFrame({"x": list("abcd") * 10}) + data["x"] = pd.Categorical(data["x"], list("bcda"), ordered=True) design_matrices("S(x)", data) + + # Test it works and it respects original order + levels = design_matrices("x", data).common.terms["x"].levels + assert levels == list("cda") + + levels = design_matrices("T(x)", data).common.terms["T(x)"].levels + assert levels == list("cda") + + levels = design_matrices("S(x)", data).common.terms["S(x)"].levels + assert levels == list("bcd") diff --git a/formulae/version.py b/formulae/version.py index f9aa3e1..e19434e 100644 --- a/formulae/version.py +++ b/formulae/version.py @@ -1 +1 @@ -__version__ = "0.3.2" +__version__ = "0.3.3"