Skip to content

Commit

Permalink
Add estimator attribute _estimator_type
Browse files Browse the repository at this point in the history
  • Loading branch information
lixfz committed Jun 30, 2023
1 parent 540ebe9 commit 1065ac1
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 1 deletion.
1 change: 1 addition & 0 deletions hypernets/experiment/compete.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def __getstate__(self):
state = super().__getstate__()
# Don't pickle experiment
if 'experiment' in state.keys():
state = state.copy()
state['experiment'] = None
return state

Expand Down
12 changes: 11 additions & 1 deletion hypernets/model/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

import numpy as np
from sklearn.model_selection import KFold, StratifiedKFold
from hypernets.utils import const


class Estimator():
def __init__(self, space_sample, task='binary', discriminator=None):
def __init__(self, space_sample, task=const.TASK_BINARY, discriminator=None):
self.space_sample = space_sample
self.task = task
self.discriminator = discriminator
Expand All @@ -19,6 +20,15 @@ def __init__(self, space_sample, task='binary', discriminator=None):
self.cv_ = None
self.cv_models_ = None

@property
def _estimator_type(self):
if self.task in {const.TASK_BINARY, const.TASK_MULTICLASS, const.TASK_MULTILABEL}:
return 'classifier'
elif self.task in {const.TASK_REGRESSION, }:
return 'regressor'
else:
return None

def set_discriminator(self, discriminator):
self.discriminator = discriminator

Expand Down
10 changes: 10 additions & 0 deletions hypernets/model/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,16 @@ def __init__(self, estimators, x_vals, y_vals):
def classes_(self):
return self.estimators[0].classes_

@property
def _estimator_type(self):
try:
if len(self.classes_) > 1:
return 'classifier'
else:
return 'regressor'
except:
return 'regressor'

def predict(self, X, **kwargs):
rows = 0
for x_val in self.x_vals:
Expand Down
3 changes: 3 additions & 0 deletions hypernets/tests/model/test_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def predict(self, *args, **kwargs):
def predict_proba(self, *args, **kwargs):
return self.cv_models_[0].predict_proba(*args, **kwargs)

@property
def _estimator_type(self):
return 'classifier'

class TestPredictionPerformanceObjective(BaseTestWithinModel):

Expand Down

0 comments on commit 1065ac1

Please sign in to comment.