From 9f8157800adbbbd627b7d0443c244a4a2b8afaf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Capretto?= Date: Fri, 14 Jul 2023 23:43:19 -0300 Subject: [PATCH] Fix bug when intercept is added after categorical variable (#102) --- docs/CHANGELOG.md | 2 ++ formulae/contrasts.py | 2 -- formulae/terms/terms.py | 13 ++++++++++++- tests/test_design_matrices.py | 21 +++++++++++++++++++++ tests/test_eval_new_data.py | 2 +- 5 files changed, 36 insertions(+), 4 deletions(-) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index c025d5e..4835bc3 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -6,6 +6,8 @@ ### Maintenance and fixes +- Fix bug when intercept is inserted after categorical variable (#102) + ### Documentation ### Deprecation diff --git a/formulae/contrasts.py b/formulae/contrasts.py index 499f4c7..c87ddd6 100644 --- a/formulae/contrasts.py +++ b/formulae/contrasts.py @@ -173,8 +173,6 @@ def pick_contrasts(group): used_subterms = set() codings = {} - for name, components in group.items(): codings[name] = ExpandedTerm(name, components).pick_contrast(used_subterms) - return codings diff --git a/formulae/terms/terms.py b/formulae/terms/terms.py index 1ce3921..6eef9b1 100644 --- a/formulae/terms/terms.py +++ b/formulae/terms/terms.py @@ -1162,7 +1162,18 @@ def set_types(self, data, env): def _get_encoding_groups(self): components = {} - for term in self.common_terms: + # This is not the best fix, but we need the intercept to be in the first position + common_terms = self.common_terms.copy() + intercept_idx = -1 + for i, term in enumerate(common_terms): + if isinstance(term, Intercept): + intercept_idx = i + break + + if intercept_idx != -1: + common_terms.insert(0, common_terms.pop(intercept_idx)) + + for term in common_terms: if term.kind == "interaction": components[term.name] = {c.name: c.kind for c in term.components} else: diff --git a/tests/test_design_matrices.py b/tests/test_design_matrices.py index d2743dd..b1178ae 100644 --- a/tests/test_design_matrices.py +++ b/tests/test_design_matrices.py @@ -1217,3 +1217,24 @@ def f(*args, **kwargs): dm = design_matrices('f(x, y="abcd")', df) assert 'f(x, y="abcd")' in dm.common.terms + + +def test_add_and_remove_intercept_works(): + df = pd.DataFrame({"g": ["a", "a", "a", "b", "b"]}) + dm = design_matrices("0 + g + 1", df) + assert np.allclose( + np.asarray(dm.common), + np.array([[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 1.0], [1.0, 1.0]]), + ) + + dm = design_matrices("-1 + g + 1", df) + assert np.allclose( + np.asarray(dm.common), + np.array([[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 1.0], [1.0, 1.0]]), + ) + + dm = design_matrices("0 + g + 0 + 1", df) + assert np.allclose( + np.asarray(dm.common), + np.array([[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 1.0], [1.0, 1.0]]), + ) diff --git a/tests/test_eval_new_data.py b/tests/test_eval_new_data.py index f8cf246..81eb298 100644 --- a/tests/test_eval_new_data.py +++ b/tests/test_eval_new_data.py @@ -385,7 +385,7 @@ def test_new_group_specific_groups(): group = design_matrices("1 + (1 + x | g1:g2)", df).group group_new = group.evaluate_new_data(df_2) - assert group_new.factors_with_new_levels == ("g1:g2", ) + assert group_new.factors_with_new_levels == ("g1:g2",) assert np.array_equal( np.array(group_new), np.array([[0, 0, 0, 0, 1, 0, 0, 0, 0, 5], [0, 0, 0, 0, 1, 0, 0, 0, 0, 6]]),