Skip to content

Commit

Permalink
support onehot prior
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Mar 11, 2024
1 parent 0efced5 commit 657216f
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions imodels/algebraic/gam_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,23 @@ def __init__(
multitask=True,
interactions=0.95,
linear_penalty='ridge',
onehot_prior=True,
random_state=42,
):
"""
Params
------
Note: args override ebm_kwargs if there are duplicates
one_hot_prior: bool
If True and multitask, the linear model will be fit with a prior that the ebm
features predicting the target should have coef 1
"""
self.ebm_kwargs = ebm_kwargs
self.multitask = multitask
self.linear_penalty = linear_penalty
self.random_state = random_state
self.interactions = interactions
self.onehot_prior = onehot_prior

# override ebm_kwargs
ebm_kwargs['random_state'] = random_state
Expand Down Expand Up @@ -96,7 +101,17 @@ def fit(self, X, y, sample_weight=None):
elif self.linear_penalty == 'lasso':
self.lin_model = LassoCV(n_alphas=7)

self.lin_model.fit(feats, y)
if self.onehot_prior:
coef_prior_ = np.zeros((feats.shape[1], ))
coef_prior_[:num_features] = 1
preds_prior = feats @ coef_prior_
residuals = y - preds_prior
self.lin_model.fit(feats, residuals, sample_weight=sample_weight)
self.lin_model.coef_ = self.lin_model.coef_ + coef_prior_

else:
self.lin_model.fit(feats, y, sample_weight=sample_weight)

return self

def _extract_ebm_features(self, X):
Expand Down Expand Up @@ -148,16 +163,13 @@ def test_multitask_extraction():
# unit test
gam = MultiTaskGAMRegressor(multitask=False)
gam.fit(X, y_train)
# ebm = gam.ebm_
gam2 = MultiTaskGAMRegressor(multitask=True)
gam2.fit(X, y_train)
preds_orig = gam.predict(X_test)
assert np.allclose(preds_orig, gam2.ebms_[-1].predict(X_test))

# extracted curves + intercept should sum to original predictions
feats_extracted = gam2._extract_ebm_features(X_test)
num_samples = X_test.shape[0]
num_features = X_test.shape[1]

# get features for ebm that predicts target
feats_extracted_target = feats_extracted[:,
Expand All @@ -177,7 +189,7 @@ def test_multitask_extraction():
# X, y, feature_names = imodels.get_clean_dataset("diabetes")

# remove some features to speed things up
X = X[:, :3]
X = X[:, :2]
X, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

kwargs = dict(
Expand All @@ -187,7 +199,8 @@ def test_multitask_extraction():
for gam in tqdm([
# AdaBoostRegressor(estimator=MultiTaskGAMRegressor(
# multitask=True), n_estimators=2),
# MultiTaskGAMRegressor(multitask=False),
MultiTaskGAMRegressor(multitask=False, onehot_prior=True),
MultiTaskGAMRegressor(multitask=False, onehot_prior=False),
MultiTaskGAMRegressor(multitask=True),
# ExplainableBoostingRegressor(n_jobs=1, interactions=0)
]):
Expand Down

0 comments on commit 657216f

Please sign in to comment.