@@ -678,7 +804,7 @@ Methods
Expand source code
def predict_proba(self, X, *args, **kwargs):
- if hasattr(self.estimator_, 'predict_proba'):
+ if hasattr(self.estimator_, "predict_proba"):
return self.estimator_.predict_proba(X, *args, **kwargs)
else:
return NotImplemented
@@ -694,7 +820,7 @@ Methods
Expand source code
def score(self, X, y, *args, **kwargs):
- if hasattr(self.estimator_, 'score'):
+ if hasattr(self.estimator_, "score"):
return self.estimator_.score(X, y, *args, **kwargs)
else:
return NotImplemented
@@ -729,12 +855,17 @@ Params
Expand source code
class HSTreeClassifier(HSTree, ClassifierMixin):
- def __init__(self, estimator_: BaseEstimator = DecisionTreeClassifier(max_leaf_nodes=20),
- reg_param: float = 1, shrinkage_scheme_: str = 'node_based'):
- super().__init__(estimator_=estimator_,
- reg_param=reg_param,
- shrinkage_scheme_=shrinkage_scheme_,
- )
+ def __init__(
+ self,
+ estimator_: BaseEstimator = DecisionTreeClassifier(max_leaf_nodes=20),
+ reg_param: float = 1,
+ shrinkage_scheme_: str = "node_based",
+ ):
+ super().__init__(
+ estimator_=estimator_,
+ reg_param=reg_param,
+ shrinkage_scheme_=shrinkage_scheme_,
+ )
Ancestors
@@ -767,11 +898,17 @@ Params
Expand source code
class HSTreeClassifierCV(HSTreeClassifier):
- def __init__(self, estimator_: BaseEstimator = None,
- reg_param_list: List[float] = [0, 0.1, 1, 10, 50, 100, 500],
- shrinkage_scheme_: str = 'node_based',
- max_leaf_nodes: int = 20,
- cv: int = 3, scoring=None, *args, **kwargs):
+ def __init__(
+ self,
+ estimator_: BaseEstimator = None,
+ reg_param_list: List[float] = [0, 0.1, 1, 10, 50, 100, 500],
+ shrinkage_scheme_: str = "node_based",
+ max_leaf_nodes: int = 20,
+ cv: int = 3,
+ scoring=None,
+ *args,
+ **kwargs
+ ):
"""Cross-validation is used to select the best regularization parameter for hierarchical shrinkage.
Params
@@ -801,7 +938,7 @@ Params
def fit(self, X, y, *args, **kwargs):
self.scores_ = [[] for _ in self.reg_param_list]
- scorer = kwargs.get('scoring', log_loss)
+ scorer = kwargs.get("scoring", log_loss)
kf = KFold(n_splits=self.cv)
for train_index, test_index in kf.split(X):
X_out, y_out = X[test_index, :], y[test_index]
@@ -818,8 +955,13 @@ Params
super().fit(X=X, y=y, *args, **kwargs)
def __repr__(self):
- attr_list = ["estimator_", "reg_param_list", "shrinkage_scheme_",
- "cv", "scoring"]
+ attr_list = [
+ "estimator_",
+ "reg_param_list",
+ "shrinkage_scheme_",
+ "cv",
+ "scoring",
+ ]
s = self.__class__.__name__
s += "("
for attr in attr_list:
@@ -846,7 +988,7 @@ Methods
def fit(self, X, y, *args, **kwargs):
self.scores_ = [[] for _ in self.reg_param_list]
- scorer = kwargs.get('scoring', log_loss)
+ scorer = kwargs.get("scoring", log_loss)
kf = KFold(n_splits=self.cv)
for train_index, test_index in kf.split(X):
X_out, y_out = X[test_index, :], y[test_index]
@@ -892,12 +1034,17 @@ Params
Expand source code
class HSTreeRegressor(HSTree, RegressorMixin):
- def __init__(self, estimator_: BaseEstimator = DecisionTreeRegressor(max_leaf_nodes=20),
- reg_param: float = 1, shrinkage_scheme_: str = 'node_based'):
- super().__init__(estimator_=estimator_,
- reg_param=reg_param,
- shrinkage_scheme_=shrinkage_scheme_,
- )
+ def __init__(
+ self,
+ estimator_: BaseEstimator = DecisionTreeRegressor(max_leaf_nodes=20),
+ reg_param: float = 1,
+ shrinkage_scheme_: str = "node_based",
+ ):
+ super().__init__(
+ estimator_=estimator_,
+ reg_param=reg_param,
+ shrinkage_scheme_=shrinkage_scheme_,
+ )
Ancestors
@@ -930,11 +1077,17 @@ Params
Expand source code
class HSTreeRegressorCV(HSTreeRegressor):
- def __init__(self, estimator_: BaseEstimator = None,
- reg_param_list: List[float] = [0, 0.1, 1, 10, 50, 100, 500],
- shrinkage_scheme_: str = 'node_based',
- max_leaf_nodes: int = 20,
- cv: int = 3, scoring=None, *args, **kwargs):
+ def __init__(
+ self,
+ estimator_: BaseEstimator = None,
+ reg_param_list: List[float] = [0, 0.1, 1, 10, 50, 100, 500],
+ shrinkage_scheme_: str = "node_based",
+ max_leaf_nodes: int = 20,
+ cv: int = 3,
+ scoring=None,
+ *args,
+ **kwargs
+ ):
"""Cross-validation is used to select the best regularization parameter for hierarchical shrinkage.
Params
@@ -965,7 +1118,7 @@ Params
def fit(self, X, y, *args, **kwargs):
self.scores_ = [[] for _ in self.reg_param_list]
kf = KFold(n_splits=self.cv)
- scorer = kwargs.get('scoring', mean_squared_error)
+ scorer = kwargs.get("scoring", mean_squared_error)
for train_index, test_index in kf.split(X):
X_out, y_out = X[test_index, :], y[test_index]
X_in, y_in = X[train_index, :], y[train_index]
@@ -981,8 +1134,13 @@ Params
super().fit(X=X, y=y, *args, **kwargs)
def __repr__(self):
- attr_list = ["estimator_", "reg_param_list", "shrinkage_scheme_",
- "cv", "scoring"]
+ attr_list = [
+ "estimator_",
+ "reg_param_list",
+ "shrinkage_scheme_",
+ "cv",
+ "scoring",
+ ]
s = self.__class__.__name__
s += "("
for attr in attr_list:
@@ -1010,7 +1168,7 @@ Methods
def fit(self, X, y, *args, **kwargs):
self.scores_ = [[] for _ in self.reg_param_list]
kf = KFold(n_splits=self.cv)
- scorer = kwargs.get('scoring', mean_squared_error)
+ scorer = kwargs.get("scoring", mean_squared_error)
for train_index, test_index in kf.split(X):
X_out, y_out = X[test_index, :], y[test_index]
X_in, y_in = X[train_index, :], y[train_index]
diff --git a/docs/util/data_util.html b/docs/util/data_util.html
index 032d277e..68e2853b 100644
--- a/docs/util/data_util.html
+++ b/docs/util/data_util.html
@@ -147,10 +147,12 @@
elif dataset_name == 'california_housing':
data = sklearn.datasets.fetch_california_housing(
data_home=oj(data_path, 'sklearn_data'))
+ elif dataset_name == 'breast_cancer':
+ data = sklearn.datasets.load_breast_cancer()
return data['data'], data['target'], _clean_feat_names(data['feature_names'])
elif data_source == 'openml': # note this api might change in newer sklearn - should give dataset-id not name
data = sklearn.datasets.fetch_openml(
- data_id=dataset_name, data_home=oj(data_path, 'openml_data'))
+ data_id=dataset_name, data_home=oj(data_path, 'openml_data'), parser='auto')
X, y, feature_names = data['data'], data['target'], _clean_feat_names(
data['feature_names'])
if isinstance(X, pd.DataFrame):
@@ -372,10 +374,12 @@ Example
elif dataset_name == 'california_housing':
data = sklearn.datasets.fetch_california_housing(
data_home=oj(data_path, 'sklearn_data'))
+ elif dataset_name == 'breast_cancer':
+ data = sklearn.datasets.load_breast_cancer()
return data['data'], data['target'], _clean_feat_names(data['feature_names'])
elif data_source == 'openml': # note this api might change in newer sklearn - should give dataset-id not name
data = sklearn.datasets.fetch_openml(
- data_id=dataset_name, data_home=oj(data_path, 'openml_data'))
+ data_id=dataset_name, data_home=oj(data_path, 'openml_data'), parser='auto')
X, y, feature_names = data['data'], data['target'], _clean_feat_names(
data['feature_names'])
if isinstance(X, pd.DataFrame):
diff --git a/imodels/__init__.py b/imodels/__init__.py
index 68ec5e30..c1306908 100644
--- a/imodels/__init__.py
+++ b/imodels/__init__.py
@@ -5,7 +5,7 @@
# Github repo available [here](https://github.com/csinva/imodels)
from .algebraic.slim import SLIMRegressor, SLIMClassifier
-from .algebraic.gam import TreeGAMClassifier
+from .algebraic.tree_gam import TreeGAMClassifier, TreeGAMRegressor
from .discretization.discretizer import RFDiscretizer, BasicDiscretizer
from .discretization.mdlp import MDLPDiscretizer, BRLDiscretizer
from .experimental.bartpy import BART
@@ -23,26 +23,63 @@
from .rule_set.skope_rules import SkopeRulesClassifier
from .rule_set.slipper import SlipperClassifier
from .tree.c45_tree.c45_tree import C45TreeClassifier
-from .tree.cart_ccp import DecisionTreeCCPClassifier, DecisionTreeCCPRegressor, HSDecisionTreeCCPClassifierCV, \
- HSDecisionTreeCCPRegressorCV
+from .tree.cart_ccp import (
+ DecisionTreeCCPClassifier,
+ DecisionTreeCCPRegressor,
+ HSDecisionTreeCCPClassifierCV,
+ HSDecisionTreeCCPRegressorCV,
+)
+
# from .tree.iterative_random_forest.iterative_random_forest import IRFClassifier
# from .tree.optimal_classification_tree import OptimalTreeModel
from .tree.cart_wrapper import GreedyTreeClassifier, GreedyTreeRegressor
from .tree.figs import FIGSRegressor, FIGSClassifier, FIGSRegressorCV, FIGSClassifierCV
from .tree.gosdt.pygosdt import OptimalTreeClassifier
-from .tree.gosdt.pygosdt_shrinkage import HSOptimalTreeClassifier, HSOptimalTreeClassifierCV
-from .tree.hierarchical_shrinkage import HSTreeRegressor, HSTreeClassifier, HSTreeRegressorCV, HSTreeClassifierCV
+from .tree.gosdt.pygosdt_shrinkage import (
+ HSOptimalTreeClassifier,
+ HSOptimalTreeClassifierCV,
+)
+from .tree.hierarchical_shrinkage import (
+ HSTreeRegressor,
+ HSTreeClassifier,
+ HSTreeRegressorCV,
+ HSTreeClassifierCV,
+)
from .tree.tao import TaoTreeClassifier, TaoTreeRegressor
from .util.data_util import get_clean_dataset
from .util.distillation import DistilledRegressor
from .util.explain_errors import explain_classification_errors
-CLASSIFIERS = [BayesianRuleListClassifier, GreedyRuleListClassifier, SkopeRulesClassifier,
- BoostedRulesClassifier, SLIMClassifier, SlipperClassifier, BayesianRuleSetClassifier,
- C45TreeClassifier, OptimalTreeClassifier, OptimalRuleListClassifier, OneRClassifier,
- SlipperClassifier, RuleFitClassifier, TaoTreeClassifier,
- FIGSClassifier, HSTreeClassifier, HSTreeClassifierCV] # , IRFClassifier
-REGRESSORS = [RuleFitRegressor, SLIMRegressor, GreedyTreeClassifier, FIGSRegressor,
- TaoTreeRegressor, HSTreeRegressor, HSTreeRegressorCV, BART]
+CLASSIFIERS = [
+ BayesianRuleListClassifier,
+ GreedyRuleListClassifier,
+ SkopeRulesClassifier,
+ BoostedRulesClassifier,
+ SLIMClassifier,
+ SlipperClassifier,
+ BayesianRuleSetClassifier,
+ C45TreeClassifier,
+ OptimalTreeClassifier,
+ OptimalRuleListClassifier,
+ OneRClassifier,
+ SlipperClassifier,
+ RuleFitClassifier,
+ TaoTreeClassifier,
+ TreeGAMClassifier,
+ FIGSClassifier,
+ HSTreeClassifier,
+ HSTreeClassifierCV,
+] # , IRFClassifier
+REGRESSORS = [
+ RuleFitRegressor,
+ SLIMRegressor,
+ GreedyTreeClassifier,
+ FIGSRegressor,
+ TaoTreeRegressor,
+ TreeGAMRegressor,
+ HSTreeRegressor,
+ HSTreeRegressorCV,
+ BART,
+]
ESTIMATORS = CLASSIFIERS + REGRESSORS
DISCRETIZERS = [RFDiscretizer, BasicDiscretizer, MDLPDiscretizer, BRLDiscretizer]
diff --git a/imodels/algebraic/gam.py b/imodels/algebraic/tree_gam.py
similarity index 93%
rename from imodels/algebraic/gam.py
rename to imodels/algebraic/tree_gam.py
index 474918f2..3b3acc80 100644
--- a/imodels/algebraic/gam.py
+++ b/imodels/algebraic/tree_gam.py
@@ -16,8 +16,10 @@
import imodels
+from sklearn.base import RegressorMixin, ClassifierMixin
-class TreeGAMClassifier(BaseEstimator):
+
+class TreeGAM(BaseEstimator):
"""Tree-based GAM classifier.
Uses cyclical boosting to fit a GAM with small trees.
Simplified version of the explainable boosting machine described in https://github.com/interpretml/interpret
@@ -30,6 +32,7 @@ def __init__(
n_boosting_rounds=100,
max_leaf_nodes=3,
reg_param=0.0,
+ learning_rate: float = 0.01,
n_boosting_rounds_marginal=0,
max_leaf_nodes_marginal=2,
reg_param_marginal=0.0,
@@ -45,6 +48,8 @@ def __init__(
Maximum number of leaf nodes for the trees in the cyclic boosting.
reg_param : float
Regularization parameter for the cyclic boosting.
+ learning_rate: float
+ Learning rate for the cyclic boosting.
n_boosting_rounds_marginal : int
Number of boosting rounds for the marginal boosting.
max_leaf_nodes_marginal : int
@@ -56,21 +61,24 @@ def __init__(
NNLS for non-negative least squares
ridge for ridge regression
None for no linear model
+
random_state : int
Random seed.
"""
self.n_boosting_rounds = n_boosting_rounds
self.max_leaf_nodes = max_leaf_nodes
self.reg_param = reg_param
+ self.learning_rate = learning_rate
self.max_leaf_nodes_marginal = max_leaf_nodes_marginal
self.reg_param_marginal = reg_param_marginal
self.n_boosting_rounds_marginal = n_boosting_rounds_marginal
self.fit_linear_marginal = fit_linear_marginal
self.random_state = random_state
- def fit(self, X, y, sample_weight=None, learning_rate=0.01, validation_frac=0.15):
+ def fit(self, X, y, sample_weight=None, validation_frac=0.15):
X, y = check_X_y(X, y, accept_sparse=False, multi_output=False)
- check_classification_targets(y)
+ if isinstance(self, ClassifierMixin):
+ check_classification_targets(y)
sample_weight = _check_sample_weight(sample_weight, X, dtype=None)
# split into train and validation for early stopping
@@ -91,7 +99,6 @@ def fit(self, X, y, sample_weight=None, learning_rate=0.01, validation_frac=0.15
self.estimators_marginal = []
self.estimators_ = []
- self.learning_rate = learning_rate
self.bias_ = np.mean(y)
if self.n_boosting_rounds_marginal > 0:
@@ -208,7 +215,10 @@ def predict_proba(self, X):
return np.array([1 - probs1, probs1]).T
def predict(self, X):
- return np.argmax(self.predict_proba(X), axis=1)
+ if isinstance(self, RegressorMixin):
+ return self.predict_proba(X)[:, 1]
+ elif isinstance(self, ClassifierMixin):
+ return np.argmax(self.predict_proba(X), axis=1)
def get_shape_function_vals(self, X, max_evals=100):
"""Uses predict_proba to compute shape_function
@@ -236,6 +246,14 @@ def get_shape_function_vals(self, X, max_evals=100):
return feature_vals_list, shape_function_vals_list
+class TreeGAMRegressor(TreeGAM, RegressorMixin):
+ ...
+
+
+class TreeGAMClassifier(TreeGAM, ClassifierMixin):
+ ...
+
+
if __name__ == "__main__":
X, y, feature_names = imodels.get_clean_dataset("heart")
X, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
diff --git a/readme.md b/readme.md
index 15cf95d4..06aa8806 100644
--- a/readme.md
+++ b/readme.md
@@ -83,7 +83,7 @@ Install with `pip install imodels` (see [here](https://github.com/csinva/imodels
| TAO rule tree | [ποΈ](https://csinva.io/imodels/tree/tao.html), , γ
€γ
€[π](https://proceedings.neurips.cc/paper/2018/hash/185c29dc24325934ee377cfda20e414c-Abstract.html) | Fits tree using alternating optimization |
| Iterative random
forest | [ποΈ](https://csinva.io/imodels/tree/iterative_random_forest/iterative_random_forest.html), [π](https://github.com/Yu-Group/iterative-Random-Forest), [π](https://www.pnas.org/content/115/8/1943) | Repeatedly fit random forest, giving features with
high importance a higher chance of being selected |
| Sparse integer
linear model | [ποΈ](https://csinva.io/imodels/algebraic/slim.html), , γ
€γ
€[π](https://link.springer.com/article/10.1007/s10994-015-5528-6) | Sparse linear model with integer coefficients |
-| Tree GAM | [ποΈ](https://csinva.io/imodels/algebraic/gam.html), [π](https://github.com/interpretml/interpret), [π](https://dl.acm.org/doi/abs/10.1145/2339530.2339556) | Generalized additive model fit with short boosted trees |
+| Tree GAM | [ποΈ](https://csinva.io/imodels/algebraic/tree_gam.html), [π](https://github.com/interpretml/interpret), [π](https://dl.acm.org/doi/abs/10.1145/2339530.2339556) | Generalized additive model fit with short boosted trees |
| Greedy tree sums | [ποΈ](https://csinva.io/imodels/tree/figs.html#imodels.tree.figs), , γ
€γ
€[π](https://arxiv.org/abs/2201.11931) | Sum of small trees with very few total rules (FIGS) |
| Hierarchical
shrinkage wrapper | [ποΈ](https://csinva.io/imodels/tree/hierarchical_shrinkage.html), , γ
€γ
€[π](https://arxiv.org/abs/2202.00858) | Improve a decision tree, random forest, or
gradient-boosting ensemble with ultra-fast, post-hoc regularization |
| Distillation
wrapper | [ποΈ](https://csinva.io/imodels/util/distillation.html) | Train a black-box model,
then distill it into an interpretable model |
@@ -175,6 +175,7 @@ Different models support different machine-learning tasks. Current support for d
| TAO rule tree | [TaoTreeClassifier](https://csinva.io/imodels/tree/tao.html#imodels.tree.tao.TaoTreeClassifier) | [TaoTreeRegressor](https://csinva.io/imodels/tree/tao.html#imodels.tree.tao.TaoTreeRegressor) | |
| Iterative random forest | [IRFClassifier](https://csinva.io/imodels/tree/iterative_random_forest/iterative_random_forest.html#imodels.tree.iterative_random_forest.iterative_random_forest.IRFClassifier) | | Requires [irf](https://pypi.org/project/irf/) |
| Sparse integer linear model | [SLIMClassifier](https://csinva.io/imodels/algebraic/slim.html#imodels.algebraic.slim.SLIMClassifier) | [SLIMRegressor](https://csinva.io/imodels/algebraic/slim.html#imodels.algebraic.slim.SLIMRegressor) | Requires extra dependencies for speed |
+| Tree GAM | [TreeGAMClassifier](https://csinva.io/imodels/algebraic/tree_gam.html) | [TreeGAMRegressor](https://csinva.io/imodels/algebraic/tree_gam.html) | |
| Greedy tree sums (FIGS) | [FIGSClassifier](https://csinva.io/imodels/tree/figs.html#imodels.tree.figs.FIGSClassifier) | [FIGSRegressor](https://csinva.io/imodels/tree/figs.html#imodels.tree.figs.FIGSRegressor) | |
| Hierarchical shrinkage | [HSTreeClassifierCV](https://csinva.io/imodels/tree/hierarchical_shrinkage.html#imodels.tree.hierarchical_shrinkage.HSTreeClassifierCV) | [HSTreeRegressorCV](https://csinva.io/imodels/tree/hierarchical_shrinkage.html#imodels.tree.hierarchical_shrinkage.HSTreeRegressorCV) | Wraps any sklearn tree-based model |
| Distillation | | [DistilledRegressor](https://csinva.io/imodels/util/distillation.html#imodels.util.distillation.DistilledRegressor) | Wraps any sklearn-compatible models |
diff --git a/setup.py b/setup.py
index 6c450a26..b2f13255 100644
--- a/setup.py
+++ b/setup.py
@@ -26,7 +26,7 @@
setuptools.setup(
name="imodels",
- version="1.3.18",
+ version="1.4.0",
author="Chandan Singh, Keyan Nasseri, Matthew Epland, Yan Shuo Tan, Omer Ronen, Tiffany Tang, Abhineet Agarwal, Theo Saarinen, Bin Yu, and others",
author_email="chandan_singh@berkeley.edu",
description="Implementations of various interpretable models",
diff --git a/tests/regression_test.py b/tests/regression_test.py
index 4d1210fb..908015bc 100644
--- a/tests/regression_test.py
+++ b/tests/regression_test.py
@@ -6,7 +6,7 @@
from sklearn.tree import DecisionTreeRegressor
from imodels import RuleFitRegressor, SLIMRegressor, GreedyTreeRegressor, HSTreeRegressor, HSTreeRegressorCV, \
- FIGSRegressor, DistilledRegressor, TaoTreeRegressor, BoostedRulesRegressor
+ FIGSRegressor, DistilledRegressor, TaoTreeRegressor, BoostedRulesRegressor, TreeGAMRegressor
class TestClassRegression:
@@ -26,6 +26,7 @@ def test_regression(self):
BoostedRulesRegressor,
partial(DistilledRegressor, teacher=RandomForestRegressor(n_estimators=3),
student=DecisionTreeRegressor()),
+ TreeGAMRegressor,
]:
if model_type == RuleFitRegressor:
m = model_type(include_linear=False, max_rules=3)