Skip to content

Commit

Permalink
Fix bug when intercept is added after categorical variable
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed Jul 15, 2023
1 parent 61113a8 commit 1b425c2
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

### Maintenance and fixes

- Fix bug when intercept is inserted after categorical variable (#102)

### Documentation

### Deprecation
Expand Down
2 changes: 0 additions & 2 deletions formulae/contrasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,6 @@ def pick_contrasts(group):

used_subterms = set()
codings = {}

for name, components in group.items():
codings[name] = ExpandedTerm(name, components).pick_contrast(used_subterms)

return codings
13 changes: 12 additions & 1 deletion formulae/terms/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,18 @@ def set_types(self, data, env):

def _get_encoding_groups(self):
components = {}
for term in self.common_terms:
# This is not the best fix, but we need the intercept to be in the first position
common_terms = self.common_terms.copy()
intercept_idx = -1
for i, term in enumerate(common_terms):
if isinstance(term, Intercept):
intercept_idx = i
break

if intercept_idx != -1:
common_terms.insert(0, common_terms.pop(intercept_idx))

for term in common_terms:
if term.kind == "interaction":
components[term.name] = {c.name: c.kind for c in term.components}
else:
Expand Down
21 changes: 21 additions & 0 deletions tests/test_design_matrices.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,3 +1217,24 @@ def f(*args, **kwargs):

dm = design_matrices('f(x, y="abcd")', df)
assert 'f(x, y="abcd")' in dm.common.terms


def test_add_and_remove_intercept_works():
df = pd.DataFrame({"g": ["a", "a", "a", "b", "b"]})
dm = design_matrices("0 + g + 1", df)
assert np.allclose(
np.asarray(dm.common),
np.array([[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 1.0], [1.0, 1.0]]),
)

dm = design_matrices("-1 + g + 1", df)
assert np.allclose(
np.asarray(dm.common),
np.array([[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 1.0], [1.0, 1.0]]),
)

dm = design_matrices("0 + g + 0 + 1", df)
assert np.allclose(
np.asarray(dm.common),
np.array([[0.0, 1.0], [0.0, 1.0], [0.0, 1.0], [1.0, 1.0], [1.0, 1.0]]),
)
2 changes: 1 addition & 1 deletion tests/test_eval_new_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def test_new_group_specific_groups():
group = design_matrices("1 + (1 + x | g1:g2)", df).group
group_new = group.evaluate_new_data(df_2)

assert group_new.factors_with_new_levels == ("g1:g2", )
assert group_new.factors_with_new_levels == ("g1:g2",)
assert np.array_equal(
np.array(group_new),
np.array([[0, 0, 0, 0, 1, 0, 0, 0, 0, 5], [0, 0, 0, 0, 1, 0, 0, 0, 0, 6]]),
Expand Down

0 comments on commit 1b425c2

Please sign in to comment.