Skip to content

Commit

Permalink
add simplebaggingregressor
Browse files Browse the repository at this point in the history
  • Loading branch information
csinva committed Mar 13, 2024
1 parent 3e56bbc commit d3cd37d
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion imodels/util/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,41 @@ def predict(self, X):
return predictions


class SimpleBaggingRegressor:
def __init__(self, estimator, n_estimators=10, random_state=None):
self.estimator = estimator
self.n_estimators = n_estimators
self.random_state = random_state

def fit(self, X, y):
np.random.seed(self.random_state)
self.estimators_ = []

rng = np.random.default_rng(self.random_state)
for _ in range(self.n_estimators):
# Simple bootstrap sampling
# sample_indices = np.random.choice(
# range(X.shape[0]), size=X.shape[0], replace=True)
sample_indices = rng.choice(
range(X.shape[0]), size=X.shape[0], replace=True)
X_sample = X[sample_indices]
y_sample = y[sample_indices]

# Fit a base estimator
# estimator = DecisionTreeRegressor()
estimator = clone(self.estimator)
estimator.fit(X_sample, y_sample)
self.estimators_.append(estimator)

def predict(self, X):
# Collect predictions from each base estimator
predictions = np.array([estimator.predict(X)
for estimator in self.estimators_])

# Aggregate predictions
return np.mean(predictions, axis=0)


if __name__ == '__main__':
X, y, feature_names = imodels.get_clean_dataset('california_housing')
X_train, X_test, y_train, y_test = train_test_split(
Expand All @@ -83,7 +118,9 @@ def predict(self, X):
# estimator = DecisionTreeRegressor(max_depth=3)
estimator = imodels.algebraic.gam_multitask.MultiTaskGAMRegressor()
for n_estimators in [1, 3, 5]:
residual_boosting_regressor = ResidualBoostingRegressor(
# residual_boosting_regressor = ResidualBoostingRegressor(
# estimator=estimator, n_estimators=n_estimators)
residual_boosting_regressor = SimpleBaggingRegressor(
estimator=estimator, n_estimators=n_estimators)
residual_boosting_regressor.fit(X_train, y_train)

Expand Down

0 comments on commit d3cd37d

Please sign in to comment.