diff --git a/flaml/automl.py b/flaml/automl.py index 010ea85f41..e5c5a52dbd 100644 --- a/flaml/automl.py +++ b/flaml/automl.py @@ -1,8 +1,7 @@ -"""! - * Copyright (c) Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. See LICENSE file in the - * project root for license information. -""" +# ! +# * Copyright (c) Microsoft Corporation. All rights reserved. +# * Licensed under the MIT License. See LICENSE file in the +# * project root for license information. import time from typing import Callable, Optional from functools import partial @@ -311,7 +310,7 @@ def size(state: AutoMLState, config: dict) -> float: class AutoML: - """The AutoML class + """The AutoML class. Example: @@ -359,10 +358,10 @@ def model(self): return self.__dict__.get("_trained_estimator") def best_model_for_estimator(self, estimator_name): - """Return the best model found for a particular estimator + """Return the best model found for a particular estimator. Args: - estimator_name: a str of the estimator's name + estimator_name: a str of the estimator's name. Returns: An object with `predict()` and `predict_proba()` method (for @@ -398,7 +397,7 @@ def best_config_per_estimator(self): @property def best_loss(self): - """A float of the best loss found""" + """A float of the best loss found.""" return self._state.best_loss @property @@ -421,7 +420,7 @@ def classes_(self): @property def time_to_find_best_model(self) -> float: - """Time taken to find best model in seconds""" + """Time taken to find best model in seconds.""" return self.__dict__.get("_time_taken_best_iter") def predict(self, X_test): @@ -490,7 +489,7 @@ def _preprocess(self, X): if issparse(X): X = X.tocsr() if self._transformer: - X = self._transformer.transform(X, self._state.task) + X = self._transformer.transform(X) return X def _validate_data( @@ -583,13 +582,11 @@ def _validate_data( X_val.shape[0] == y_val.shape[0] ), "# rows in X_val must match length of y_val." if self._transformer: - self._state.X_val = self._transformer.transform(X_val, self._state.task) + self._state.X_val = self._transformer.transform(X_val) else: self._state.X_val = X_val if self._label_transformer: - self._state.y_val = self._label_transformer.transform( - y_val, self._state.task - ) + self._state.y_val = self._label_transformer.transform(y_val) else: self._state.y_val = y_val else: @@ -852,26 +849,26 @@ def _prepare_data(self, eval_method, split_ratio, n_splits): ) def add_learner(self, learner_name, learner_class): - """Add a customized learner + """Add a customized learner. Args: - learner_name: A string of the learner's name - learner_class: A subclass of flaml.model.BaseEstimator + learner_name: A string of the learner's name. + learner_class: A subclass of flaml.model.BaseEstimator. """ self._state.learner_classes[learner_name] = learner_class def get_estimator_from_log(self, log_file_name, record_id, task): - """Get the estimator from log file + """Get the estimator from log file. Args: - log_file_name: A string of the log file name + log_file_name: A string of the log file name. record_id: An integer of the record ID in the file, - 0 corresponds to the first trial + 0 corresponds to the first trial. task: A string of the task type, - 'binary', 'multi', 'regression', 'ts_forecast', 'rank' + 'binary', 'multi', 'regression', 'ts_forecast', 'rank'. Returns: - An estimator object for the given configuration + An estimator object for the given configuration. """ with training_log_reader(log_file_name) as reader: @@ -910,16 +907,16 @@ def retrain_from_log( auto_augment=True, **fit_kwargs, ): - """Retrain from log file + """Retrain from log file. Args: - log_file_name: A string of the log file name - X_train: A numpy array of training data in shape n*m + log_file_name: A string of the log file name. + X_train: A numpy array or dataframe of training data in shape n*m. For 'ts_forecast' task, the first column of X_train must be the timestamp column (datetime type). Other columns in the dataframe are assumed to be exogenous variables (categorical or numeric). - y_train: A numpy array of labels in shape n*1 + y_train: A numpy array or series of labels in shape n*1. dataframe: A dataframe of training data including label column. For 'ts_forecast' task, dataframe must be specified and should have at least two columns: timestamp and label, where the first @@ -1080,11 +1077,13 @@ def _decide_eval_method(self, time_budget): @property def search_space(self) -> dict: - """Search space - Must be called after fit(...) (use max_iter=0 to prevent actual fitting) + """Search space. + + Must be called after fit(...) + (use max_iter=0 and retrain_final=False to prevent actual fitting). Returns: - A dict of the search space + A dict of the search space. """ estimator_list = self.estimator_list if len(estimator_list) == 1: @@ -1101,7 +1100,7 @@ def search_space(self) -> dict: @property def low_cost_partial_config(self) -> dict: - """Low cost partial config + """Low cost partial config. Returns: A dict. @@ -1112,7 +1111,6 @@ def low_cost_partial_config(self) -> dict: to each learner's low_cost_partial_config; the estimator index as an integer corresponding to the cheapest learner is appended to the list at the end. - """ if len(self.estimator_list) == 1: estimator = self.estimator_list[0] @@ -1146,7 +1144,6 @@ def cat_hp_cost(self) -> dict: a list of the cat_hp_cost's as the value, corresponding to each learner's cat_hp_cost; the cost relative to lgbm for each learner (as a list itself) is appended to the list at the end. - """ if len(self.estimator_list) == 1: estimator = self.estimator_list[0] @@ -1198,28 +1195,28 @@ def prune_attr(self) -> Optional[str]: @property def min_resource(self) -> Optional[float]: - """Attribute for pruning + """Attribute for pruning. Returns: - A float for the minimal sample size or None + A float for the minimal sample size or None. """ return self._min_sample_size if self._sample else None @property def max_resource(self) -> Optional[float]: - """Attribute for pruning + """Attribute for pruning. Returns: - A float for the maximal sample size or None + A float for the maximal sample size or None. """ return self._state.data_size if self._sample else None @property def trainable(self) -> Callable[[dict], Optional[float]]: - """Training function + """Training function. Returns: - A function that evaluates each config and returns the loss + A function that evaluates each config and returns the loss. """ self._state.time_from_start = 0 for estimator in self.estimator_list: @@ -1255,10 +1252,10 @@ def train(config: dict): @property def metric_constraints(self) -> list: - """Metric constraints + """Metric constraints. Returns: - A list of the metric constraints + A list of the metric constraints. """ constraints = [] if np.isfinite(self._pred_time_limit): @@ -1310,7 +1307,7 @@ def fit( use_ray=False, **fit_kwargs, ): - """Find a model for a given task + """Find a model for a given task. Args: X_train: A numpy array or a pandas dataframe of training data in @@ -1499,6 +1496,7 @@ def custom_metric( and eval_method == "holdout" and self._state.X_val is None or eval_method == "cv" + and (max_iter > 0 or retrain_full is True) or max_iter == 1 ) self._auto_augment = auto_augment diff --git a/flaml/data.py b/flaml/data.py index bc0c9eb951..d420917ed1 100644 --- a/flaml/data.py +++ b/flaml/data.py @@ -1,8 +1,7 @@ -"""! - * Copyright (c) Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. -""" - +# ! +# * Copyright (c) Microsoft Corporation. All rights reserved. +# * Licensed under the MIT License. See LICENSE file in the +# * project root for license information. import numpy as np from scipy.sparse import vstack, issparse import pandas as pd @@ -130,17 +129,15 @@ def get_output_from_log(filename, time_budget): """Get output from log file Args: - filename: A string of the log file name - time_budget: A float of the time budget in seconds + filename: A string of the log file name. + time_budget: A float of the time budget in seconds. Returns: - search_time_list: A list of the finished time of each logged iter - best_error_list: - A list of the best validation error after each logged iter - error_list: A list of the validation error of each logged iter - config_list: - A list of the estimator, sample size and config of each logged iter - logged_metric_list: A list of the logged metric of each logged iter + search_time_list: A list of the finished time of each logged iter. + best_error_list: A list of the best validation error after each logged iter. + error_list: A list of the validation error of each logged iter. + config_list: A list of the estimator, sample size and config of each logged iter. + logged_metric_list: A list of the logged metric of each logged iter. """ best_config = None @@ -208,9 +205,21 @@ def concat(X1, X2): class DataTransformer: - """transform X, y""" + """Transform input training data.""" def fit_transform(self, X, y, task): + """Fit transformer and process the input training data according to the task type. + + Args: + X: A numpy array or a pandas dataframe of training data. + y: A numpy array or a pandas series of labels. + task: A string of the task type, e.g., + 'classification', 'regression', 'ts_forecast', 'rank'. + + Returns: + X: Processed numpy array or pandas dataframe of training data. + y: Processed numpy array or pandas series of labels. + """ if isinstance(X, pd.DataFrame): X = X.copy() n = X.shape[0] @@ -320,9 +329,22 @@ def fit_transform(self, X, y, task): y = self.label_transformer.fit_transform(y) else: self.label_transformer = None + self._task = task return X, y - def transform(self, X, task): + def transform(self, X): + """Process data using fit transformer. + + Args: + X: A numpy array or a pandas dataframe of training data. + y: A numpy array or a pandas series of labels. + task: A string of the task type, e.g., + 'classification', 'regression', 'ts_forecast', 'rank'. + + Returns: + X: Processed numpy array or pandas dataframe of training data. + y: Processed numpy array or pandas series of labels. + """ X = X.copy() if isinstance(X, pd.DataFrame): cat_columns, num_columns, datetime_columns = ( @@ -330,7 +352,7 @@ def transform(self, X, task): self._num_columns, self._datetime_columns, ) - if task == TS_FORECAST: + if self._task == TS_FORECAST: X = X.rename(columns={X.columns[0]: TS_TIMESTAMP_COL}) ds_col = X.pop(TS_TIMESTAMP_COL) if datetime_columns: @@ -357,7 +379,7 @@ def transform(self, X, task): X[column] = X[column].map(datetime.toordinal) del tmp_dt X = X[cat_columns + num_columns].copy() - if task == TS_FORECAST: + if self._task == TS_FORECAST: X.insert(0, TS_TIMESTAMP_COL, ds_col) for column in cat_columns: if X[column].dtype.name == "object": diff --git a/flaml/ml.py b/flaml/ml.py index 3b214256ad..02c523d251 100644 --- a/flaml/ml.py +++ b/flaml/ml.py @@ -1,8 +1,7 @@ -"""! - * Copyright (c) Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. -""" - +# ! +# * Copyright (c) Microsoft Corporation. All rights reserved. +# * Licensed under the MIT License. See LICENSE file in the +# * project root for license information. import time import numpy as np import pandas as pd @@ -27,22 +26,20 @@ LRL1Classifier, LRL2Classifier, CatBoostEstimator, - ExtraTreeEstimator, + ExtraTreesEstimator, KNeighborsEstimator, Prophet, ARIMA, SARIMAX, ) from .data import CLASSIFICATION, group_counts, TS_FORECAST, TS_VALUE_COL - import logging logger = logging.getLogger(__name__) def get_estimator_class(task, estimator_name): - """when adding a new learner, need to add an elif branch""" - + # when adding a new learner, need to add an elif branch if "xgboost" == estimator_name: if "regression" == task: estimator_class = XGBoostEstimator @@ -59,7 +56,7 @@ def get_estimator_class(task, estimator_name): elif "catboost" == estimator_name: estimator_class = CatBoostEstimator elif "extra_tree" == estimator_name: - estimator_class = ExtraTreeEstimator + estimator_class = ExtraTreesEstimator elif "kneighbor" == estimator_name: estimator_class = KNeighborsEstimator elif "prophet" in estimator_name: @@ -84,7 +81,7 @@ def sklearn_metric_loss_score( sample_weight=None, groups=None, ): - """Loss using the specified metric + """Loss using the specified metric. Args: metric_name: A string of the metric name, one of @@ -487,15 +484,15 @@ def get_classification_objective(num_labels: int) -> str: def norm_confusion_matrix(y_true, y_pred): - """normalized confusion matrix + """normalized confusion matrix. Args: - estimator: A multi-class classification estimator - y_true: A numpy array or a pandas series of true labels - y_pred: A numpy array or a pandas series of predicted labels + estimator: A multi-class classification estimator. + y_true: A numpy array or a pandas series of true labels. + y_pred: A numpy array or a pandas series of predicted labels. Returns: - A normalized confusion matrix + A normalized confusion matrix. """ from sklearn.metrics import confusion_matrix @@ -505,19 +502,19 @@ def norm_confusion_matrix(y_true, y_pred): def multi_class_curves(y_true, y_pred_proba, curve_func): - """Binarize the data for multi-class tasks and produce ROC or precision-recall curves + """Binarize the data for multi-class tasks and produce ROC or precision-recall curves. Args: - y_true: A numpy array or a pandas series of true labels - y_pred_proba: A numpy array or a pandas dataframe of predicted probabilites - curve_func: A function to produce a curve (e.g., roc_curve or precision_recall_curve) + y_true: A numpy array or a pandas series of true labels. + y_pred_proba: A numpy array or a pandas dataframe of predicted probabilites. + curve_func: A function to produce a curve (e.g., roc_curve or precision_recall_curve). Returns: - A tuple of two dictionaries with the same set of keys (class indices) + A tuple of two dictionaries with the same set of keys (class indices). The first dictionary curve_x stores the x coordinates of each curve, e.g., - curve_x[0] is an 1D array of the x coordinates of class 0 + curve_x[0] is an 1D array of the x coordinates of class 0. The second dictionary curve_y stores the y coordinates of each curve, e.g., - curve_y[0] is an 1D array of the y coordinates of class 0 + curve_y[0] is an 1D array of the y coordinates of class 0. """ from sklearn.preprocessing import label_binarize diff --git a/flaml/model.py b/flaml/model.py index 4f5d5cfd7f..2a3d8a291a 100644 --- a/flaml/model.py +++ b/flaml/model.py @@ -1,7 +1,7 @@ -"""! - * Copyright (c) Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. -""" +# ! +# * Copyright (c) Microsoft Corporation. All rights reserved. +# * Licensed under the MIT License. See LICENSE file in the +# * project root for license information. from contextlib import contextmanager from functools import partial import signal @@ -66,17 +66,17 @@ def limit_resource(memory_limit, time_limit): class BaseEstimator: - """The abstract class for all learners + """The abstract class for all learners. - Typical example: - XGBoostEstimator: for regression - XGBoostSklearnEstimator: for classification - LGBMEstimator, RandomForestEstimator, LRL1Classifier, LRL2Classifier: - for both regression and classification + Typical examples: + * XGBoostEstimator: for regression. + * XGBoostSklearnEstimator: for classification. + * LGBMEstimator, RandomForestEstimator, LRL1Classifier, LRL2Classifier: + for both regression and classification. """ def __init__(self, task="binary", **config): - """Constructor + """Constructor. Args: task: A string of the task type, one of @@ -111,12 +111,12 @@ def n_features_in_(self): @property def model(self): - """Trained model after fit() is called, or None before fit() is called""" + """Trained model after fit() is called, or None before fit() is called.""" return self._model @property def estimator(self): - """Trained model after fit() is called, or None before fit() is called""" + """Trained model after fit() is called, or None before fit() is called.""" return self._model def _preprocess(self, X): @@ -149,15 +149,15 @@ def _fit(self, X_train, y_train, **kwargs): return train_time def fit(self, X_train, y_train, budget=None, **kwargs): - """Train the model from given training data + """Train the model from given training data. Args: - X_train: A numpy array of training data in shape n*m - y_train: A numpy array of labels in shape n*1 - budget: A float of the time budget in seconds + X_train: A numpy array or a dataframe of training data in shape n*m. + y_train: A numpy array or a series of labels in shape n*1. + budget: A float of the time budget in seconds. Returns: - train_time: A float of the training time in seconds + train_time: A float of the training time in seconds. """ if ( getattr(self, "limit_resource", None) @@ -190,14 +190,14 @@ def fit(self, X_train, y_train, budget=None, **kwargs): return train_time def predict(self, X_test): - """Predict label from features + """Predict label from features. Args: - X_test: A numpy array of featurized instances, shape n*m + X_test: A numpy array or a dataframe of featurized instances, shape n*m. Returns: A numpy array of shape n*1. - Each element is the label for a instance + Each element is the label for a instance. """ if self._model is not None: X_test = self._preprocess(X_test) @@ -206,18 +206,17 @@ def predict(self, X_test): return np.ones(X_test.shape[0]) def predict_proba(self, X_test): - """Predict the probability of each class from features + """Predict the probability of each class from features. Only works for classification problems Args: - model: An object of trained model with method predict_proba() - X_test: A numpy array of featurized instances, shape n*m + X_test: A numpy array of featurized instances, shape n*m. Returns: - A numpy array of shape n*c. c is the # classes + A numpy array of shape n*c. c is the # classes. Each element at (i,j) is the probability for instance i to be in - class j + class j. """ assert ( self._task in CLASSIFICATION @@ -230,7 +229,7 @@ def cleanup(self): @classmethod def search_space(cls, **params): - """[required method] search space + """[required method] search space. Returns: A dictionary of the search space. @@ -238,16 +237,16 @@ def search_space(cls, **params): its domain (required) and low_cost_init_value, init_value, cat_hp_cost (if applicable). e.g., - {'domain': tune.randint(lower=1, upper=10), 'init_value': 1}. + `{'domain': tune.randint(lower=1, upper=10), 'init_value': 1}.` """ return {} @classmethod def size(cls, config: dict) -> float: - """[optional method] memory size of the estimator in bytes + """[optional method] memory size of the estimator in bytes. Args: - config - A dict of the hyperparameter config. + config: A dict of the hyperparameter config. Returns: A float of the memory size required by the estimator to train the @@ -257,19 +256,19 @@ def size(cls, config: dict) -> float: @classmethod def cost_relative2lgbm(cls) -> float: - """[optional method] relative cost compared to lightgbm""" + """[optional method] relative cost compared to lightgbm.""" return 1.0 @classmethod def init(cls): - """[optional method] initialize the class""" + """[optional method] initialize the class.""" pass def config2params(self, config: dict) -> dict: """[optional method] config dict to params dict Args: - config - A dict of the hyperparameter config. + config: A dict of the hyperparameter config. Returns: A dict that will be passed to self.estimator_class's constructor. @@ -278,6 +277,8 @@ def config2params(self, config: dict) -> dict: class SKLearnEstimator(BaseEstimator): + """The base class for tuning scikit-learn estimators.""" + def __init__(self, task="binary", **config): super().__init__(task, **config) @@ -298,6 +299,8 @@ def _preprocess(self, X): class LGBMEstimator(BaseEstimator): + """The class for tuning LGBM, using sklearn API.""" + ITER_HP = "n_estimators" HAS_CALLBACK = True @@ -469,7 +472,10 @@ def fit(self, X_train, y_train, budget=None, **kwargs): if self.params[self.ITER_HP] > 0: if self.HAS_CALLBACK: self._fit( - X_train, y_train, callbacks=self._callbacks(start_time, deadline), **kwargs + X_train, + y_train, + callbacks=self._callbacks(start_time, deadline), + **kwargs, ) best_iteration = ( self._model.get_booster().best_iteration @@ -503,7 +509,7 @@ def _callback(self, start_time, deadline, env) -> None: class XGBoostEstimator(SKLearnEstimator): - """not using sklearn API, used for regression""" + """The class for tuning XGBoost regressor, not using sklearn API.""" @classmethod def search_space(cls, data_size, **params): @@ -648,7 +654,7 @@ def after_iteration(self, model, epoch, evals_log) -> bool: class XGBoostSklearnEstimator(SKLearnEstimator, LGBMEstimator): - """using sklearn API, used for classification""" + """The class for tuning XGBoost (for classification), using sklearn API.""" @classmethod def search_space(cls, data_size, **params): @@ -693,6 +699,8 @@ def _callbacks(self, start_time, deadline) -> List[Callable]: class RandomForestEstimator(SKLearnEstimator, LGBMEstimator): + """The class for tuning Random Forest.""" + HAS_CALLBACK = False @classmethod @@ -746,7 +754,9 @@ def __init__( self.estimator_class = RandomForestClassifier -class ExtraTreeEstimator(RandomForestEstimator): +class ExtraTreesEstimator(RandomForestEstimator): + """The class for tuning Extra Trees.""" + @classmethod def cost_relative2lgbm(cls): return 1.9 @@ -760,6 +770,8 @@ def __init__(self, task="binary", **params): class LRL1Classifier(SKLearnEstimator): + """The class for tuning Logistic Regression with L1 regularization.""" + @classmethod def search_space(cls, **params): return { @@ -787,6 +799,8 @@ def __init__(self, task="binary", **config): class LRL2Classifier(SKLearnEstimator): + """The class for tuning Logistic Regression with L2 regularization.""" + limit_resource = True @classmethod @@ -811,6 +825,8 @@ def __init__(self, task="binary", **config): class CatBoostEstimator(BaseEstimator): + """The class for tuning CatBoost.""" + ITER_HP = "n_estimators" @classmethod @@ -1011,6 +1027,8 @@ def _preprocess(self, X): class Prophet(SKLearnEstimator): + """The class for tuning Prophet.""" + @classmethod def search_space(cls, **params): space = { @@ -1083,6 +1101,8 @@ def predict(self, X_test): class ARIMA(Prophet): + """The class for tuning ARIMA.""" + @classmethod def search_space(cls, **params): space = { @@ -1172,6 +1192,8 @@ def predict(self, X_test): class SARIMAX(ARIMA): + """The class for tuning SARIMA.""" + @classmethod def search_space(cls, **params): space = { @@ -1258,16 +1280,6 @@ def fit(self, X_train, y_train, budget=None, **kwargs): class suppress_stdout_stderr(object): - """ - A context manager for doing a "deep suppression" of stdout and stderr in - Python, i.e. will suppress all print, even if the print originates in a - compiled C/Fortran sub-function. - This will not suppress raised exceptions, since exceptions are printed - to stderr just before a script exits, and after the context manager has - exited. - - """ - def __init__(self): # Open a pair of null files self.null_fds = [os.open(os.devnull, os.O_RDWR) for x in range(2)] diff --git a/flaml/onlineml/autovw.py b/flaml/onlineml/autovw.py index 8ccddd55e9..dbb611b138 100644 --- a/flaml/onlineml/autovw.py +++ b/flaml/onlineml/autovw.py @@ -1,6 +1,12 @@ from typing import Optional, Union import logging -from flaml.tune import Trial, Categorical, Float, PolynomialExpansionSet, polynomial_expansion_set +from flaml.tune import ( + Trial, + Categorical, + Float, + PolynomialExpansionSet, + polynomial_expansion_set, +) from flaml.onlineml import OnlineTrialRunner from flaml.scheduler import ChaChaScheduler from flaml.searcher import ChampionFrontierSearcher @@ -10,69 +16,82 @@ class AutoVW: - """The AutoML class - """ + """class for the AutoVW algorithm.""" WARMSTART_NUM = 100 - AUTOMATIC = '_auto' - VW_INTERACTION_ARG_NAME = 'interactions' - - def __init__(self, - max_live_model_num: int, - search_space: dict, - init_config: Optional[dict] = {}, - min_resource_lease: Optional[Union[str, float]] = 'auto', - automl_runner_args: Optional[dict] = {}, - scheduler_args: Optional[dict] = {}, - model_select_policy: Optional[str] = 'threshold_loss_ucb', - metric: Optional[str] = 'mae_clipped', - random_seed: Optional[int] = None, - model_selection_mode: Optional[str] = 'min', - cb_coef: Optional[float] = None, - ): - """Constructor + AUTOMATIC = "_auto" + VW_INTERACTION_ARG_NAME = "interactions" + + def __init__( + self, + max_live_model_num: int, + search_space: dict, + init_config: Optional[dict] = {}, + min_resource_lease: Optional[Union[str, float]] = "auto", + automl_runner_args: Optional[dict] = {}, + scheduler_args: Optional[dict] = {}, + model_select_policy: Optional[str] = "threshold_loss_ucb", + metric: Optional[str] = "mae_clipped", + random_seed: Optional[int] = None, + model_selection_mode: Optional[str] = "min", + cb_coef: Optional[float] = None, + ): + """Constructor. Args: - max_live_model_num: The maximum number of 'live' models, which, in other words, - is the maximum number of models allowed to update in each learning iteraction. - search_space: A dictionary of the search space. This search space includes both - hyperparameters we want to tune and fixed hyperparameters. In the latter case, - the value is a fixed value. + max_live_model_num: An int to specify the maximum number of + 'live' models, which, in other words, is the maximum number + of models allowed to update in each learning iteraction. + search_space: A dictionary of the search space. This search space + includes both hyperparameters we want to tune and fixed + hyperparameters. In the latter case, the value is a fixed value. init_config: A dictionary of a partial or full initial config, e.g. {'interactions': set(), 'learning_rate': 0.5} - min_resource_lease: The minimum resource lease assigned to a particular model/trial. - If set as 'auto', it will be calculated automatically. + min_resource_lease: string or float | The minimum resource lease + assigned to a particular model/trial. If set as 'auto', it will + be calculated automatically. automl_runner_args: A dictionary of configuration for the OnlineTrialRunner. - If set {}, default values will be used, which is equivalent to using the following configs. - automl_runner_args = - {"champion_test_policy": 'loss_ucb' # specifcies how to do the statistic test for a better champion - "remove_worse": False # specifcies whether to do worse than test + If set {}, default values will be used, which is equivalent to using + the following configs. + .. code-block:: python + + automl_runner_args = + {"champion_test_policy": 'loss_ucb',# the statistic test for a better champion + "remove_worse": False, # whether to do worse than test } + scheduler_args: A dictionary of configuration for the scheduler. - If set {}, default values will be used, which is equivalent to using the following configs. - scheduler_args = - {"keep_challenger_metric": 'ucb' # what metric to use when deciding the top performing challengers - "keep_challenger_ratio": 0.5 # denotes the ratio of top performing challengers to keep live - "keep_champion": True # specifcies whether to keep the champion always running + If set {}, default values will be used, which is equivalent to using the + following config. + .. code-block:: python + + scheduler_args = + {"keep_challenger_metric": 'ucb', # what metric to use when deciding the top performing challengers + "keep_challenger_ratio": 0.5, # denotes the ratio of top performing challengers to keep live + "keep_champion": True, # specifcies whether to keep the champion always running } - model_select_policy: A string in ['threshold_loss_ucb', 'threshold_loss_lcb', 'threshold_loss_avg', - 'loss_ucb', 'loss_lcb', 'loss_avg'] to specify how to select one model to do prediction - from the live model pool. Default value is 'threshold_loss_ucb'. - metric: A string in ['mae_clipped', 'mae', 'mse', 'absolute_clipped', 'absolute', 'squared'] - to specify the name of the loss function used for calculating the progressive validation loss in ChaCha. - random_seed (int): An integer of the random seed used in the searcher - (more specifically this the random seed for ConfigOracle) + + model_select_policy: A string in ['threshold_loss_ucb', + 'threshold_loss_lcb', 'threshold_loss_avg', 'loss_ucb', 'loss_lcb', + 'loss_avg'] to specify how to select one model to do prediction from + the live model pool. Default value is 'threshold_loss_ucb'. + metric: A string in ['mae_clipped', 'mae', 'mse', 'absolute_clipped', + 'absolute', 'squared'] to specify the name of the loss function used + for calculating the progressive validation loss in ChaCha. + random_seed: An integer of the random seed used in the searcher + (more specifically this the random seed for ConfigOracle). model_selection_mode: A string in ['min', 'max'] to specify the objective as minimization or maximization. - cb_coef (float): A float coefficient (optional) used in the sample complexity bound. + cb_coef: A float coefficient (optional) used in the sample complexity bound. """ self._max_live_model_num = max_live_model_num self._search_space = search_space self._init_config = init_config - self._online_trial_args = {"metric": metric, - "min_resource_lease": min_resource_lease, - "cb_coef": cb_coef, - } + self._online_trial_args = { + "metric": metric, + "min_resource_lease": min_resource_lease, + "cb_coef": cb_coef, + } self._automl_runner_args = automl_runner_args self._scheduler_args = scheduler_args self._model_select_policy = model_select_policy @@ -85,100 +104,122 @@ def __init__(self, self._iter = 0 def _setup_trial_runner(self, vw_example): - """Set up the _trial_runner based on one vw_example - """ + """Set up the _trial_runner based on one vw_example.""" # setup the default search space for the namespace interaction hyperparameter search_space = self._search_space.copy() for k, v in self._search_space.items(): if k == self.VW_INTERACTION_ARG_NAME and v == self.AUTOMATIC: - raw_namespaces = self.get_ns_feature_dim_from_vw_example(vw_example).keys() - search_space[k] = polynomial_expansion_set(init_monomials=set(raw_namespaces)) + raw_namespaces = self.get_ns_feature_dim_from_vw_example( + vw_example + ).keys() + search_space[k] = polynomial_expansion_set( + init_monomials=set(raw_namespaces) + ) # setup the init config based on the input _init_config and search space init_config = self._init_config.copy() for k, v in search_space.items(): if k not in init_config.keys(): if isinstance(v, PolynomialExpansionSet): init_config[k] = set() - elif (not isinstance(v, Categorical) and not isinstance(v, Float)): + elif not isinstance(v, Categorical) and not isinstance(v, Float): init_config[k] = v - searcher_args = {"init_config": init_config, - "space": search_space, - "random_seed": self._random_seed, - 'online_trial_args': self._online_trial_args, - } + searcher_args = { + "init_config": init_config, + "space": search_space, + "random_seed": self._random_seed, + "online_trial_args": self._online_trial_args, + } logger.info("original search_space %s", self._search_space) logger.info("original init_config %s", self._init_config) - logger.info('searcher_args %s', searcher_args) - logger.info('scheduler_args %s', self._scheduler_args) - logger.info('automl_runner_args %s', self._automl_runner_args) + logger.info("searcher_args %s", searcher_args) + logger.info("scheduler_args %s", self._scheduler_args) + logger.info("automl_runner_args %s", self._automl_runner_args) searcher = ChampionFrontierSearcher(**searcher_args) scheduler = ChaChaScheduler(**self._scheduler_args) - self._trial_runner = OnlineTrialRunner(max_live_model_num=self._max_live_model_num, - searcher=searcher, - scheduler=scheduler, - **self._automl_runner_args) + self._trial_runner = OnlineTrialRunner( + max_live_model_num=self._max_live_model_num, + searcher=searcher, + scheduler=scheduler, + **self._automl_runner_args + ) def predict(self, data_sample): - """Predict on the input example (e.g., vw example) + """Predict on the input data sample. Args: - data_sample (vw_example) + data_sample: one data example in vw format. """ if self._trial_runner is None: self._setup_trial_runner(data_sample) self._best_trial = self._select_best_trial() self._y_predict = self._best_trial.predict(data_sample) # code for debugging purpose - if self._prediction_trial_id is None or \ - self._prediction_trial_id != self._best_trial.trial_id: + if ( + self._prediction_trial_id is None + or self._prediction_trial_id != self._best_trial.trial_id + ): self._prediction_trial_id = self._best_trial.trial_id - logger.info('prediction trial id changed to %s at iter %s, resource used: %s', - self._prediction_trial_id, self._iter, - self._best_trial.result.resource_used) + logger.info( + "prediction trial id changed to %s at iter %s, resource used: %s", + self._prediction_trial_id, + self._iter, + self._best_trial.result.resource_used, + ) return self._y_predict def learn(self, data_sample): - """Perform one online learning step with the given data sample + """Perform one online learning step with the given data sample. Args: - data_sample (vw_example): one data sample on which the model gets updated + data_sample: one data example in vw format. It will be used to + update the vw model. """ self._iter += 1 self._trial_runner.step(data_sample, (self._y_predict, self._best_trial)) def _select_best_trial(self): - """Select a best trial from the running trials accoring to the _model_select_policy - """ - best_score = float('+inf') if self._model_selection_mode == 'min' else float('-inf') + """Select a best trial from the running trials according to the _model_select_policy.""" + best_score = ( + float("+inf") if self._model_selection_mode == "min" else float("-inf") + ) new_best_trial = None for trial in self._trial_runner.running_trials: - if trial.result is not None and ('threshold' not in self._model_select_policy - or trial.result.resource_used >= self.WARMSTART_NUM): + if trial.result is not None and ( + "threshold" not in self._model_select_policy + or trial.result.resource_used >= self.WARMSTART_NUM + ): score = trial.result.get_score(self._model_select_policy) - if ('min' == self._model_selection_mode and score < best_score) or \ - ('max' == self._model_selection_mode and score > best_score): + if ("min" == self._model_selection_mode and score < best_score) or ( + "max" == self._model_selection_mode and score > best_score + ): best_score = score new_best_trial = trial if new_best_trial is not None: - logger.debug('best_trial resource used: %s', new_best_trial.result.resource_used) + logger.debug( + "best_trial resource used: %s", new_best_trial.result.resource_used + ) return new_best_trial else: # This branch will be triggered when the resource consumption all trials are smaller # than the WARMSTART_NUM threshold. In this case, we will select the _best_trial # selected in the previous iteration. - if self._best_trial is not None and self._best_trial.status == Trial.RUNNING: - logger.debug('old best trial %s', self._best_trial.trial_id) + if ( + self._best_trial is not None + and self._best_trial.status == Trial.RUNNING + ): + logger.debug("old best trial %s", self._best_trial.trial_id) return self._best_trial else: # this will be triggered in the first iteration or in the iteration where we want # to select the trial from the previous iteration but that trial has been paused # (i.e., self._best_trial.status != Trial.RUNNING) by the scheduler. - logger.debug('using champion trial: %s', - self._trial_runner.champion_trial.trial_id) + logger.debug( + "using champion trial: %s", + self._trial_runner.champion_trial.trial_id, + ) return self._trial_runner.champion_trial @staticmethod def get_ns_feature_dim_from_vw_example(vw_example) -> dict: - """Get a dictionary of feature dimensionality for each namespace singleton - """ + """Get a dictionary of feature dimensionality for each namespace singleton.""" return get_ns_feature_dim_from_vw_example(vw_example) diff --git a/flaml/onlineml/trial.py b/flaml/onlineml/trial.py index 0bf50df97f..5d4feafb3b 100644 --- a/flaml/onlineml/trial.py +++ b/flaml/onlineml/trial.py @@ -4,7 +4,7 @@ import math import copy import collections -from typing import Optional +from typing import Optional, Union from sklearn.metrics import mean_squared_error, mean_absolute_error from flaml.tune import Trial @@ -12,65 +12,68 @@ def get_ns_feature_dim_from_vw_example(vw_example) -> dict: - """Get a dictionary of feature dimensionality for each namespace singleton + """Get a dictionary of feature dimensionality for each namespace singleton.""" + # *************************A NOTE about the input vwexample*********** + # Assumption: assume the vw_example takes one of the following format + # depending on whether the example includes the feature names. - NOTE: - Assumption: assume the vw_example takes one of the following format - depending on whether the example includes the feature names + # format 1: `y |ns1 feature1:feature_value1 feature2:feature_value2 |ns2 + # ns2 feature3:feature_value3 feature4:feature_value4` + # format 2: `y | ns1 feature_value1 feature_value2 | + # ns2 feature_value3 feature_value4` - format 1: 'y |ns1 feature1:feature_value1 feature2:feature_value2 |ns2 - ns2 feature3:feature_value3 feature4:feature_value4' - format 2: 'y | ns1 feature_value1 feature_value2 | - ns2 feature_value3 feature_value4' + # The output of both cases are `{'ns1': 2, 'ns2': 2}`. - The output of both cases are {'ns1': 2, 'ns2': 2} + # For more information about the input formate of vw example, please refer to + # https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Input-format. - For more information about the input formate of vw example, please refer to - https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Input-format - """ ns_feature_dim = {} - data = vw_example.split('|') + data = vw_example.split("|") for i in range(1, len(data)): - if ':' in data[i]: - ns_w_feature = data[i].split(' ') + if ":" in data[i]: + ns_w_feature = data[i].split(" ") ns = ns_w_feature[0] feature = ns_w_feature[1:] feature_dim = len(feature) else: - data_split = data[i].split(' ') + data_split = data[i].split(" ") ns = data_split[0] feature_dim = len(data_split) - 1 if len(data_split[-1]) == 0: feature_dim -= 1 ns_feature_dim[ns] = feature_dim - logger.debug('name space feature dimension %s', ns_feature_dim) + logger.debug("name space feature dimension %s", ns_feature_dim) return ns_feature_dim class OnlineResult: - """Class for managing the result statistics of a trial - - Attributes: - observation_count: the total number of observations - resource_used: the sum of loss - - Methods: - update_result(new_loss, new_resource_used, data_dimension) - Update result - get_score(score_name) - Get the score according to the input score_name - """ + """class for managing the result statistics of a trial.""" + prob_delta = 0.1 LOSS_MIN = 0.0 LOSS_MAX = np.inf CB_COEF = 0.05 # 0.001 for mse - def __init__(self, result_type_name: str, cb_coef: Optional[float] = None, - init_loss: Optional[float] = 0.0, init_cb: Optional[float] = 100.0, - mode: Optional[str] = 'min', sliding_window_size: Optional[int] = 100): - """ + def __init__( + self, + result_type_name: str, + cb_coef: Optional[float] = None, + init_loss: Optional[float] = 0.0, + init_cb: Optional[float] = 100.0, + mode: Optional[str] = "min", + sliding_window_size: Optional[int] = 100, + ): + """Constructor. + Args: - result_type_name (str): The name of the result type + result_type_name: A String to specify the name of the result type. + cb_coef: a string to specify the coefficient on the confidence bound. + init_loss: a float to specify the inital loss. + init_cb: a float to specify the intial confidence bound. + mode: A string in ['min', 'max'] to specify the objective as + minimization or maximization. + sliding_window_size: An int to specify the size of the sliding windown + (for experimental purpose). """ self._result_type_name = result_type_name # for example 'mse' or 'mae' self._mode = mode @@ -85,32 +88,40 @@ def __init__(self, result_type_name: str, cb_coef: Optional[float] = None, self._sliding_window_size = sliding_window_size self._loss_queue = collections.deque(maxlen=self._sliding_window_size) - def update_result(self, new_loss, new_resource_used, data_dimension, - bound_of_range=1.0, new_observation_count=1.0): - """Update result statistics - """ + def update_result( + self, + new_loss, + new_resource_used, + data_dimension, + bound_of_range=1.0, + new_observation_count=1.0, + ): + """Update result statistics.""" self.resource_used += new_resource_used # keep the running average instead of sum of loss to avoid over overflow - self._loss_avg = self._loss_avg * (self.observation_count / (self.observation_count + new_observation_count) - ) + new_loss / (self.observation_count + new_observation_count) + self._loss_avg = self._loss_avg * ( + self.observation_count / (self.observation_count + new_observation_count) + ) + new_loss / (self.observation_count + new_observation_count) self.observation_count += new_observation_count self._loss_cb = self._update_loss_cb(bound_of_range, data_dimension) self._loss_queue.append(new_loss) - def _update_loss_cb(self, bound_of_range, data_dim, - bound_name='sample_complexity_bound'): - """Calculate bound coef - """ - if bound_name == 'sample_complexity_bound': + def _update_loss_cb( + self, bound_of_range, data_dim, bound_name="sample_complexity_bound" + ): + """Calculate the coefficient of the confidence bound.""" + if bound_name == "sample_complexity_bound": # set the coefficient in the loss bound - if 'mae' in self.result_type_name: + if "mae" in self.result_type_name: coef = self._cb_coef * bound_of_range else: coef = 0.001 * bound_of_range comp_F = math.sqrt(data_dim) n = self.observation_count - return coef * comp_F * math.sqrt((np.log10(n / OnlineResult.prob_delta)) / n) + return ( + coef * comp_F * math.sqrt((np.log10(n / OnlineResult.prob_delta)) / n) + ) else: raise NotImplementedError @@ -120,8 +131,7 @@ def result_type_name(self): @property def loss_avg(self): - return self._loss_avg if \ - self.observation_count != 0 else self._init_loss + return self._loss_avg if self.observation_count != 0 else self._init_loss @property def loss_cb(self): @@ -137,53 +147,45 @@ def loss_ucb(self): @property def loss_avg_recent(self): - return sum(self._loss_queue) / len(self._loss_queue) \ - if len(self._loss_queue) != 0 else self._init_loss + return ( + sum(self._loss_queue) / len(self._loss_queue) + if len(self._loss_queue) != 0 + else self._init_loss + ) def get_score(self, score_name, cb_ratio=1): - if 'lcb' in score_name: + if "lcb" in score_name: return max(self._loss_avg - cb_ratio * self._loss_cb, OnlineResult.LOSS_MIN) - elif 'ucb' in score_name: + elif "ucb" in score_name: return min(self._loss_avg + cb_ratio * self._loss_cb, OnlineResult.LOSS_MAX) - elif 'avg' in score_name: + elif "avg" in score_name: return self._loss_avg else: raise NotImplementedError class BaseOnlineTrial(Trial): - """Class for online trial. - - Attributes: - config: the config for this trial - trial_id: the trial_id of this trial - min_resource_lease (float): the minimum resource realse - status: the status of this trial - start_time: the start time of this trial - custom_trial_name: a custom name for this trial - - Methods: - set_resource_lease(resource) - set_status(status) - set_checked_under_current_champion(checked_under_current_champion) - """ - - def __init__(self, - config: dict, - min_resource_lease: float, - is_champion: Optional[bool] = False, - is_checked_under_current_champion: Optional[bool] = True, - custom_trial_name: Optional[str] = 'mae', - trial_id: Optional[str] = None, - ): - """ + """Class for the online trial.""" + + def __init__( + self, + config: dict, + min_resource_lease: float, + is_champion: Optional[bool] = False, + is_checked_under_current_champion: Optional[bool] = True, + custom_trial_name: Optional[str] = "mae", + trial_id: Optional[str] = None, + ): + """Constructor. + Args: - config: the config dict - min_resource_lease: the minimum resource realse - is_champion: a bool variable - is_checked_under_current_champion: a bool variable - custom_trial_name: custom trial name - trial_id: the trial id + config: The configuration dictionary. + min_resource_lease: A float specifying the minimum resource lease. + is_champion: A bool variable indicating whether the trial is champion. + is_checked_under_current_champion: A bool indicating whether the trial + has been used under the current champion. + custom_trial_name: A string of a custom trial name. + trial_id: A string for the trial id. """ # ****basic variables self.config = config @@ -213,26 +215,25 @@ def resource_lease(self): return self._resource_lease def set_checked_under_current_champion(self, checked_under_current_champion: bool): - """TODO: add documentation why this is needed. This is needed because sometimes - we want to know whether a trial has been paused since a new champion is promoted. - We want to try to pause those running trials (even though they are not yet achieve - the next scheduling check point according to resource used and resource lease), - because a better trial is likely to be in the new challengers generated by the new - champion, so we want to try them as soon as possible. - If we wait until we reach the next scheduling point, we may waste a lot of resource - (depending on what is the current resource lease) on the old trials (note that new - trials is not possible to be scheduled to run until there is a slot openning). - Intuitively speaking, we want to squize an opening slot as soon as possible once - a new champion is promoted, such that we are able to try newly generated challengers. - """ + # This is needed because sometimes + # we want to know whether a trial has been paused since a new champion is promoted. + # We want to try to pause those running trials (even though they are not yet achieve + # the next scheduling check point according to resource used and resource lease), + # because a better trial is likely to be in the new challengers generated by the new + # champion, so we want to try them as soon as possible. + # If we wait until we reach the next scheduling point, we may waste a lot of resource + # (depending on what is the current resource lease) on the old trials (note that new + # trials is not possible to be scheduled to run until there is a slot openning). + # Intuitively speaking, we want to squize an opening slot as soon as possible once + # a new champion is promoted, such that we are able to try newly generated challengers. self._is_checked_under_current_champion = checked_under_current_champion def set_resource_lease(self, resource: float): + """Sets the resource lease accordingly.""" self._resource_lease = resource def set_status(self, status): - """Sets the status of the trial and record the start time - """ + """Sets the status of the trial and record the start time.""" self.status = status if status == Trial.RUNNING: if self.start_time is None: @@ -240,74 +241,62 @@ def set_status(self, status): class VowpalWabbitTrial(BaseOnlineTrial): - """Implement BaseOnlineTrial for Vowpal Wabbit - - Attributes: - model: the online model - result: the anytime result for the online model - trainable_class: the model class (set as pyvw.vw for VowpalWabbitTrial) - - config: the config for this trial - trial_id: the trial_id of this trial - min_resource_lease (float): the minimum resource realse - status: the status of this trial - start_time: the start time of this trial - custom_trial_name: a custom name for this trial - - Methods: - set_resource_lease(resource) - set_status(status) - set_checked_under_current_champion(checked_under_current_champion) - - NOTE: - About result: - 1. training related results (need to be updated in the trainable class) - 2. result about resources lease (need to be updated externally) - - About namespaces in vw: - - Wiki in vw: - https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Namespaces - - Namespace vs features: - https://stackoverflow.com/questions/28586225/in-vowpal-wabbit-what-is-the-difference-between-a-namespace-and-feature - """ + """The class for Vowpal Wabbit online trials.""" + + # NOTE: 1. About namespaces in vw: + # - Wiki in vw: + # https://github.com/VowpalWabbit/vowpal_wabbit/wiki/Namespaces + # - Namespace vs features: + # https://stackoverflow.com/questions/28586225/in-vowpal-wabbit-what-is-the-difference-between-a-namespace-and-feature + + # About result: + # 1. training related results (need to be updated in the trainable class) + # 2. result about resources lease (need to be updated externally) cost_unit = 1.0 - interactions_config_key = 'interactions' + interactions_config_key = "interactions" MIN_RES_CONST = 5 - def __init__(self, - config: dict, - min_resource_lease: float, - metric: str = 'mae', - is_champion: Optional[bool] = False, - is_checked_under_current_champion: Optional[bool] = True, - custom_trial_name: Optional[str] = 'vw_mae_clipped', - trial_id: Optional[str] = None, - cb_coef: Optional[float] = None, - ): - """Constructor + def __init__( + self, + config: dict, + min_resource_lease: float, + metric: str = "mae", + is_champion: Optional[bool] = False, + is_checked_under_current_champion: Optional[bool] = True, + custom_trial_name: Optional[str] = "vw_mae_clipped", + trial_id: Optional[str] = None, + cb_coef: Optional[float] = None, + ): + """Constructor. Args: config (dict): the config of the trial (note that the config is a set - because the hyperparameters are ) - min_resource_lease (float): the minimum resource lease - metric (str): the loss metric - is_champion (bool): indicates whether the trial is the current champion or not + because the hyperparameters are). + min_resource_lease (float): the minimum resource lease. + metric (str): the loss metric. + is_champion (bool): indicates whether the trial is the current champion or not. is_checked_under_current_champion (bool): indicates whether this trials has - been paused under the current champion - trial_id (str): id of the trial (if None, it will be generated in the constructor) - + been paused under the current champion. + trial_id (str): id of the trial (if None, it will be generated in the constructor). """ try: from vowpalwabbit import pyvw except ImportError: raise ImportError( - 'To use AutoVW, please run pip install flaml[vw] to install vowpalwabbit') + "To use AutoVW, please run pip install flaml[vw] to install vowpalwabbit" + ) # attributes self.trial_id = self._config_to_id(config) if trial_id is None else trial_id - logger.info('Create trial with trial_id: %s', self.trial_id) - super().__init__(config, min_resource_lease, is_champion, is_checked_under_current_champion, - custom_trial_name, self.trial_id) - self.model = None # model is None until the config is scheduled to run + logger.info("Create trial with trial_id: %s", self.trial_id) + super().__init__( + config, + min_resource_lease, + is_champion, + is_checked_under_current_champion, + custom_trial_name, + self.trial_id, + ) + self.model = None # model is None until the config is scheduled to run self.result = None self.trainable_class = pyvw.vw # variables that are needed during online training @@ -320,45 +309,48 @@ def __init__(self, @staticmethod def _config_to_id(config): - """Generate an id for the provided config - """ + """Generate an id for the provided config.""" # sort config keys sorted_k_list = sorted(list(config.keys())) - config_id_full = '' + config_id_full = "" for key in sorted_k_list: v = config[key] - config_id = '|' + config_id = "|" if isinstance(v, set): value_list = sorted(v) - config_id += '_'.join([str(k) for k in value_list]) + config_id += "_".join([str(k) for k in value_list]) else: config_id += str(v) config_id_full = config_id_full + config_id return config_id_full def _initialize_vw_model(self, vw_example): - """Initialize a vw model using the trainable_class - """ + """Initialize a vw model using the trainable_class""" self._vw_config = self.config.copy() - ns_interactions = self.config.get(VowpalWabbitTrial.interactions_config_key, None) + ns_interactions = self.config.get( + VowpalWabbitTrial.interactions_config_key, None + ) # ensure the feature interaction config is a list (required by VW) if ns_interactions is not None: - self._vw_config[VowpalWabbitTrial.interactions_config_key] \ - = list(ns_interactions) + self._vw_config[VowpalWabbitTrial.interactions_config_key] = list( + ns_interactions + ) # get the dimensionality of the feature according to the namespace configuration namespace_feature_dim = get_ns_feature_dim_from_vw_example(vw_example) self._dim = self._get_dim_from_ns(namespace_feature_dim, ns_interactions) # construct an instance of vw model using the input config and fixed config self.model = self.trainable_class(**self._vw_config) - self.result = OnlineResult(self._metric, - cb_coef=self._cb_coef, - init_loss=0.0, init_cb=100.0,) + self.result = OnlineResult( + self._metric, + cb_coef=self._cb_coef, + init_loss=0.0, + init_cb=100.0, + ) def train_eval_model_online(self, data_sample, y_pred): - """Train and eval model online - """ + """Train and evaluate model online.""" # extract info needed the first time we see the data - if self._resource_lease == 'auto' or self._resource_lease is None: + if self._resource_lease == "auto" or self._resource_lease is None: assert self._dim is not None self._resource_lease = self._dim * self.MIN_RES_CONST y = self._get_y_from_vw_example(data_sample) @@ -369,20 +361,23 @@ def train_eval_model_online(self, data_sample, y_pred): # do one step of learning self.model.learn(data_sample) # update training related results accordingly - new_loss = self._get_loss(y, y_pred, self._metric, - self._y_min_observed, self._y_max_observed) + new_loss = self._get_loss( + y, y_pred, self._metric, self._y_min_observed, self._y_max_observed + ) # udpate sample size, sum of loss, and cost data_sample_size = 1 bound_of_range = self._y_max_observed - self._y_min_observed if bound_of_range == 0: bound_of_range = 1.0 - self.result.update_result(new_loss, - VowpalWabbitTrial.cost_unit * data_sample_size, - self._dim, bound_of_range) + self.result.update_result( + new_loss, + VowpalWabbitTrial.cost_unit * data_sample_size, + self._dim, + bound_of_range, + ) def predict(self, x): - """Predict using the model - """ + """Predict using the model.""" if self.model is None: # initialize self.model and self.result self._initialize_vw_model(x) @@ -390,14 +385,17 @@ def predict(self, x): def _get_loss(self, y_true, y_pred, loss_func_name, y_min_observed, y_max_observed): """Get instantaneous loss from y_true and y_pred, and loss_func_name - For mae_clip, we clip y_pred in the observed range of y + For mae_clip, we clip y_pred in the observed range of y """ - if 'mse' in loss_func_name or 'squared' in loss_func_name: + if "mse" in loss_func_name or "squared" in loss_func_name: loss_func = mean_squared_error - elif 'mae' in loss_func_name or 'absolute' in loss_func_name: + elif "mae" in loss_func_name or "absolute" in loss_func_name: loss_func = mean_absolute_error - if y_min_observed is not None and y_max_observed is not None and \ - 'clip' in loss_func_name: + if ( + y_min_observed is not None + and y_max_observed is not None + and "clip" in loss_func_name + ): # clip y_pred in the observed range of y y_pred = min(y_max_observed, max(y_pred, y_min_observed)) else: @@ -405,17 +403,17 @@ def _get_loss(self, y_true, y_pred, loss_func_name, y_min_observed, y_max_observ return loss_func([y_true], [y_pred]) def _update_y_range(self, y): - """Maintain running observed minimum and maximum target value - """ + """Maintain running observed minimum and maximum target value.""" if self._y_min_observed is None or y < self._y_min_observed: self._y_min_observed = y if self._y_max_observed is None or y > self._y_max_observed: self._y_max_observed = y @staticmethod - def _get_dim_from_ns(namespace_feature_dim: dict, namespace_interactions: [set, list]): - """Get the dimensionality of the corresponding feature of input namespace set - """ + def _get_dim_from_ns( + namespace_feature_dim: dict, namespace_interactions: Union[set, list] + ): + """Get the dimensionality of the corresponding feature of input namespace set.""" total_dim = sum(namespace_feature_dim.values()) if namespace_interactions: for f in namespace_interactions: @@ -431,6 +429,5 @@ def clean_up_model(self): @staticmethod def _get_y_from_vw_example(vw_example): - """Get y from a vw_example. this works for regression datasets. - """ - return float(vw_example.split('|')[0]) + """Get y from a vw_example. this works for regression datasets.""" + return float(vw_example.split("|")[0]) diff --git a/flaml/onlineml/trial_runner.py b/flaml/onlineml/trial_runner.py index d1cc6e5422..5510ac977f 100644 --- a/flaml/onlineml/trial_runner.py +++ b/flaml/onlineml/trial_runner.py @@ -9,42 +9,25 @@ class OnlineTrialRunner: - """The OnlineTrialRunner class - - Methods: - step(max_live_model_num, data_sample, prediction_trial_tuple) - Outputs a _max_live_model_num number of trials to run each time it is called - get_top_running_trials() - Get a list of trial ids, whose performance is among the top running trials - add_trial(trial) - Add trial to this TrialRunner. - stop_trial(trial) - Set the status of a trial to be Trial.TERMINATED and perform other subsequent operations - pause_trial(trial) - Set the status of a trial to be Trial.PAUSED and perform other subsequent operations - run_trial(trial) - Set the status of a trial to be Trial.RUNNING and perform other subsequent operations - get_trials() - Get all the trials added (whatever that status) in the the OnlineTrialRunner - - NOTE about the status of a trial: - Trial.PENDING: All trials are set to be pending when frist added into the OnlineTrialRunner until - it is selected to run. By this definition, a trial with status Trial.PENDING is a challenger - trial added to the OnlineTrialRunner but never been selected to run. - It denotes the starting of trial's lifespan in the OnlineTrialRunner. - Trial.RUNNING: It indicates that this trial is one of the concurrently running trials. - The max number of Trial.RUNNING trials is running_budget. - The status of a trial will be set to Trial.RUNNING the next time it selected to run. - A trial's status may have the following change: - Trial.PENDING -> Trial.RUNNING - Trial.PAUSED - > Trial.RUNNING - Trial.PAUSED: The status of a trial is set to Trial.PAUSED once it is removed from the running trials. - Trial.RUNNING - > Trial.PAUSED - Trial.TERMINATED: set the status of a trial to Trial.TERMINATED when you never want to select it. - It denotes the real end of a trial's lifespan. - Status change routine of a trial - Trial.PENDING -> (Trial.RUNNING -> Trial.PAUSED -> Trial.RUNNING -> ...) -> Trial.TERMINATED(optional) - """ + """class for the OnlineTrialRunner.""" + + # ************NOTE about the status of a trial*************** + # Trial.PENDING: All trials are set to be pending when frist added into the OnlineTrialRunner until + # it is selected to run. By this definition, a trial with status Trial.PENDING is a challenger + # trial added to the OnlineTrialRunner but never been selected to run. + # It denotes the starting of trial's lifespan in the OnlineTrialRunner. + # Trial.RUNNING: It indicates that this trial is one of the concurrently running trials. + # The max number of Trial.RUNNING trials is running_budget. + # The status of a trial will be set to Trial.RUNNING the next time it selected to run. + # A trial's status may have the following change: + # Trial.PENDING -> Trial.RUNNING + # Trial.PAUSED - > Trial.RUNNING + # Trial.PAUSED: The status of a trial is set to Trial.PAUSED once it is removed from the running trials. + # Trial.RUNNING - > Trial.PAUSED + # Trial.TERMINATED: set the status of a trial to Trial.TERMINATED when you never want to select it. + # It denotes the real end of a trial's lifespan. + # Status change routine of a trial: + # Trial.PENDING -> (Trial.RUNNING -> Trial.PAUSED -> Trial.RUNNING -> ...) -> Trial.TERMINATED(optional) RANDOM_SEED = 123456 WARMSTART_NUM = 100 @@ -57,33 +40,37 @@ def __init__( champion_test_policy="loss_ucb", **kwargs ): - """Constructor + """Constructor. Args: max_live_model_num: The maximum number of 'live'/running models allowed. - searcher: A class for generating Trial objects progressively. The ConfigOracle - is implemented in the searcher. - Required methods of the searcher: - - next_trial() - Generate the next trial to add. - - set_search_properties(metric: Optional[str], mode: Optional[str], config: Optional[dict], setting: Optional[dict]) - Generate new challengers based on the current champion and update the challenger list - - on_trial_result(trial_id: str, result: Dict) - Reprot results to the scheduler. - scheduler: A class for managing the 'live' trials and allocating the resources for the trials. - Required methods of the scheduler: - - on_trial_add(trial_runner, trial: Trial) - It adds candidate trials to the scheduler. It is called inside of the add_trial - function in the TrialRunner. - - on_trial_remove(trial_runner, trial: Trial) - Remove terminated trials from the scheduler. - - on_trial_result(trial_runner, trial: Trial, result: Dict) - Reprot results to the scheduler. - - choose_trial_to_run(trial_runner) -> Optional[Trial] - Among them, on_trial_result and choose_trial_to_run are the most important methods - champion_test_policy: A string to specify what test policy to test for champion. - Currently can choose from ['loss_ucb', 'loss_avg', 'loss_lcb', None]. + searcher: A class for generating Trial objects progressively. + The ConfigOracle is implemented in the searcher. + scheduler: A class for managing the 'live' trials and allocating the + resources for the trials. + champion_test_policy: A string to specify what test policy to test for + champion. Currently can choose from ['loss_ucb', 'loss_avg', 'loss_lcb', None]. """ + # ************A NOTE about the input searcher and scheduler****** + # Required methods of the searcher: + # - next_trial() + # Generate the next trial to add. + # - set_search_properties(metric: Optional[str], mode: Optional[str], + # config: Optional[dict], setting: Optional[dict]) + # Generate new challengers based on the current champion and update the challenger list + # - on_trial_result(trial_id: str, result: Dict) + # Reprot results to the scheduler. + # Required methods of the scheduler: + # - on_trial_add(trial_runner, trial: Trial) + # It adds candidate trials to the scheduler. It is called inside of the add_trial + # function in the TrialRunner. + # - on_trial_remove(trial_runner, trial: Trial) + # Remove terminated trials from the scheduler. + # - on_trial_result(trial_runner, trial: Trial, result: Dict) + # Reprot results to the scheduler. + # - choose_trial_to_run(trial_runner) -> Optional[Trial] + # Among them, on_trial_result and choose_trial_to_run are the most important methods + # ***************************************************************** # OnlineTrialRunner setting self._searcher = searcher self._scheduler = scheduler @@ -112,39 +99,37 @@ def __init__( @property def champion_trial(self) -> Trial: - """The champion trial""" + """The champion trial.""" return self._champion_trial @property def running_trials(self): - """The running/'live' trials""" + """The running/'live' trials.""" return self._running_trials def step(self, data_sample=None, prediction_trial_tuple=None): - """Schedule up to max_live_model_num trials to run + """Schedule one trial to run each time it is called. Args: - data_sample - prediction_trial_tuple - - NOTE: - It consists of the following several parts: - Update model: - 0. Update running trials using observations received. - Tests for Champion - 1. Test for champion (BetterThan test, and WorseThan test) - 1.1 BetterThan test - 1.2 WorseThan test: a trial may be removed if WroseThan test is triggered - Online Scheduling: - 2. Report results to the searcher and scheduler (the scheduler will return a decision about - the status of the running trials). - 3. Pause or stop a trial according to the scheduler's decision. - Add trial into the OnlineTrialRunner if there are opening slots. - - TODO: - add documentation about the Args + data_sample: One data example. + prediction_trial_tuple: A list of information containing + (prediction_made, prediction_trial). """ - # ***********Update running trials with observation*************************** + # TODO: Will remove prediction_trial_tuple. + # NOTE: This function consists of the following several parts: + # * Update model: + # 0. Update running trials using observations received. + # * Tests for Champion: + # 1. Test for champion (BetterThan test, and WorseThan test) + # 1.1 BetterThan test + # 1.2 WorseThan test: a trial may be removed if WroseThan test is triggered + # * Online Scheduling: + # 2. Report results to the searcher and scheduler (the scheduler will return a decision about + # the status of the running trials). + # 3. Pause or stop a trial according to the scheduler's decision. + # Add a trial into the OnlineTrialRunner if there are opening slots. + + # ***********Update running trials with observation******************* if data_sample is not None: self._total_steps += 1 prediction_made, prediction_trial = ( @@ -206,7 +191,7 @@ def step(self, data_sample=None, prediction_trial_tuple=None): break def get_top_running_trials(self, top_ratio=None, top_metric="ucb") -> list: - """Get a list of trial ids, whose performance is among the top running trials""" + """Get a list of trial ids, whose performance is among the top running trials.""" running_valid_trials = [ trial for trial in self._running_trials if trial.result is not None ] @@ -250,8 +235,8 @@ def _add_trial_from_searcher(self): """Add a new trial to this TrialRunner. NOTE: - The new trial is acquired from the input search algorithm, i.e. self._searcher - A 'new' trial means the trial is not in self._trial + The new trial is acquired from the input search algorithm, i.e. self._searcher. + A 'new' trial means the trial is not in self._trial. """ # (optionally) upper bound the number of trials in the OnlineTrialRunner if self._bound_trial_num and self._first_challenger_pool_size is not None: @@ -385,16 +370,13 @@ def get_trials(self) -> list: def add_trial(self, new_trial): """Add a new trial to this TrialRunner. - Trials may be added at any time. Args: - trial (Trial): Trial to queue. - - NOTE: - Only add the new trial when it does not exist (according to the trial_id, which is - the signature of the trail) in self._trials. + new_trial (Trial): Trial to queue. """ + # Only add the new trial when it does not exist (according to the trial_id, which is + # the signature of the trail) in self._trials. for trial in self._trials: if trial.trial_id == new_trial.trial_id: trial.set_checked_under_current_champion(True) @@ -409,8 +391,8 @@ def add_trial(self, new_trial): self._scheduler.on_trial_add(self, new_trial) def stop_trial(self, trial): - """Stop a trial: set the status of a trial to be Trial.TERMINATED and perform - other subsequent operations + """Stop a trial: set the status of a trial to be + Trial.TERMINATED and perform other subsequent operations. """ if trial.status in [Trial.ERROR, Trial.TERMINATED]: return @@ -428,8 +410,8 @@ def stop_trial(self, trial): self._running_trials.remove(trial) def pause_trial(self, trial): - """Pause a trial: set the status of a trial to be Trial.PAUSED and perform other - subsequent operations + """Pause a trial: set the status of a trial to be Trial.PAUSED + and perform other subsequent operations. """ if trial.status in [Trial.ERROR, Trial.TERMINATED]: return @@ -450,8 +432,8 @@ def pause_trial(self, trial): self._running_trials.remove(trial) def run_trial(self, trial): - """Run a trial: set the status of a trial to be Trial.RUNNING and perform other - subsequent operations + """Run a trial: set the status of a trial to be Trial.RUNNING + and perform other subsequent operations. """ if trial.status in [Trial.ERROR, Trial.TERMINATED]: return @@ -460,11 +442,11 @@ def run_trial(self, trial): self._running_trials.add(trial) def _better_than_champion_test(self, trial_to_test): - """Test whether there is a config in the existing trials that is better than - the current champion config + """Test whether there is a config in the existing trials that + is better than the current champion config. Returns: - A bool indicating whether a new champion is found + A bool indicating whether a new champion is found. """ if trial_to_test.result is not None and self._champion_trial.result is not None: if "ucb" in self._champion_test_policy: diff --git a/flaml/scheduler/online_scheduler.py b/flaml/scheduler/online_scheduler.py index 4a359091b2..55f2563fe9 100644 --- a/flaml/scheduler/online_scheduler.py +++ b/flaml/scheduler/online_scheduler.py @@ -3,33 +3,23 @@ from typing import Dict from flaml.scheduler import TrialScheduler from flaml.tune import Trial + logger = logging.getLogger(__name__) class OnlineScheduler(TrialScheduler): - """Implementation of the OnlineFIFOSchedulers. - - Methods: - on_trial_result(trial_runner, trial, result) - Report result and return a decision on the trial's status - choose_trial_to_run(trial_runner) - Decide which trial to run next - """ - def on_trial_result(self, trial_runner, trial: Trial, result: Dict): - """Report result and return a decision on the trial's status + """Class for the most basic OnlineScheduler.""" - Always keep a trial running (return status TrialScheduler.CONTINUE) - """ + def on_trial_result(self, trial_runner, trial: Trial, result: Dict): + """Report result and return a decision on the trial's status.""" + # Always keep a trial running (return status TrialScheduler.CONTINUE). return TrialScheduler.CONTINUE def choose_trial_to_run(self, trial_runner) -> Trial: - """Decide which trial to run next - - Trial prioritrization according to the status: - PENDING (trials that have not been tried) > PAUSED (trials that have been ran) - - For trials with the same status, it chooses the ones with smaller resource lease - """ + """Decide which trial to run next.""" + # Trial prioritrization according to the status: + # PENDING (trials that have not been tried) > PAUSED (trials that have been ran). + # For trials with the same status, it chooses the ones with smaller resource lease. for trial in trial_runner.get_trials(): if trial.status == Trial.PENDING: return trial @@ -37,7 +27,10 @@ def choose_trial_to_run(self, trial_runner) -> Trial: min_paused_resource_trial = None for trial in trial_runner.get_trials(): # if there is a tie, prefer the earlier added ones - if trial.status == Trial.PAUSED and trial.resource_lease < min_paused_resource: + if ( + trial.status == Trial.PAUSED + and trial.resource_lease < min_paused_resource + ): min_paused_resource = trial.resource_lease min_paused_resource_trial = trial if min_paused_resource_trial is not None: @@ -45,66 +38,56 @@ def choose_trial_to_run(self, trial_runner) -> Trial: class OnlineSuccessiveDoublingScheduler(OnlineScheduler): - """Implementation of the OnlineSuccessiveDoublingScheduler. - - Methods: - on_trial_result(trial_runner, trial, result) - Report result and return a decision on the trial's status - choose_trial_to_run(trial_runner) - Decide which trial to run next - """ + """class for the OnlineSuccessiveDoublingScheduler algorithm.""" + def __init__(self, increase_factor: float = 2.0): - ''' + """Constructor. + Args: - increase_factor (float): a multiplicative factor used to increase resource lease. - The default value is 2.0 - ''' + increase_factor: A float of multiplicative factor + used to increase resource lease. Default is 2.0. + """ super().__init__() self._increase_factor = increase_factor def on_trial_result(self, trial_runner, trial: Trial, result: Dict): - """Report result and return a decision on the trial's status - - 1. Returns TrialScheduler.CONTINUE (i.e., keep the trial running), - if the resource consumed has not reached the current resource_lease.s - 2. otherwise double the current resource lease and return TrialScheduler.PAUSE - """ + """Report result and return a decision on the trial's status.""" + # 1. Returns TrialScheduler.CONTINUE (i.e., keep the trial running), + # if the resource consumed has not reached the current resource_lease.s. + # 2. otherwise double the current resource lease and return TrialScheduler.PAUSE. if trial.result is None or trial.result.resource_used < trial.resource_lease: return TrialScheduler.CONTINUE else: trial.set_resource_lease(trial.resource_lease * self._increase_factor) - logger.info('Doubled resource for trial %s, used: %s, current budget %s', - trial.trial_id, trial.result.resource_used, trial.resource_lease) + logger.info( + "Doubled resource for trial %s, used: %s, current budget %s", + trial.trial_id, + trial.result.resource_used, + trial.resource_lease, + ) return TrialScheduler.PAUSE class ChaChaScheduler(OnlineSuccessiveDoublingScheduler): - """ Keep the top performing learners running - - Methods: - on_trial_result(trial_runner, trial, result) - Report result and return a decision on the trial's status - choose_trial_to_run(trial_runner) - Decide which trial to run next - """ + """class for the ChaChaScheduler algorithm.""" + def __init__(self, increase_factor: float = 2.0, **kwargs): - ''' + """Constructor. + Args: - increase_factor: a multiplicative factor used to increase resource lease. - The default value is 2.0 - ''' + increase_factor: A float of multiplicative factor + used to increase resource lease. Default is 2.0. + """ super().__init__(increase_factor) - self._keep_champion = kwargs.get('keep_champion', True) - self._keep_challenger_metric = kwargs.get('keep_challenger_metric', 'ucb') - self._keep_challenger_ratio = kwargs.get('keep_challenger_ratio', 0.5) - self._pause_old_froniter = kwargs.get('pause_old_froniter', False) - logger.info('Using chacha scheduler with config %s', kwargs) + self._keep_champion = kwargs.get("keep_champion", True) + self._keep_challenger_metric = kwargs.get("keep_challenger_metric", "ucb") + self._keep_challenger_ratio = kwargs.get("keep_challenger_ratio", 0.5) + self._pause_old_froniter = kwargs.get("pause_old_froniter", False) + logger.info("Using chacha scheduler with config %s", kwargs) def on_trial_result(self, trial_runner, trial: Trial, result: Dict): - """Report result and return a decision on the trial's status - - Make a decision according to: SuccessiveDoubling + champion check + performance check - """ + """Report result and return a decision on the trial's status.""" + # Make a decision according to: SuccessiveDoubling + champion check + performance check. # Doubling scheduler makes a decision decision = super().on_trial_result(trial_runner, trial, result) # ***********Check whether the trial has been paused since a new champion is promoted** @@ -119,22 +102,28 @@ def on_trial_result(self, trial_runner, trial: Trial, result: Dict): if decision == TrialScheduler.CONTINUE: decision = TrialScheduler.PAUSE trial.set_checked_under_current_champion(True) - logger.info('Tentitively set trial as paused') + logger.info("Tentitively set trial as paused") # ****************Keep the champion always running****************** - if self._keep_champion and trial.trial_id == trial_runner.champion_trial.trial_id and \ - decision == TrialScheduler.PAUSE: + if ( + self._keep_champion + and trial.trial_id == trial_runner.champion_trial.trial_id + and decision == TrialScheduler.PAUSE + ): return TrialScheduler.CONTINUE # ****************Keep the trials with top performance always running****************** if self._keep_challenger_ratio is not None: if decision == TrialScheduler.PAUSE: - logger.debug('champion, %s', trial_runner.champion_trial.trial_id) + logger.debug("champion, %s", trial_runner.champion_trial.trial_id) # this can be inefficient when the # trials is large. TODO: need to improve efficiency. - top_trials = trial_runner.get_top_running_trials(self._keep_challenger_ratio, - self._keep_challenger_metric) - logger.debug('top_learners: %s', top_trials) + top_trials = trial_runner.get_top_running_trials( + self._keep_challenger_ratio, self._keep_challenger_metric + ) + logger.debug("top_learners: %s", top_trials) if trial in top_trials: - logger.debug('top runner %s: set from PAUSE to CONTINUE', trial.trial_id) + logger.debug( + "top runner %s: set from PAUSE to CONTINUE", trial.trial_id + ) return TrialScheduler.CONTINUE return decision diff --git a/flaml/scheduler/trial_scheduler.py b/flaml/scheduler/trial_scheduler.py index fbfdc0a671..a188b71131 100644 --- a/flaml/scheduler/trial_scheduler.py +++ b/flaml/scheduler/trial_scheduler.py @@ -1,23 +1,20 @@ -''' -Copyright 2020 The Ray Authors. +# Copyright 2020 The Ray Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -This source file is adapted here because ray does not fully support Windows. - -Copyright (c) Microsoft Corporation. -''' +# This source file is adapted here because ray does not fully support Windows. +# Copyright (c) Microsoft Corporation. from flaml.tune import trial_runner from flaml.tune.trial import Trial @@ -29,10 +26,8 @@ class TrialScheduler: PAUSE = "PAUSE" #: Status for pausing trial execution STOP = "STOP" #: Status for stopping trial execution - def on_trial_add(self, trial_runner: "trial_runner.TrialRunner", - trial: Trial): + def on_trial_add(self, trial_runner: "trial_runner.TrialRunner", trial: Trial): pass - def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner", - trial: Trial): + def on_trial_remove(self, trial_runner: "trial_runner.TrialRunner", trial: Trial): pass diff --git a/flaml/searcher/blendsearch.py b/flaml/searcher/blendsearch.py index e6ce726da6..eb8fce9da9 100644 --- a/flaml/searcher/blendsearch.py +++ b/flaml/searcher/blendsearch.py @@ -1,8 +1,7 @@ -"""! - * Copyright (c) Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. See LICENSE file in the - * project root for license information. -""" +# ! +# * Copyright (c) Microsoft Corporation. All rights reserved. +# * Licensed under the MIT License. See LICENSE file in the +# * project root for license information. from typing import Dict, Optional, List, Tuple, Callable, Union import numpy as np import time @@ -22,14 +21,13 @@ from .search_thread import SearchThread from .flow2 import FLOW2 from ..tune.space import add_cost_to_space, indexof, normalize, define_by_run_func - import logging logger = logging.getLogger(__name__) class BlendSearch(Searcher): - """class for BlendSearch algorithm""" + """class for BlendSearch algorithm.""" cost_attr = "time_total_s" # cost attribute in result lagrange = "_lagrange" # suffix for lagrange-modified metric @@ -59,7 +57,7 @@ def __init__( seed: Optional[int] = 20, experimental: Optional[bool] = False, ): - """Constructor + """Constructor. Args: metric: A string of the metric name to optimize for. @@ -334,7 +332,7 @@ def _init_search(self): self.best_resource = self._ls.min_resource def save(self, checkpoint_path: str): - """save states to a checkpoint path""" + """save states to a checkpoint path.""" self._time_used += time.time() - self._start_time self._start_time = time.time() save_object = self @@ -342,7 +340,7 @@ def save(self, checkpoint_path: str): pickle.dump(save_object, outputFile) def restore(self, checkpoint_path: str): - """restore states from checkpoint""" + """restore states from checkpoint.""" with open(checkpoint_path, "rb") as inputFile: state = pickle.load(inputFile) self.__dict__ = state.__dict__ @@ -360,7 +358,7 @@ def is_ls_ever_converged(self): def on_trial_complete( self, trial_id: str, result: Optional[Dict] = None, error: bool = False ): - """search thread updater and cleaner""" + """search thread updater and cleaner.""" metric_constraint_satisfied = True if result and not error and self._metric_constraints: # account for metric constraints if any @@ -621,7 +619,7 @@ def _inferior(self, id1: int, id2: int) -> bool: return False def on_trial_result(self, trial_id: str, result: Dict): - """receive intermediate result""" + """receive intermediate result.""" if trial_id not in self._trial_proposed_by: return thread_id = self._trial_proposed_by[trial_id] @@ -632,7 +630,7 @@ def on_trial_result(self, trial_id: str, result: Dict): self._search_thread_pool[thread_id].on_trial_result(trial_id, result) def suggest(self, trial_id: str) -> Optional[Dict]: - """choose thread, suggest a valid config""" + """choose thread, suggest a valid config.""" if self._init_used and not self._points_to_evaluate: choice, backup = self._select_thread() # if choice < 0: # timeout @@ -902,14 +900,15 @@ def extract_scalar_reward(x: Dict): class BlendSearchTuner(BlendSearch, NNITuner): - """Tuner class for NNI""" + """Tuner class for NNI.""" def receive_trial_result(self, parameter_id, parameters, value, **kwargs): - """ - Receive trial's final result. - parameter_id: int - parameters: object created by 'generate_parameters()' - value: final metrics of the trial, including default metric + """Receive trial's final result. + + Args: + parameter_id: int. + parameters: object created by `generate_parameters()`. + value: final metrics of the trial, including default metric. """ result = { "config": parameters, @@ -926,20 +925,24 @@ def receive_trial_result(self, parameter_id, parameters, value, **kwargs): ... def generate_parameters(self, parameter_id, **kwargs) -> Dict: - """ - Returns a set of trial (hyper-)parameters, as a serializable object - parameter_id: int + """Returns a set of trial (hyper-)parameters, as a serializable object. + + Args: + parameter_id: int. """ return self.suggest(str(parameter_id)) ... def update_search_space(self, search_space): - """ + """Required by NNI. + Tuners are advised to support updating search space at run-time. If a tuner can only set search space once before generating first hyper-parameters, it should explicitly document this behaviour. - search_space: JSON object created by experiment owner + + Args: + search_space: JSON object created by experiment owner. """ config = {} for key, value in search_space.items(): @@ -991,7 +994,7 @@ def update_search_space(self, search_space): class CFO(BlendSearchTuner): - """class for CFO algorithm""" + """class for CFO algorithm.""" __name__ = "CFO" @@ -1045,6 +1048,8 @@ def on_trial_complete( class RandomSearch(CFO): + """Class for random search.""" + def suggest(self, trial_id: str) -> Optional[Dict]: if self._points_to_evaluate: return super().suggest(trial_id) diff --git a/flaml/searcher/cfo_cat.py b/flaml/searcher/cfo_cat.py index a6f884211d..2955cd7ae6 100644 --- a/flaml/searcher/cfo_cat.py +++ b/flaml/searcher/cfo_cat.py @@ -1,15 +1,13 @@ -'''! - * Copyright (c) 2021 Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. See LICENSE file in the - * project root for license information. -''' +# ! +# * Copyright (c) Microsoft Corporation. All rights reserved. +# * Licensed under the MIT License. See LICENSE file in the +# * project root for license information. from .flow2 import FLOW2 from .blendsearch import CFO class FLOW2Cat(FLOW2): - '''Local search algorithm optimized for categorical variables - ''' + """Local search algorithm optimized for categorical variables.""" def _init_search(self): super()._init_search() @@ -25,7 +23,6 @@ def _init_search(self): class CFOCat(CFO): - '''CFO optimized for categorical variables - ''' + """CFO optimized for categorical variables.""" LocalSearch = FLOW2Cat diff --git a/flaml/searcher/flow2.py b/flaml/searcher/flow2.py index 9057c65842..4705d391d3 100644 --- a/flaml/searcher/flow2.py +++ b/flaml/searcher/flow2.py @@ -1,8 +1,7 @@ -"""! - * Copyright (c) Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. See LICENSE file in the - * project root for license information. -""" +# ! +# * Copyright (c) Microsoft Corporation. All rights reserved. +# * Licensed under the MIT License. See LICENSE file in the +# * project root for license information. from flaml.tune.sample import Domain from typing import Dict, Optional, Tuple import numpy as np @@ -29,7 +28,7 @@ class FLOW2(Searcher): - """Local search algorithm FLOW2, with adaptive step size""" + """Local search algorithm FLOW2, with adaptive step size.""" STEPSIZE = 0.1 STEP_LOWER_BOUND = 0.0001 @@ -47,13 +46,13 @@ def __init__( cost_attr: Optional[str] = "time_total_s", seed: Optional[int] = 20, ): - """Constructor + """Constructor. Args: init_config: a dictionary of a partial or full initial config, - e.g. from a subset of controlled dimensions + e.g., from a subset of controlled dimensions to the initial low-cost values. - e.g. {'epochs': 1} + E.g., {'epochs': 1}. metric: A string of the metric name to optimize for. mode: A string in ['min', 'max'] to specify the objective as minimization or maximization. @@ -243,8 +242,9 @@ def complete_config( lower: Optional[Dict] = None, upper: Optional[Dict] = None, ) -> Tuple[Dict, Dict]: - """generate a complete config from the partial config input - add minimal resource to config if available + """Generate a complete config from the partial config input. + + Add minimal resource to config if available. """ disturb = self._reset_times and partial_config == self.init_config # if not the first time to complete init_config, use random gaussian @@ -279,13 +279,13 @@ def create( return flow2 def normalize(self, config, recursive=False) -> Dict: - """normalize each dimension in config to [0,1]""" + """normalize each dimension in config to [0,1].""" return normalize( config, self._space, self.best_config, self.incumbent, recursive ) def denormalize(self, config): - """denormalize each dimension in config from [0,1]""" + """denormalize each dimension in config from [0,1].""" return denormalize( config, self._space, self.best_config, self.incumbent, self._random ) @@ -314,9 +314,11 @@ def set_search_properties( def on_trial_complete( self, trial_id: str, result: Optional[Dict] = None, error: bool = False ): - # compare with incumbent - # if better, move, reset num_complete and num_proposed - # if not better and num_complete >= 2*dim, num_allowed += 2 + """ + Compare with incumbent. + If better, move, reset num_complete and num_proposed. + If not better and num_complete >= 2*dim, num_allowed += 2. + """ self.trial_count_complete += 1 if not error and result: obj = result.get(self._metric) @@ -369,7 +371,7 @@ def on_trial_complete( # elif proposed_by: del self._proposed_by[trial_id] def on_trial_result(self, trial_id: str, result: Dict): - """early update of incumbent""" + """Early update of incumbent.""" if result: obj = result.get(self._metric) if obj: @@ -401,12 +403,12 @@ def rand_vector_unit_sphere(self, dim, trunc=0) -> np.ndarray: return vec / mag def suggest(self, trial_id: str) -> Optional[Dict]: - """suggest a new config, one of the following cases: - 1. same incumbent, increase resource - 2. same resource, move from the incumbent to a random direction - 3. same resource, move from the incumbent to the opposite direction - #TODO: better decouple FLOW2 config suggestion and stepsize update + """Suggest a new config, one of the following cases: + 1. same incumbent, increase resource. + 2. same resource, move from the incumbent to a random direction. + 3. same resource, move from the incumbent to the opposite direction. """ + # TODO: better decouple FLOW2 config suggestion and stepsize update self.trial_count_proposed += 1 if ( self._num_complete4incumbent > 0 @@ -516,13 +518,13 @@ def _project(self, config): @property def can_suggest(self) -> bool: - """can't suggest if 2*dim configs have been proposed for the incumbent - while fewer are completed + """Can't suggest if 2*dim configs have been proposed for the incumbent + while fewer are completed. """ return self._num_allowed4incumbent > 0 def config_signature(self, config, space: Dict = None) -> tuple: - """return the signature tuple of a config""" + """Return the signature tuple of a config.""" config = flatten_dict(config) if space: space = flatten_dict(space) @@ -558,14 +560,14 @@ def config_signature(self, config, space: Dict = None) -> tuple: @property def converged(self) -> bool: - """return whether the local search has converged""" + """Whether the local search has converged.""" if self._num_complete4incumbent < self.dir - 2: return False # check stepsize after enough configs are completed return self.step < self.step_lower_bound def reach(self, other: Searcher) -> bool: - """whether the incumbent can reach the incumbent of other""" + """whether the incumbent can reach the incumbent of other.""" config1, config2 = self.best_config, other.best_config incumbent1, incumbent2 = self.incumbent, other.incumbent if self._resource and config1[self.prune_attr] > config2[self.prune_attr]: diff --git a/flaml/searcher/online_searcher.py b/flaml/searcher/online_searcher.py index 26038c873e..536c0f2ede 100644 --- a/flaml/searcher/online_searcher.py +++ b/flaml/searcher/online_searcher.py @@ -11,14 +11,7 @@ class BaseSearcher: - """Implementation of the BaseSearcher - - Methods: - set_search_properties(metric, mode, config) - next_trial() - on_trial_result(trial_id, result) - on_trial_complete() - """ + """Abstract class for an online searcher.""" def __init__( self, @@ -50,28 +43,21 @@ def on_trial_complete(self, trial): class ChampionFrontierSearcher(BaseSearcher): - """The ChampionFrontierSearcher class - - Methods: - (metric, mode, config) - Generate a list of new challengers, and add them to the _challenger_list - next_trial() - Pop a trial from the _challenger_list - on_trial_result(trial_id, result) - Doing nothing - on_trial_complete() - Doing nothing - - NOTE: - This class serves the role of ConfigOralce. - Every time we create an online trial, we generate a searcher_trial_id. - At the same time, we also record the trial_id of the VW trial. - Note that the trial_id is a unique signature of the configuraiton. - So if two VWTrials are associated with the same config, they will have the same trial_id - (although not the same searcher_trial_id). - searcher_trial_id will be used in suggest() + """The ChampionFrontierSearcher class. + + NOTE about the correspondence about this code and the research paper: + [ChaCha for Online AutoML](https://arxiv.org/pdf/2106.04815.pdf) + This class serves the role of ConfigOralce as described in the paper. """ + # **************************More notes*************************** + # Every time we create an online trial, we generate a searcher_trial_id. + # At the same time, we also record the trial_id of the VW trial. + # Note that the trial_id is a unique signature of the configuration. + # So if two VWTrials are associated with the same config, they will have the same trial_id + # (although not the same searcher_trial_id). + # searcher_trial_id will be used in suggest(). + # ****the following constants are used when generating new challengers in # the _query_config_oracle function # how many item to add when doing the expansion @@ -109,17 +95,19 @@ def __init__( online_trial_args: Optional[Dict] = {}, nonpoly_searcher_name: Optional[str] = "CFO", ): - """Constructor + """Constructor. Args: - init_config: dict - space: dict - metric: str - mode: str - random_seed: int - online_trial_args: dict + init_config: A dictionary of initial configuration. + space: A dictionary to specify the search space. + metric: A string of the metric name to optimize for. + mode: A string in ['min', 'max'] to specify the objective as + minimization or maximization. + random_seed: An integer of the random seed. + online_trial_args: A dictionary to specify the online trial + arguments for experimental purpose. nonpoly_searcher_name: A string to specify the search algorithm - for nonpoly hyperparameters + for nonpoly hyperparameters. """ self._init_config = init_config self._space = space @@ -154,7 +142,7 @@ def set_search_properties( setting: Optional[Dict] = {}, init_call: Optional[bool] = False, ): - """Construct search space with given config, and setup the search""" + """Construct search space with the given config, and setup the search.""" super().set_search_properties(metric, mode, config) # *********Use ConfigOralce (i.e, self._generate_new_space to generate list of new challengers) logger.info("setting %s", setting) @@ -184,7 +172,7 @@ def set_search_properties( ) def next_trial(self): - """Return a trial from the _challenger_list""" + """Return a trial from the _challenger_list.""" next_trial = None if self._challenger_list: next_trial = self._challenger_list.pop() @@ -204,7 +192,7 @@ def _query_config_oracle( self, seed_config, seed_config_trial_id, seed_config_searcher_trial_id=None ) -> List[Trial]: """Give the seed config, generate a list of new configs (which are supposed to include - at least one config that has better performance than the input seed_config) + at least one config that has better performance than the input seed_config). """ # group the hyperparameters according to whether the configs of them are independent # with the other hyperparameters diff --git a/flaml/searcher/search_thread.py b/flaml/searcher/search_thread.py index e04dccdf9d..2bad28da5e 100644 --- a/flaml/searcher/search_thread.py +++ b/flaml/searcher/search_thread.py @@ -1,53 +1,59 @@ -'''! - * Copyright (c) 2020-2021 Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. See LICENSE file in the - * project root for license information. -''' +# ! +# * Copyright (c) Microsoft Corporation. All rights reserved. +# * Licensed under the MIT License. See LICENSE file in the +# * project root for license information. from typing import Dict, Optional import numpy as np + try: from ray import __version__ as ray_version - assert ray_version >= '1.0.0' + + assert ray_version >= "1.0.0" from ray.tune.suggest import Searcher except (ImportError, AssertionError): from .suggestion import Searcher from .flow2 import FLOW2 from ..tune.space import add_cost_to_space, unflatten_hierarchical - import logging + logger = logging.getLogger(__name__) class SearchThread: - '''Class of global or local search thread - ''' + """Class of global or local search thread.""" _eps = 1.0 - def __init__(self, mode: str = "min", - search_alg: Optional[Searcher] = None, - cost_attr: Optional[str] = 'time_total_s'): - ''' When search_alg is omitted, use local search FLOW2 - ''' + def __init__( + self, + mode: str = "min", + search_alg: Optional[Searcher] = None, + cost_attr: Optional[str] = "time_total_s", + ): + """When search_alg is omitted, use local search FLOW2.""" self._search_alg = search_alg self._is_ls = isinstance(search_alg, FLOW2) self._mode = mode - self._metric_op = 1 if mode == 'min' else -1 - self.cost_best = self.cost_last = self.cost_total = self.cost_best1 = \ - getattr(search_alg, 'cost_incumbent', 0) + self._metric_op = 1 if mode == "min" else -1 + self.cost_best = self.cost_last = self.cost_total = self.cost_best1 = getattr( + search_alg, "cost_incumbent", 0 + ) self.cost_best2 = 0 self.obj_best1 = self.obj_best2 = getattr( - search_alg, 'best_obj', np.inf) # inherently minimize + search_alg, "best_obj", np.inf + ) # inherently minimize # eci: estimated cost for improvement self.eci = self.cost_best self.priority = self.speed = 0 self._init_config = True - self.running = 0 # the number of running trials from the thread + self.running = 0 # the number of running trials from the thread self.cost_attr = cost_attr if search_alg: self.space = self._space = search_alg.space # unflattened space - if self.space and not isinstance(search_alg, FLOW2) and isinstance( - search_alg._space, dict + if ( + self.space + and not isinstance(search_alg, FLOW2) + and isinstance(search_alg._space, dict) ): # remember const config self._const = add_cost_to_space(self.space, {}, {}) @@ -57,8 +63,7 @@ def set_eps(cls, time_budget_s): cls._eps = max(min(time_budget_s / 1000.0, 1.0), 1e-9) def suggest(self, trial_id: str) -> Optional[Dict]: - ''' use the suggest() of the underlying search algorithm - ''' + """Use the suggest() of the underlying search algorithm.""" if isinstance(self._search_alg, FLOW2): config = self._search_alg.suggest(trial_id) else: @@ -68,12 +73,12 @@ def suggest(self, trial_id: str) -> Optional[Dict]: config.update(self._const) else: # define by run - config, self.space = unflatten_hierarchical( - config, self._space) + config, self.space = unflatten_hierarchical(config, self._space) except FloatingPointError: logger.warning( - 'The global search method raises FloatingPointError. ' - 'Ignoring for this iteration.') + "The global search method raises FloatingPointError. " + "Ignoring for this iteration." + ) config = None if config is not None: self.running += 1 @@ -83,14 +88,14 @@ def update_priority(self, eci: Optional[float] = 0): # optimistic projection self.priority = eci * self.speed - self.obj_best1 - def update_eci(self, metric_target: float, - max_speed: Optional[float] = np.inf): + def update_eci(self, metric_target: float, max_speed: Optional[float] = np.inf): # calculate eci: estimated cost for improvement over metric_target best_obj = metric_target * self._metric_op if not self.speed: self.speed = max_speed - self.eci = max(self.cost_total - self.cost_best1, - self.cost_best1 - self.cost_best2) + self.eci = max( + self.cost_total - self.cost_best1, self.cost_best1 - self.cost_best2 + ) if self.obj_best1 > best_obj and self.speed > 0: self.eci = max(self.eci, 2 * (self.obj_best1 - best_obj) / self.speed) @@ -98,19 +103,23 @@ def _update_speed(self): # calculate speed; use 0 for invalid speed temporarily if self.obj_best2 > self.obj_best1: # discount the speed if there are unfinished trials - self.speed = (self.obj_best2 - self.obj_best1) / self.running / ( - max(self.cost_total - self.cost_best2, SearchThread._eps)) + self.speed = ( + (self.obj_best2 - self.obj_best1) + / self.running + / (max(self.cost_total - self.cost_best2, SearchThread._eps)) + ) else: self.speed = 0 - def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, - error: bool = False): - ''' update the statistics of the thread - ''' + def on_trial_complete( + self, trial_id: str, result: Optional[Dict] = None, error: bool = False + ): + """Update the statistics of the thread.""" if not self._search_alg: return - if not hasattr(self._search_alg, '_ot_trials') or ( - not error and trial_id in self._search_alg._ot_trials): + if not hasattr(self._search_alg, "_ot_trials") or ( + not error and trial_id in self._search_alg._ot_trials + ): # optuna doesn't handle error if self._is_ls or not self._init_config: try: @@ -118,7 +127,8 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, except RuntimeError as e: # rs is used in place of optuna sometimes if not str(e).endswith( - "has already finished and can not be updated."): + "has already finished and can not be updated." + ): raise e else: # init config is not proposed by self._search_alg @@ -132,8 +142,7 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, if obj < self.obj_best1: self.cost_best2 = self.cost_best1 self.cost_best1 = self.cost_total - self.obj_best2 = obj if np.isinf( - self.obj_best1) else self.obj_best1 + self.obj_best2 = obj if np.isinf(self.obj_best1) else self.obj_best1 self.obj_best1 = obj self.cost_best = self.cost_last self._update_speed() @@ -141,18 +150,17 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, assert self.running >= 0 def on_trial_result(self, trial_id: str, result: Dict): - ''' TODO update the statistics of the thread with partial result? - ''' + # TODO update the statistics of the thread with partial result? if not self._search_alg: return - if not hasattr(self._search_alg, '_ot_trials') or ( - trial_id in self._search_alg._ot_trials): + if not hasattr(self._search_alg, "_ot_trials") or ( + trial_id in self._search_alg._ot_trials + ): try: self._search_alg.on_trial_result(trial_id, result) except RuntimeError as e: # rs is used in place of optuna sometimes - if not str(e).endswith( - "has already finished and can not be updated."): + if not str(e).endswith("has already finished and can not be updated."): raise e if self.cost_attr in result and self.cost_last < result[self.cost_attr]: self.cost_last = result[self.cost_attr] @@ -167,12 +175,10 @@ def resource(self) -> float: return self._search_alg.resource def reach(self, thread) -> bool: - ''' whether the incumbent can reach the incumbent of thread - ''' + """Whether the incumbent can reach the incumbent of thread.""" return self._search_alg.reach(thread._search_alg) @property def can_suggest(self) -> bool: - ''' whether the thread can suggest new configs - ''' + """Whether the thread can suggest new configs.""" return self._search_alg.can_suggest diff --git a/flaml/searcher/suggestion.py b/flaml/searcher/suggestion.py index ea420b885c..68ed2fa3f8 100644 --- a/flaml/searcher/suggestion.py +++ b/flaml/searcher/suggestion.py @@ -1,22 +1,20 @@ -''' -Copyright 2020 The Ray Authors. +# Copyright 2020 The Ray Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -This source file is adapted here because ray does not fully support Windows. +# This source file is adapted here because ray does not fully support Windows. -Copyright (c) Microsoft Corporation. -''' +# Copyright (c) Microsoft Corporation. import time import functools import warnings @@ -25,8 +23,15 @@ from typing import Any, Dict, Optional, Union, List, Tuple, Callable import pickle from .variant_generator import parse_spec_vars -from ..tune.sample import Categorical, Domain, Float, Integer, LogUniform, \ - Quantized, Uniform +from ..tune.sample import ( + Categorical, + Domain, + Float, + Integer, + LogUniform, + Quantized, + Uniform, +) from ..tune.trial import flatten_dict, unflatten_dict logger = logging.getLogger(__name__) @@ -36,19 +41,22 @@ "space definitions. {cls} should however be instantiated with fully " "configured search spaces only. To use Ray Tune's automatic search space " "conversion, pass the space definition as part of the `config` argument " - "to `tune.run()` instead.") + "to `tune.run()` instead." +) UNDEFINED_SEARCH_SPACE = str( "Trying to sample a configuration from {cls}, but no search " "space has been defined. Either pass the `{space}` argument when " "instantiating the search algorithm, or pass a `config` to " - "`tune.run()`.") + "`tune.run()`." +) UNDEFINED_METRIC_MODE = str( "Trying to sample a configuration from {cls}, but the `metric` " "({metric}) or `mode` ({mode}) parameters have not been set. " "Either pass these arguments when instantiating the search algorithm, " - "or pass them to `tune.run()`.") + "or pass them to `tune.run()`." +) class Searcher: @@ -83,14 +91,17 @@ def on_trial_complete(self, trial_id, result, **kwargs): self.optimizer.update(configuration, result[self.metric]) tune.run(trainable_function, search_alg=ExampleSearch()) """ + FINISHED = "FINISHED" CKPT_FILE_TMPL = "searcher-state-{}.pkl" - def __init__(self, - metric: Optional[str] = None, - mode: Optional[str] = None, - max_concurrent: Optional[int] = None, - use_early_stopped_trials: Optional[bool] = None): + def __init__( + self, + metric: Optional[str] = None, + mode: Optional[str] = None, + max_concurrent: Optional[int] = None, + use_early_stopped_trials: Optional[bool] = None, + ): self._metric = metric self._mode = mode @@ -100,20 +111,21 @@ def __init__(self, return assert isinstance( - metric, type(mode)), "metric and mode must be of the same type" + metric, type(mode) + ), "metric and mode must be of the same type" if isinstance(mode, str): - assert mode in ["min", "max" - ], "if `mode` is a str must be 'min' or 'max'!" + assert mode in ["min", "max"], "if `mode` is a str must be 'min' or 'max'!" elif isinstance(mode, list): - assert len(mode) == len( - metric), "Metric and mode must be the same length" - assert all(mod in ["min", "max", "obs"] for mod in - mode), "All of mode must be 'min' or 'max' or 'obs'!" + assert len(mode) == len(metric), "Metric and mode must be the same length" + assert all( + mod in ["min", "max", "obs"] for mod in mode + ), "All of mode must be 'min' or 'max' or 'obs'!" else: raise ValueError("Mode most either be a list or string") - def set_search_properties(self, metric: Optional[str], mode: Optional[str], - config: Dict) -> bool: + def set_search_properties( + self, metric: Optional[str], mode: Optional[str], config: Dict + ) -> bool: """Pass search properties to searcher. This method acts as an alternative to instantiating search algorithms with their own specific search spaces. Instead they can accept a @@ -171,10 +183,7 @@ class ConcurrencyLimiter(Searcher): tune.run(trainable, search_alg=search_alg) """ - def __init__(self, - searcher: Searcher, - max_concurrent: int, - batch: bool = False): + def __init__(self, searcher: Searcher, max_concurrent: int, batch: bool = False): assert type(max_concurrent) is int and max_concurrent > 0 self.searcher = searcher self.max_concurrent = max_concurrent @@ -182,16 +191,20 @@ def __init__(self, self.live_trials = set() self.cached_results = {} super(ConcurrencyLimiter, self).__init__( - metric=self.searcher.metric, mode=self.searcher.mode) + metric=self.searcher.metric, mode=self.searcher.mode + ) def suggest(self, trial_id: str) -> Optional[Dict]: - assert trial_id not in self.live_trials, ( - f"Trial ID {trial_id} must be unique: already found in set.") + assert ( + trial_id not in self.live_trials + ), f"Trial ID {trial_id} must be unique: already found in set." if len(self.live_trials) >= self.max_concurrent: logger.debug( f"Not providing a suggestion for {trial_id} due to " - "concurrency limit: %s/%s.", len(self.live_trials), - self.max_concurrent) + "concurrency limit: %s/%s.", + len(self.live_trials), + self.max_concurrent, + ) return suggestion = self.searcher.suggest(trial_id) @@ -199,10 +212,9 @@ def suggest(self, trial_id: str) -> Optional[Dict]: self.live_trials.add(trial_id) return suggestion - def on_trial_complete(self, - trial_id: str, - result: Optional[Dict] = None, - error: bool = False): + def on_trial_complete( + self, trial_id: str, result: Optional[Dict] = None, error: bool = False + ): if trial_id not in self.live_trials: return elif self.batch: @@ -212,14 +224,14 @@ def on_trial_complete(self, # full batch is completed. for trial_id, (result, error) in self.cached_results.items(): self.searcher.on_trial_complete( - trial_id, result=result, error=error) + trial_id, result=result, error=error + ) self.live_trials.remove(trial_id) self.cached_results = {} else: return else: - self.searcher.on_trial_complete( - trial_id, result=result, error=error) + self.searcher.on_trial_complete(trial_id, result=result, error=error) self.live_trials.remove(trial_id) def get_state(self) -> Dict: @@ -242,8 +254,9 @@ def on_pause(self, trial_id: str): def on_unpause(self, trial_id: str): self.searcher.on_unpause(trial_id) - def set_search_properties(self, metric: Optional[str], mode: Optional[str], - config: Dict) -> bool: + def set_search_properties( + self, metric: Optional[str], mode: Optional[str], config: Dict + ) -> bool: return self.searcher.set_search_properties(metric, mode, config) @@ -270,10 +283,12 @@ def set_search_properties(self, metric: Optional[str], mode: Optional[str], DEFINE_BY_RUN_WARN_THRESHOLD_S = 1 # 1 is arbitrary -def validate_warmstart(parameter_names: List[str], - points_to_evaluate: List[Union[List, Dict]], - evaluated_rewards: List, - validate_point_name_lengths: bool = True): +def validate_warmstart( + parameter_names: List[str], + points_to_evaluate: List[Union[List, Dict]], + evaluated_rewards: List, + validate_point_name_lengths: bool = True, +): """Generic validation of a Searcher's warm start functionality. Raises exceptions in case of type and length mismatches between parameters. @@ -285,29 +300,36 @@ def validate_warmstart(parameter_names: List[str], if not isinstance(points_to_evaluate, list): raise TypeError( "points_to_evaluate expected to be a list, got {}.".format( - type(points_to_evaluate))) + type(points_to_evaluate) + ) + ) for point in points_to_evaluate: if not isinstance(point, (dict, list)): raise TypeError( f"points_to_evaluate expected to include list or dict, " - f"got {point}.") + f"got {point}." + ) - if validate_point_name_lengths and ( - not len(point) == len(parameter_names)): - raise ValueError("Dim of point {}".format(point) - + " and parameter_names {}".format( - parameter_names) + " do not match.") + if validate_point_name_lengths and (not len(point) == len(parameter_names)): + raise ValueError( + "Dim of point {}".format(point) + + " and parameter_names {}".format(parameter_names) + + " do not match." + ) if points_to_evaluate and evaluated_rewards: if not isinstance(evaluated_rewards, list): raise TypeError( "evaluated_rewards expected to be a list, got {}.".format( - type(evaluated_rewards))) + type(evaluated_rewards) + ) + ) if not len(evaluated_rewards) == len(points_to_evaluate): raise ValueError( "Dim of evaluated_rewards {}".format(evaluated_rewards) + " and points_to_evaluate {}".format(points_to_evaluate) - + " do not match.") + + " do not match." + ) class _OptunaTrialSuggestCaptor: @@ -421,30 +443,33 @@ def define_search_space(trial: optuna.Trial): .. versionadded:: 0.8.8 """ - def __init__(self, - space: Optional[Union[Dict[str, "OptunaDistribution"], List[ - Tuple], Callable[["OptunaTrial"], Optional[Dict[ - str, Any]]]]] = None, - metric: Optional[str] = None, - mode: Optional[str] = None, - points_to_evaluate: Optional[List[Dict]] = None, - sampler: Optional["BaseSampler"] = None, - seed: Optional[int] = None, - evaluated_rewards: Optional[List] = None): - assert ot is not None, ( - "Optuna must be installed! Run `pip install optuna`.") + def __init__( + self, + space: Optional[ + Union[ + Dict[str, "OptunaDistribution"], + List[Tuple], + Callable[["OptunaTrial"], Optional[Dict[str, Any]]], + ] + ] = None, + metric: Optional[str] = None, + mode: Optional[str] = None, + points_to_evaluate: Optional[List[Dict]] = None, + sampler: Optional["BaseSampler"] = None, + seed: Optional[int] = None, + evaluated_rewards: Optional[List] = None, + ): + assert ot is not None, "Optuna must be installed! Run `pip install optuna`." super(OptunaSearch, self).__init__( - metric=metric, - mode=mode, - max_concurrent=None, - use_early_stopped_trials=None) + metric=metric, mode=mode, max_concurrent=None, use_early_stopped_trials=None + ) if isinstance(space, dict) and space: resolved_vars, domain_vars, grid_vars = parse_spec_vars(space) if domain_vars or grid_vars: logger.warning( - UNRESOLVED_SEARCH_SPACE.format( - par="space", cls=type(self).__name__)) + UNRESOLVED_SEARCH_SPACE.format(par="space", cls=type(self).__name__) + ) space = self.convert_search_space(space) else: # Flatten to support nested dicts @@ -461,13 +486,15 @@ def __init__(self, logger.warning( "You passed an initialized sampler to `OptunaSearch`. The " "`seed` parameter has to be passed to the sampler directly " - "and will be ignored.") + "and will be ignored." + ) self._sampler = sampler or ot.samplers.TPESampler(seed=seed) - assert isinstance(self._sampler, BaseSampler), \ - "You can only pass an instance of `optuna.samplers.BaseSampler` " \ + assert isinstance(self._sampler, BaseSampler), ( + "You can only pass an instance of `optuna.samplers.BaseSampler` " "as a sampler to `OptunaSearcher`." + ) self._ot_trials = {} self._ot_study = None @@ -488,24 +515,28 @@ def _setup_study(self, mode: str): pruner=pruner, study_name=self._study_name, direction="minimize" if mode == "min" else "maximize", - load_if_exists=True) + load_if_exists=True, + ) if self._points_to_evaluate: validate_warmstart( self._space, self._points_to_evaluate, self._evaluated_rewards, - validate_point_name_lengths=not callable(self._space)) + validate_point_name_lengths=not callable(self._space), + ) if self._evaluated_rewards: - for point, reward in zip(self._points_to_evaluate, - self._evaluated_rewards): + for point, reward in zip( + self._points_to_evaluate, self._evaluated_rewards + ): self.add_evaluated_point(point, reward) else: for point in self._points_to_evaluate: self._ot_study.enqueue_trial(point) - def set_search_properties(self, metric: Optional[str], mode: Optional[str], - config: Dict) -> bool: + def set_search_properties( + self, metric: Optional[str], mode: Optional[str], config: Dict + ) -> bool: if self._space: return False space = self.convert_search_space(config) @@ -519,8 +550,10 @@ def set_search_properties(self, metric: Optional[str], mode: Optional[str], return True def _suggest_from_define_by_run_func( - self, func: Callable[["OptunaTrial"], Optional[Dict[str, Any]]], - ot_trial: "OptunaTrial") -> Dict: + self, + func: Callable[["OptunaTrial"], Optional[Dict[str, Any]]], + ot_trial: "OptunaTrial", + ) -> Dict: captor = _OptunaTrialSuggestCaptor(ot_trial) time_start = time.time() ret = func(captor) @@ -531,35 +564,37 @@ def _suggest_from_define_by_run_func( f"took {time_taken} seconds to " "run. Ensure that actual computation, training takes " "place inside Tune's train functions or Trainables " - "passed to `tune.run`.") + "passed to `tune.run`." + ) if ret is not None: if not isinstance(ret, dict): raise TypeError( "The return value of the define-by-run function " "passed in the `space` argument should be " "either None or a `dict` with `str` keys. " - f"Got {type(ret)}.") + f"Got {type(ret)}." + ) if not all(isinstance(k, str) for k in ret.keys()): raise TypeError( "At least one of the keys in the dict returned by the " "define-by-run function passed in the `space` argument " - "was not a `str`.") - return { - **captor.captured_values, - **ret - } if ret else captor.captured_values + "was not a `str`." + ) + return {**captor.captured_values, **ret} if ret else captor.captured_values def suggest(self, trial_id: str) -> Optional[Dict]: if not self._space: raise RuntimeError( UNDEFINED_SEARCH_SPACE.format( - cls=self.__class__.__name__, space="space")) + cls=self.__class__.__name__, space="space" + ) + ) if not self._metric or not self._mode: raise RuntimeError( UNDEFINED_METRIC_MODE.format( - cls=self.__class__.__name__, - metric=self._metric, - mode=self._mode)) + cls=self.__class__.__name__, metric=self._metric, mode=self._mode + ) + ) if isinstance(self._space, list): # Keep for backwards compatibility @@ -571,8 +606,9 @@ def suggest(self, trial_id: str) -> Optional[Dict]: # getattr will fetch the trial.suggest_ function on Optuna trials params = { - args[0] if len(args) > 0 else kwargs["name"]: getattr( - ot_trial, fn)(*args, **kwargs) + args[0] + if len(args) > 0 + else kwargs["name"]: getattr(ot_trial, fn)(*args, **kwargs) for (fn, args, kwargs) in self._space } elif callable(self._space): @@ -581,13 +617,13 @@ def suggest(self, trial_id: str) -> Optional[Dict]: ot_trial = self._ot_trials[trial_id] - params = self._suggest_from_define_by_run_func( - self._space, ot_trial) + params = self._suggest_from_define_by_run_func(self._space, ot_trial) else: # Use Optuna ask interface (since version 2.6.0) if trial_id not in self._ot_trials: self._ot_trials[trial_id] = self._ot_study.ask( - fixed_distributions=self._space) + fixed_distributions=self._space + ) ot_trial = self._ot_trials[trial_id] params = ot_trial.params @@ -599,10 +635,9 @@ def on_trial_result(self, trial_id: str, result: Dict): ot_trial = self._ot_trials[trial_id] ot_trial.report(metric, step) - def on_trial_complete(self, - trial_id: str, - result: Optional[Dict] = None, - error: bool = False): + def on_trial_complete( + self, trial_id: str, result: Optional[Dict] = None, error: bool = False + ): ot_trial = self._ot_trials[trial_id] val = result.get(self.metric, None) if result else None @@ -617,22 +652,26 @@ def on_trial_complete(self, except ValueError as exc: logger.warning(exc) # E.g. if NaN was reported - def add_evaluated_point(self, - parameters: Dict, - value: float, - error: bool = False, - pruned: bool = False, - intermediate_values: Optional[List[float]] = None): + def add_evaluated_point( + self, + parameters: Dict, + value: float, + error: bool = False, + pruned: bool = False, + intermediate_values: Optional[List[float]] = None, + ): if not self._space: raise RuntimeError( UNDEFINED_SEARCH_SPACE.format( - cls=self.__class__.__name__, space="space")) + cls=self.__class__.__name__, space="space" + ) + ) if not self._metric or not self._mode: raise RuntimeError( UNDEFINED_METRIC_MODE.format( - cls=self.__class__.__name__, - metric=self._metric, - mode=self._mode)) + cls=self.__class__.__name__, metric=self._metric, mode=self._mode + ) + ) ot_trial_state = OptunaTrialState.COMPLETE if error: @@ -642,8 +681,7 @@ def add_evaluated_point(self, if intermediate_values: intermediate_values_dict = { - i: value - for i, value in enumerate(intermediate_values) + i: value for i, value in enumerate(intermediate_values) } else: intermediate_values_dict = None @@ -653,13 +691,19 @@ def add_evaluated_point(self, value=value, params=parameters, distributions=self._space, - intermediate_values=intermediate_values_dict) + intermediate_values=intermediate_values_dict, + ) self._ot_study.add_trial(trial) def save(self, checkpoint_path: str): - save_object = (self._sampler, self._ot_trials, self._ot_study, - self._points_to_evaluate, self._evaluated_rewards) + save_object = ( + self._sampler, + self._ot_trials, + self._ot_study, + self._points_to_evaluate, + self._evaluated_rewards, + ) with open(checkpoint_path, "wb") as outputFile: pickle.dump(save_object, outputFile) @@ -667,12 +711,21 @@ def restore(self, checkpoint_path: str): with open(checkpoint_path, "rb") as inputFile: save_object = pickle.load(inputFile) if len(save_object) == 5: - self._sampler, self._ot_trials, self._ot_study, \ - self._points_to_evaluate, self._evaluated_rewards = save_object + ( + self._sampler, + self._ot_trials, + self._ot_study, + self._points_to_evaluate, + self._evaluated_rewards, + ) = save_object else: # Backwards compatibility - self._sampler, self._ot_trials, self._ot_study, \ - self._points_to_evaluate = save_object + ( + self._sampler, + self._ot_trials, + self._ot_study, + self._points_to_evaluate, + ) = save_object @staticmethod def convert_search_space(spec: Dict) -> Dict[str, Any]: @@ -684,7 +737,8 @@ def convert_search_space(spec: Dict) -> Dict[str, Any]: if grid_vars: raise ValueError( "Grid search parameters cannot be automatically converted " - "to an Optuna search space.") + "to an Optuna search space." + ) # Flatten and resolve again after checking for grid search. spec = flatten_dict(spec, prevent_delimiter=True) @@ -701,50 +755,54 @@ def resolve_value(domain: Domain) -> ot.distributions.BaseDistribution: logger.warning( "Optuna does not handle quantization in loguniform " "sampling. The parameter will be passed but it will " - "probably be ignored.") + "probably be ignored." + ) if isinstance(domain, Float): if isinstance(sampler, LogUniform): if quantize: logger.warning( "Optuna does not support both quantization and " - "sampling from LogUniform. Dropped quantization.") + "sampling from LogUniform. Dropped quantization." + ) return ot.distributions.LogUniformDistribution( - domain.lower, domain.upper) + domain.lower, domain.upper + ) elif isinstance(sampler, Uniform): if quantize: return ot.distributions.DiscreteUniformDistribution( - domain.lower, domain.upper, quantize) + domain.lower, domain.upper, quantize + ) return ot.distributions.UniformDistribution( - domain.lower, domain.upper) + domain.lower, domain.upper + ) elif isinstance(domain, Integer): if isinstance(sampler, LogUniform): return ot.distributions.IntLogUniformDistribution( - domain.lower, domain.upper - 1, step=quantize or 1) + domain.lower, domain.upper - 1, step=quantize or 1 + ) elif isinstance(sampler, Uniform): # Upper bound should be inclusive for quantization and # exclusive otherwise return ot.distributions.IntUniformDistribution( domain.lower, domain.upper - int(bool(not quantize)), - step=quantize or 1) + step=quantize or 1, + ) elif isinstance(domain, Categorical): if isinstance(sampler, Uniform): - return ot.distributions.CategoricalDistribution( - domain.categories) + return ot.distributions.CategoricalDistribution(domain.categories) raise ValueError( "Optuna search does not support parameters of type " "`{}` with samplers of type `{}`".format( - type(domain).__name__, - type(domain.sampler).__name__)) + type(domain).__name__, type(domain.sampler).__name__ + ) + ) # Parameter name is e.g. "a/b/c" for nested dicts - values = { - "/".join(path): resolve_value(domain) - for path, domain in domain_vars - } + values = {"/".join(path): resolve_value(domain) for path, domain in domain_vars} return values diff --git a/flaml/searcher/variant_generator.py b/flaml/searcher/variant_generator.py index 4b34676606..d87a0bc9a0 100644 --- a/flaml/searcher/variant_generator.py +++ b/flaml/searcher/variant_generator.py @@ -1,22 +1,20 @@ -''' -Copyright 2020 The Ray Authors. +# Copyright 2020 The Ray Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -This source file is adapted here because ray does not fully support Windows. +# This source file is adapted here because ray does not fully support Windows. -Copyright (c) Microsoft Corporation. -''' +# Copyright (c) Microsoft Corporation. import copy import logging from typing import Any, Dict, Generator, List, Tuple @@ -31,11 +29,13 @@ class TuneError(Exception): """General error class raised by ray.tune.""" + pass def generate_variants( - unresolved_spec: Dict) -> Generator[Tuple[Dict, Dict], None, None]: + unresolved_spec: Dict, +) -> Generator[Tuple[Dict, Dict], None, None]: """Generates variants from a spec (dict) with unresolved values. There are two types of unresolved values: Grid search: These define a grid search over values. For example, the @@ -72,8 +72,9 @@ def grid_search(values: List) -> Dict[str, List]: _MAX_RESOLUTION_PASSES = 20 -def parse_spec_vars(spec: Dict) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[ - Tuple, Any]], List[Tuple[Tuple, Any]]]: +def parse_spec_vars( + spec: Dict, +) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[Tuple, Any]], List[Tuple[Tuple, Any]]]: resolved, unresolved = _split_resolved_unresolved_values(spec) resolved_vars = list(resolved.items()) @@ -107,12 +108,16 @@ def _generate_variants(spec: Dict) -> Tuple[Dict, Dict]: for path, value in grid_vars: resolved_vars[path] = _get_value(spec, path) for k, v in resolved.items(): - if (k in resolved_vars and v != resolved_vars[k] - and _is_resolved(resolved_vars[k])): + if ( + k in resolved_vars + and v != resolved_vars[k] + and _is_resolved(resolved_vars[k]) + ): raise ValueError( "The variable `{}` could not be unambiguously " "resolved to a single value. Consider simplifying " - "your configuration.".format(k)) + "your configuration.".format(k) + ) resolved_vars[k] = v yield resolved_vars, spec @@ -129,8 +134,7 @@ def _get_value(spec: Dict, path: Tuple) -> Any: return spec -def _resolve_domain_vars(spec: Dict, - domain_vars: List[Tuple[Tuple, Domain]]) -> Dict: +def _resolve_domain_vars(spec: Dict, domain_vars: List[Tuple[Tuple, Domain]]) -> Dict: resolved = {} error = True num_passes = 0 @@ -146,8 +150,8 @@ def _resolve_domain_vars(spec: Dict, error = e except Exception: raise ValueError( - "Failed to evaluate expression: {}: {}".format( - path, domain)) + "Failed to evaluate expression: {}: {}".format(path, domain) + ) else: assign_value(spec, path, value) resolved[path] = value @@ -156,8 +160,9 @@ def _resolve_domain_vars(spec: Dict, return resolved -def _grid_search_generator(unresolved_spec: Dict, - grid_vars: List) -> Generator[Dict, None, None]: +def _grid_search_generator( + unresolved_spec: Dict, grid_vars: List +) -> Generator[Dict, None, None]: value_indices = [0] * len(grid_vars) def increment(i): @@ -199,39 +204,44 @@ def _try_resolve(v) -> Tuple[bool, Any]: grid_values = v["grid_search"] if not isinstance(grid_values, list): raise TuneError( - "Grid search expected list of values, got: {}".format( - grid_values)) + "Grid search expected list of values, got: {}".format(grid_values) + ) return False, Categorical(grid_values).grid() return True, v def _split_resolved_unresolved_values( - spec: Dict) -> Tuple[Dict[Tuple, Any], Dict[Tuple, Any]]: + spec: Dict, +) -> Tuple[Dict[Tuple, Any], Dict[Tuple, Any]]: resolved_vars = {} unresolved_vars = {} for k, v in spec.items(): resolved, v = _try_resolve(v) if not resolved: - unresolved_vars[(k, )] = v + unresolved_vars[(k,)] = v elif isinstance(v, dict): # Recurse into a dict - _resolved_children, _unresolved_children = \ - _split_resolved_unresolved_values(v) + ( + _resolved_children, + _unresolved_children, + ) = _split_resolved_unresolved_values(v) for (path, value) in _resolved_children.items(): - resolved_vars[(k, ) + path] = value + resolved_vars[(k,) + path] = value for (path, value) in _unresolved_children.items(): - unresolved_vars[(k, ) + path] = value + unresolved_vars[(k,) + path] = value elif isinstance(v, list): # Recurse into a list for i, elem in enumerate(v): - _resolved_children, _unresolved_children = \ - _split_resolved_unresolved_values({i: elem}) + ( + _resolved_children, + _unresolved_children, + ) = _split_resolved_unresolved_values({i: elem}) for (path, value) in _resolved_children.items(): - resolved_vars[(k, ) + path] = value + resolved_vars[(k,) + path] = value for (path, value) in _unresolved_children.items(): - unresolved_vars[(k, ) + path] = value + unresolved_vars[(k,) + path] = value else: - resolved_vars[(k, )] = v + resolved_vars[(k,)] = v return resolved_vars, unresolved_vars @@ -252,7 +262,8 @@ def __getattribute__(self, item): value = dict.__getattribute__(self, item) if not _is_resolved(value): raise RecursiveDependencyError( - "`{}` recursively depends on {}".format(item, value)) + "`{}` recursively depends on {}".format(item, value) + ) elif isinstance(value, dict): return _UnresolvedAccessGuard(value) else: diff --git a/flaml/tune/analysis.py b/flaml/tune/analysis.py index 320b0d279f..d02747e398 100644 --- a/flaml/tune/analysis.py +++ b/flaml/tune/analysis.py @@ -1,22 +1,20 @@ -""" -Copyright 2020 The Ray Authors. +# Copyright 2020 The Ray Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -This source file is adapted here because ray does not fully support Windows. +# This source file is adapted here because ray does not fully support Windows. -Copyright (c) Microsoft Corporation. -""" +# Copyright (c) Microsoft Corporation. from typing import Dict, Optional import numpy as np from .trial import Trial diff --git a/flaml/tune/result.py b/flaml/tune/result.py index f7a1430b5e..461c991f13 100644 --- a/flaml/tune/result.py +++ b/flaml/tune/result.py @@ -1,22 +1,20 @@ -''' -Copyright 2020 The Ray Authors. +# Copyright 2020 The Ray Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -This source file is adapted here because ray does not fully support Windows. +# This source file is adapted here because ray does not fully support Windows. -Copyright (c) Microsoft Corporation. -''' +# Copyright (c) Microsoft Corporation. import os # yapf: disable @@ -83,8 +81,13 @@ DEFAULT_EXPERIMENT_INFO_KEYS = ("trainable_name", EXPERIMENT_TAG, TRIAL_ID) -DEFAULT_RESULT_KEYS = (TRAINING_ITERATION, TIME_TOTAL_S, TIMESTEPS_TOTAL, - MEAN_ACCURACY, MEAN_LOSS) +DEFAULT_RESULT_KEYS = ( + TRAINING_ITERATION, + TIME_TOTAL_S, + TIMESTEPS_TOTAL, + MEAN_ACCURACY, + MEAN_LOSS, +) # Make sure this doesn't regress AUTO_RESULT_KEYS = ( @@ -120,9 +123,11 @@ STDERR_FILE = "__stderr_file__" # Where Tune writes result files by default -DEFAULT_RESULTS_DIR = (os.environ.get("TEST_TMPDIR") - or os.environ.get("TUNE_RESULT_DIR") - or os.path.expanduser("~/ray_results")) +DEFAULT_RESULTS_DIR = ( + os.environ.get("TEST_TMPDIR") + or os.environ.get("TUNE_RESULT_DIR") + or os.path.expanduser("~/ray_results") +) # Meta file about status under each experiment directory, can be # parsed by automlboard if exists. diff --git a/flaml/tune/sample.py b/flaml/tune/sample.py index 13ccffe7fa..dcc6fca56e 100644 --- a/flaml/tune/sample.py +++ b/flaml/tune/sample.py @@ -1,22 +1,20 @@ -''' -Copyright 2020 The Ray Authors. +# Copyright 2020 The Ray Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -This source file is included here because ray does not fully support Windows. +# This source file is included here because ray does not fully support Windows. -Copyright (c) Microsoft Corporation. -''' +# Copyright (c) Microsoft Corporation. import logging import random from copy import copy @@ -38,6 +36,7 @@ class Domain: allow specification of specific samplers (e.g. ``uniform()`` or ``loguniform()``). """ + sampler = None default_sampler_cls = None @@ -47,11 +46,13 @@ def cast(self, value): def set_sampler(self, sampler, allow_override=False): if self.sampler and not allow_override: - raise ValueError("You can only choose one sampler for parameter " - "domains. Existing sampler for parameter {}: " - "{}. Tried to add {}".format( - self.__class__.__name__, self.sampler, - sampler)) + raise ValueError( + "You can only choose one sampler for parameter " + "domains. Existing sampler for parameter {}: " + "{}. Tried to add {}".format( + self.__class__.__name__, self.sampler, sampler + ) + ) self.sampler = sampler def get_sampler(self): @@ -80,10 +81,12 @@ def domain_str(self): class Sampler: - def sample(self, - domain: Domain, - spec: Optional[Union[List[Dict], Dict]] = None, - size: int = 1): + def sample( + self, + domain: Domain, + spec: Optional[Union[List[Dict], Dict]] = None, + size: int = 1, + ): raise NotImplementedError @@ -107,7 +110,7 @@ def __str__(self): class Normal(Sampler): - def __init__(self, mean: float = 0., sd: float = 0.): + def __init__(self, mean: float = 0.0, sd: float = 0.0): self.mean = mean self.sd = sd @@ -120,50 +123,58 @@ def __str__(self): class Grid(Sampler): """Dummy sampler used for grid search""" - def sample(self, - domain: Domain, - spec: Optional[Union[List[Dict], Dict]] = None, - size: int = 1): + def sample( + self, + domain: Domain, + spec: Optional[Union[List[Dict], Dict]] = None, + size: int = 1, + ): return RuntimeError("Do not call `sample()` on grid.") class Float(Domain): class _Uniform(Uniform): - def sample(self, - domain: "Float", - spec: Optional[Union[List[Dict], Dict]] = None, - size: int = 1): - assert domain.lower > float("-inf"), \ - "Uniform needs a lower bound" - assert domain.upper < float("inf"), \ - "Uniform needs a upper bound" + def sample( + self, + domain: "Float", + spec: Optional[Union[List[Dict], Dict]] = None, + size: int = 1, + ): + assert domain.lower > float("-inf"), "Uniform needs a lower bound" + assert domain.upper < float("inf"), "Uniform needs a upper bound" items = np.random.uniform(domain.lower, domain.upper, size=size) return items if len(items) > 1 else domain.cast(items[0]) class _LogUniform(LogUniform): - def sample(self, - domain: "Float", - spec: Optional[Union[List[Dict], Dict]] = None, - size: int = 1): - assert domain.lower > 0, \ - "LogUniform needs a lower bound greater than 0" - assert 0 < domain.upper < float("inf"), \ - "LogUniform needs a upper bound greater than 0" + def sample( + self, + domain: "Float", + spec: Optional[Union[List[Dict], Dict]] = None, + size: int = 1, + ): + assert domain.lower > 0, "LogUniform needs a lower bound greater than 0" + assert ( + 0 < domain.upper < float("inf") + ), "LogUniform needs a upper bound greater than 0" logmin = np.log(domain.lower) / np.log(self.base) logmax = np.log(domain.upper) / np.log(self.base) - items = self.base**(np.random.uniform(logmin, logmax, size=size)) + items = self.base ** (np.random.uniform(logmin, logmax, size=size)) return items if len(items) > 1 else domain.cast(items[0]) class _Normal(Normal): - def sample(self, - domain: "Float", - spec: Optional[Union[List[Dict], Dict]] = None, - size: int = 1): - assert not domain.lower or domain.lower == float("-inf"), \ - "Normal sampling does not allow a lower value bound." - assert not domain.upper or domain.upper == float("inf"), \ - "Normal sampling does not allow a upper value bound." + def sample( + self, + domain: "Float", + spec: Optional[Union[List[Dict], Dict]] = None, + size: int = 1, + ): + assert not domain.lower or domain.lower == float( + "-inf" + ), "Normal sampling does not allow a lower value bound." + assert not domain.upper or domain.upper == float( + "inf" + ), "Normal sampling does not allow a upper value bound." items = np.random.normal(self.mean, self.sd, size=size) return items if len(items) > 1 else domain.cast(items[0]) @@ -181,11 +192,13 @@ def uniform(self): if not self.lower > float("-inf"): raise ValueError( "Uniform requires a lower bound. Make sure to set the " - "`lower` parameter of `Float()`.") + "`lower` parameter of `Float()`." + ) if not self.upper < float("inf"): raise ValueError( "Uniform requires a upper bound. Make sure to set the " - "`upper` parameter of `Float()`.") + "`upper` parameter of `Float()`." + ) new = copy(self) new.set_sampler(self._Uniform()) return new @@ -196,33 +209,39 @@ def loguniform(self, base: float = 10): "LogUniform requires a lower bound greater than 0." f"Got: {self.lower}. Did you pass a variable that has " "been log-transformed? If so, pass the non-transformed value " - "instead.") + "instead." + ) if not 0 < self.upper < float("inf"): raise ValueError( "LogUniform requires a upper bound greater than 0. " f"Got: {self.lower}. Did you pass a variable that has " "been log-transformed? If so, pass the non-transformed value " - "instead.") + "instead." + ) new = copy(self) new.set_sampler(self._LogUniform(base)) return new - def normal(self, mean=0., sd=1.): + def normal(self, mean=0.0, sd=1.0): new = copy(self) new.set_sampler(self._Normal(mean, sd)) return new def quantized(self, q: float): - if self.lower > float("-inf") and not isclose(self.lower / q, - round(self.lower / q)): + if self.lower > float("-inf") and not isclose( + self.lower / q, round(self.lower / q) + ): raise ValueError( f"Your lower variable bound {self.lower} is not divisible by " - f"quantization factor {q}.") - if self.upper < float("inf") and not isclose(self.upper / q, - round(self.upper / q)): + f"quantization factor {q}." + ) + if self.upper < float("inf") and not isclose( + self.upper / q, round(self.upper / q) + ): raise ValueError( f"Your upper variable bound {self.upper} is not divisible by " - f"quantization factor {q}.") + f"quantization factor {q}." + ) new = copy(self) new.set_sampler(Quantized(new.get_sampler(), q), allow_override=True) @@ -238,26 +257,30 @@ def domain_str(self): class Integer(Domain): class _Uniform(Uniform): - def sample(self, - domain: "Integer", - spec: Optional[Union[List[Dict], Dict]] = None, - size: int = 1): + def sample( + self, + domain: "Integer", + spec: Optional[Union[List[Dict], Dict]] = None, + size: int = 1, + ): items = np.random.randint(domain.lower, domain.upper, size=size) return items if len(items) > 1 else domain.cast(items[0]) class _LogUniform(LogUniform): - def sample(self, - domain: "Integer", - spec: Optional[Union[List[Dict], Dict]] = None, - size: int = 1): - assert domain.lower > 0, \ - "LogUniform needs a lower bound greater than 0" - assert 0 < domain.upper < float("inf"), \ - "LogUniform needs a upper bound greater than 0" + def sample( + self, + domain: "Integer", + spec: Optional[Union[List[Dict], Dict]] = None, + size: int = 1, + ): + assert domain.lower > 0, "LogUniform needs a lower bound greater than 0" + assert ( + 0 < domain.upper < float("inf") + ), "LogUniform needs a upper bound greater than 0" logmin = np.log(domain.lower) / np.log(self.base) logmax = np.log(domain.upper) / np.log(self.base) - items = self.base**(np.random.uniform(logmin, logmax, size=size)) + items = self.base ** (np.random.uniform(logmin, logmax, size=size)) items = np.round(items).astype(int) return items if len(items) > 1 else domain.cast(items[0]) @@ -286,13 +309,15 @@ def loguniform(self, base: float = 10): "LogUniform requires a lower bound greater than 0." f"Got: {self.lower}. Did you pass a variable that has " "been log-transformed? If so, pass the non-transformed value " - "instead.") + "instead." + ) if not 0 < self.upper < float("inf"): raise ValueError( "LogUniform requires a upper bound greater than 0. " f"Got: {self.lower}. Did you pass a variable that has " "been log-transformed? If so, pass the non-transformed value " - "instead.") + "instead." + ) new = copy(self) new.set_sampler(self._LogUniform(base)) return new @@ -307,10 +332,12 @@ def domain_str(self): class Categorical(Domain): class _Uniform(Uniform): - def sample(self, - domain: "Categorical", - spec: Optional[Union[List[Dict], Dict]] = None, - size: int = 1): + def sample( + self, + domain: "Categorical", + spec: Optional[Union[List[Dict], Dict]] = None, + size: int = 1, + ): items = random.choices(domain.categories, k=size) return items if len(items) > 1 else domain.cast(items[0]) @@ -349,10 +376,12 @@ def __init__(self, sampler: Sampler, q: Union[float, int]): def get_sampler(self): return self.sampler - def sample(self, - domain: Domain, - spec: Optional[Union[List[Dict], Dict]] = None, - size: int = 1): + def sample( + self, + domain: Domain, + spec: Optional[Union[List[Dict], Dict]] = None, + size: int = 1, + ): values = self.sampler.sample(domain, spec, size) quantized = np.round(np.divide(values, self.q)) * self.q if not isinstance(quantized, np.ndarray): @@ -361,12 +390,18 @@ def sample(self, class PolynomialExpansionSet: - - def __init__(self, init_monomials: set = (), highest_poly_order: int = None, - allow_self_inter: bool = False): + def __init__( + self, + init_monomials: set = (), + highest_poly_order: int = None, + allow_self_inter: bool = False, + ): self._init_monomials = init_monomials - self._highest_poly_order = highest_poly_order if \ - highest_poly_order is not None else len(self._init_monomials) + self._highest_poly_order = ( + highest_poly_order + if highest_poly_order is not None + else len(self._init_monomials) + ) self._allow_self_inter = allow_self_inter @property @@ -471,7 +506,7 @@ def qlograndint(lower: int, upper: int, q: int, base: float = 10): return Integer(lower, upper).loguniform(base).quantized(q) -def randn(mean: float = 0., sd: float = 1.): +def randn(mean: float = 0.0, sd: float = 1.0): """Sample a float value normally with ``mean`` and ``sd``. Args: mean (float): Mean of the normal distribution. Defaults to 0. @@ -492,7 +527,8 @@ def qrandn(mean: float, sd: float, q: float): return Float(None, None).normal(mean, sd).quantized(q) -def polynomial_expansion_set(init_monomials: set, highest_poly_order: int = None, - allow_self_inter: bool = False): +def polynomial_expansion_set( + init_monomials: set, highest_poly_order: int = None, allow_self_inter: bool = False +): return PolynomialExpansionSet(init_monomials, highest_poly_order, allow_self_inter) diff --git a/flaml/tune/space.py b/flaml/tune/space.py index c252a8b99c..91fe08868f 100644 --- a/flaml/tune/space.py +++ b/flaml/tune/space.py @@ -117,7 +117,7 @@ def define_by_run_func(trial, space: Dict, path: str = "") -> Optional[Dict[str, def unflatten_hierarchical(config: Dict, space: Dict) -> Tuple[Dict, Dict]: - """unflatten hierarchical config""" + """Unflatten hierarchical config.""" hier = {} subspace = {} for key, value in config.items(): @@ -152,7 +152,7 @@ def unflatten_hierarchical(config: Dict, space: Dict) -> Tuple[Dict, Dict]: def add_cost_to_space(space: Dict, low_cost_point: Dict, choice_cost: Dict): - """Update the space in place by adding low_cost_point and choice_cost + """Update the space in place by adding low_cost_point and choice_cost. Returns: A dict with constant values. @@ -240,8 +240,9 @@ def normalize( normalized_reference_config: Dict, recursive: bool = False, ): - """normalize config in space according to reference_config. - normalize each dimension in config to [0,1]. + """Normalize config in space according to reference_config. + + Normalize each dimension in config to [0,1]. """ config_norm = {} for key, value in config.items(): @@ -410,7 +411,7 @@ def denormalize( def indexof(domain: Dict, config: Dict) -> int: - """find the index of config in domain.categories""" + """Find the index of config in domain.categories.""" index = config.get("_choice_") if index is not None: return index @@ -441,10 +442,10 @@ def complete_config( lower: Optional[Dict] = None, upper: Optional[Dict] = None, ) -> Tuple[Dict, Dict]: - """Complete partial config in space + """Complete partial config in space. Returns: - config, space + config, space. """ config = partial_config.copy() normalized = normalize(config, space, partial_config, {}) diff --git a/flaml/tune/trial.py b/flaml/tune/trial.py index 394c1ffaa6..30d4fe663b 100644 --- a/flaml/tune/trial.py +++ b/flaml/tune/trial.py @@ -1,22 +1,20 @@ -''' -Copyright 2020 The Ray Authors. +# Copyright 2020 The Ray Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -This source file is adapted here because ray does not fully support Windows. +# This source file is adapted here because ray does not fully support Windows. -Copyright (c) Microsoft Corporation. -''' +# Copyright (c) Microsoft Corporation. import uuid import time from numbers import Number @@ -29,7 +27,8 @@ def flatten_dict(dt, delimiter="/", prevent_delimiter=False): # Raise if delimiter is any of the keys raise ValueError( "Found delimiter `{}` in key when trying to flatten array." - "Please avoid using the delimiter in your specification.") + "Please avoid using the delimiter in your specification." + ) while any(isinstance(v, dict) for v in dt.values()): remove = [] add = {} @@ -41,7 +40,8 @@ def flatten_dict(dt, delimiter="/", prevent_delimiter=False): raise ValueError( "Found delimiter `{}` in key when trying to " "flatten array. Please avoid using the delimiter " - "in your specification.") + "in your specification." + ) add[delimiter.join([key, str(subkey)])] = v remove.append(key) dt.update(add) @@ -106,31 +106,35 @@ def update_last_result(self, result): "max": value, "min": value, "avg": value, - "last": value + "last": value, } self.metric_n_steps[metric] = {} for n in self.n_steps: key = "last-{:d}-avg".format(n) self.metric_analysis[metric][key] = value # Store n as string for correct restore. - self.metric_n_steps[metric][str(n)] = deque( - [value], maxlen=n) + self.metric_n_steps[metric][str(n)] = deque([value], maxlen=n) else: step = result["training_iteration"] or 1 self.metric_analysis[metric]["max"] = max( - value, self.metric_analysis[metric]["max"]) + value, self.metric_analysis[metric]["max"] + ) self.metric_analysis[metric]["min"] = min( - value, self.metric_analysis[metric]["min"]) - self.metric_analysis[metric]["avg"] = 1 / step * ( - value + (step - 1) * self.metric_analysis[metric]["avg"]) + value, self.metric_analysis[metric]["min"] + ) + self.metric_analysis[metric]["avg"] = ( + 1 + / step + * (value + (step - 1) * self.metric_analysis[metric]["avg"]) + ) self.metric_analysis[metric]["last"] = value for n in self.n_steps: key = "last-{:d}-avg".format(n) self.metric_n_steps[metric][str(n)].append(value) self.metric_analysis[metric][key] = sum( - self.metric_n_steps[metric][str(n)]) / len( - self.metric_n_steps[metric][str(n)]) + self.metric_n_steps[metric][str(n)] + ) / len(self.metric_n_steps[metric][str(n)]) def set_status(self, status): """Sets the status of the trial.""" diff --git a/flaml/tune/trial_runner.py b/flaml/tune/trial_runner.py index 75c5181ea3..41d34be4ce 100644 --- a/flaml/tune/trial_runner.py +++ b/flaml/tune/trial_runner.py @@ -1,8 +1,7 @@ -"""! - * Copyright (c) Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. See LICENSE file in the - * project root for license information. -""" +# ! +# * Copyright (c) Microsoft Corporation. All rights reserved. +# * Licensed under the MIT License. See LICENSE file in the +# * project root for license information. from typing import Optional # try: diff --git a/flaml/tune/tune.py b/flaml/tune/tune.py index d2c28f0403..c98c3bc6bc 100644 --- a/flaml/tune/tune.py +++ b/flaml/tune/tune.py @@ -1,8 +1,7 @@ -"""! - * Copyright (c) Microsoft Corporation. All rights reserved. - * Licensed under the MIT License. See LICENSE file in the - * project root for license information. -""" +# ! +# * Copyright (c) Microsoft Corporation. All rights reserved. +# * Licensed under the MIT License. See LICENSE file in the +# * project root for license information. from typing import Optional, Union, List, Callable, Tuple import numpy as np import datetime @@ -32,7 +31,7 @@ class ExperimentAnalysis(EA): - """Class for storing the experiment results""" + """Class for storing the experiment results.""" def __init__(self, trials, metric, mode): try: diff --git a/flaml/version.py b/flaml/version.py index 49e0fc1e09..a5f830a2c0 100644 --- a/flaml/version.py +++ b/flaml/version.py @@ -1 +1 @@ -__version__ = "0.7.0" +__version__ = "0.7.1"