Skip to content

Commit

Permalink
update dispatcher
Browse files Browse the repository at this point in the history
  • Loading branch information
oaksharks committed May 6, 2023
1 parent 9cbc56c commit 44c19c0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion hypernets/core/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def to_df(self, include_params=False):

return df

def plot_best_trials(self, index=True, figsize=(5, 5), loc='upper left', bbox_to_anchor=None, xlim=None, ylim=None):
def plot_best_trials(self, index=True, figsize=(5, 5), loc=None, bbox_to_anchor=None, xlim=None, ylim=None):
try:
from matplotlib import pyplot as plt
except Exception:
Expand Down
5 changes: 3 additions & 2 deletions hypernets/dispatchers/in_process_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,20 @@ def dispatch(self, hyper_model, X, y, X_eval, y_eval, X_test, cv, num_folds, max
trial_no = 1
retry_counter = 0

importances = None
space_options = {}
if hyper_model.searcher.kind() == const.SEARCHER_MOO:
if 'feature_usage' in [_.name for _ in hyper_model.searcher.objectives]:
tb = get_tool_box(X, y)
preprocessor = tb.general_preprocessor(X)
estimator = tb.general_estimator(X, y, task=hyper_model.task)
estimator.fit(preprocessor.fit_transform(X, y), y)
importances = list(zip(estimator.feature_name_, estimator.feature_importances_))
space_options['importances'] = importances

while trial_no <= max_trials:
gc.collect()
try:
space_options = dict(importances=importances)

space_sample = hyper_model.searcher.sample(space_options=space_options)
if hyper_model.history.is_existed(space_sample):
if retry_counter >= retry_limit:
Expand Down

0 comments on commit 44c19c0

Please sign in to comment.