Skip to content

Commit

Permalink
support default max_rounds multitask gam
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Mar 13, 2024
1 parent dec7a15 commit 75c42ec
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion imodels/algebraic/gam_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class MultiTaskGAM(BaseEstimator):

def __init__(
self,
ebm_kwargs={'n_jobs': 1},
ebm_kwargs={'n_jobs': 1, 'max_rounds': 5000, },
multitask=True,
interactions=0.95,
linear_penalty='ridge',
Expand Down
12 changes: 8 additions & 4 deletions tests/gam_multitask_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,19 +151,23 @@ def _roc_no_error(y_true, y_pred):


def compare_models():
X, y, feature_names = imodels.get_clean_dataset("heart")
# X, y, feature_names = imodels.get_clean_dataset("bike_sharing")
# X, y, feature_names = imodels.get_clean_dataset("heart")
X, y, feature_names = imodels.get_clean_dataset("bike_sharing")
# X, y, feature_names = imodels.get_clean_dataset("water-quality_multitask")
# X, y, feature_names = imodels.get_clean_dataset("diabetes")

# remove some features to speed things up
X = X[:, :2]
# X = X[:, :2]
X, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

results = defaultdict(list)
for gam in tqdm([
MultiTaskGAMRegressor(),
MultiTaskGAMRegressor(fit_target_curves=False),
# MultiTaskGAMRegressor(fit_target_curves=False),
AdaBoostRegressor(
estimator=MultiTaskGAMRegressor(
ebm_kwargs={'max_rounds': 50}),
n_estimators=8),
# AdaBoostRegressor(estimator=MultiTaskGAMRegressor(
# multitask=True), n_estimators=2),
# MultiTaskGAMRegressor(multitask=True, onehot_prior=True),
Expand Down

0 comments on commit 75c42ec

Please sign in to comment.