Skip to content

Commit

Permalink
Add Holt-Winters exponential smoothing (#962)
Browse files Browse the repository at this point in the history
* tentatively implement holt-winters-no covariates

* fix forecast method, clean class

* checking external regressors too

* update test forecast

* remove duplicated test file, re-add sarimax, search space cleanup

* Update flaml/automl/model.py

removed links. Most important one probably was: https://robjhyndman.com/hyndsight/ets-regressors/

Co-authored-by: Chi Wang <wang.chi@microsoft.com>

* prevent short series

* add docs

---------

Co-authored-by: Andrea W <a.ruggerini@ammagamma.com>
Co-authored-by: Chi Wang <wang.chi@microsoft.com>
  • Loading branch information
3 people committed Apr 4, 2023
1 parent 4c20c85 commit 7f9402b
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 9 deletions.
3 changes: 3 additions & 0 deletions flaml/automl/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Prophet,
ARIMA,
SARIMAX,
HoltWinters,
TransformersEstimator,
TemporalFusionTransformerEstimator,
TransformersEstimatorModelSelection,
Expand Down Expand Up @@ -156,6 +157,8 @@ def get_estimator_class(task: str, estimator_name: str) -> EstimatorSubclass:
estimator_class = ARIMA
elif estimator_name == "sarimax":
estimator_class = SARIMAX
elif estimator_name == "holt-winters":
estimator_class = HoltWinters
elif estimator_name == "transformer":
estimator_class = TransformersEstimator
elif estimator_name == "tft":
Expand Down
88 changes: 88 additions & 0 deletions flaml/automl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2377,6 +2377,94 @@ def fit(self, X_train, y_train, budget=None, free_mem_ratio=0, **kwargs):
return train_time


class HoltWinters(ARIMA):
"""
The class for tuning Holt Winters model, aka 'Triple Exponential Smoothing'.
"""

@classmethod
def search_space(cls, **params):
space = {
"damped_trend": {"domain": tune.choice([True, False]), "init_value": False},
"trend": {"domain": tune.choice(["add", "mul", None]), "init_value": "add"},
"seasonal": {
"domain": tune.choice(["add", "mul", None]),
"init_value": "add",
},
"use_boxcox": {"domain": tune.choice([False, True]), "init_value": False},
"seasonal_periods": { # statsmodels casts this to None if "seasonal" is None
"domain": tune.choice(
[7, 12, 4, 52, 6]
), # weekly, yearly, quarterly, weekly w yearly data
"init_value": 7,
},
}
return space

def fit(self, X_train, y_train, budget=None, free_mem_ratio=0, **kwargs):
import warnings

warnings.filterwarnings("ignore")
from statsmodels.tsa.holtwinters import (
ExponentialSmoothing as HWExponentialSmoothing,
)

current_time = time.time()
train_df = self._join(X_train, y_train)
train_df = self._preprocess(train_df)
regressors = list(train_df)
regressors.remove(TS_VALUE_COL)
if regressors:
logger.warning("Regressors are ignored for Holt-Winters ETS models.")

# Override incompatible parameters
if (
X_train.shape[0] < 2 * self.params["seasonal_periods"]
): # this would prevent heuristic initialization to work properly
self.params["seasonal"] = None
if (
self.params["seasonal"] == "mul" and (train_df.y == 0).sum() > 0
): # cannot have multiplicative seasonality in this case
self.params["seasonal"] = "add"
if self.params["trend"] == "mul" and (train_df.y == 0).sum() > 0:
self.params["trend"] = "add"

if not self.params["seasonal"] or not self.params["trend"] in [
"mul",
"add",
]:
self.params["damped_trend"] = False

model = HWExponentialSmoothing(
train_df[[TS_VALUE_COL]],
damped_trend=self.params["damped_trend"],
seasonal=self.params["seasonal"],
trend=self.params["trend"],
)
with suppress_stdout_stderr():
model = model.fit()
train_time = time.time() - current_time
self._model = model
return train_time

def predict(self, X, **kwargs):
if self._model is not None:
if isinstance(X, int):
forecast = self._model.forecast(steps=X)
elif isinstance(X, DataFrame):
start = X[TS_TIMESTAMP_COL].iloc[0]
end = X[TS_TIMESTAMP_COL].iloc[-1]
forecast = self._model.predict(start=start, end=end, **kwargs)
else:
raise ValueError(
"X needs to be either a pandas Dataframe with dates as the first column"
" or an int number of periods for predict()."
)
return forecast
else:
return np.ones(X if isinstance(X, int) else X.shape[0])


class TS_SKLearn(SKLearnEstimator):
"""The class for tuning SKLearn Regressors for time-series forecasting, using hcrystalball"""

Expand Down
9 changes: 7 additions & 2 deletions flaml/automl/task/generic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,9 +1055,14 @@ def default_estimator_list(
try:
import prophet

estimator_list += ["prophet", "arima", "sarimax"]
estimator_list += [
"prophet",
"arima",
"sarimax",
"holt-winters",
]
except ImportError:
estimator_list += ["arima", "sarimax"]
estimator_list += ["arima", "sarimax", "holt-winters"]
elif not self.is_regression():
estimator_list += ["lrl1"]

Expand Down
20 changes: 13 additions & 7 deletions test/automl/test_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from flaml import AutoML


def test_forecast_automl(budget=5):
def test_forecast_automl(
budget=5, estimators_when_no_prophet=["arima", "sarimax", "holt-winters"]
):
# using dataframe
import statsmodels.api as sm

Expand Down Expand Up @@ -39,7 +41,7 @@ def test_forecast_automl(budget=5):
automl.fit(
dataframe=df,
**settings,
estimator_list=["arima", "sarimax"],
estimator_list=estimators_when_no_prophet,
period=time_horizon,
)
""" retrieve best config and best learner"""
Expand Down Expand Up @@ -89,7 +91,7 @@ def test_forecast_automl(budget=5):
X_train=X_train,
y_train=y_train,
**settings,
estimator_list=["arima", "sarimax"],
estimator_list=estimators_when_no_prophet,
period=time_horizon,
)

Expand Down Expand Up @@ -161,7 +163,9 @@ def load_multi_dataset():
return df


def test_multivariate_forecast_num(budget=5):
def test_multivariate_forecast_num(
budget=5, estimators_when_no_prophet=["arima", "sarimax", "holt-winters"]
):
df = load_multi_dataset()
# split data into train and test
time_horizon = 180
Expand Down Expand Up @@ -193,7 +197,7 @@ def test_multivariate_forecast_num(budget=5):
automl.fit(
dataframe=train_df,
**settings,
estimator_list=["arima", "sarimax"],
estimator_list=estimators_when_no_prophet,
period=time_horizon,
)
""" retrieve best config and best learner"""
Expand Down Expand Up @@ -293,7 +297,9 @@ def above_monthly_avg(date, temp):
return train_df, test_df


def test_multivariate_forecast_cat(budget=5):
def test_multivariate_forecast_cat(
budget=5, estimators_when_no_prophet=["arima", "sarimax", "holt-winters"]
):
time_horizon = 180
train_df, test_df = load_multi_dataset_cat(time_horizon)
X_test = test_df[
Expand All @@ -320,7 +326,7 @@ def test_multivariate_forecast_cat(budget=5):
automl.fit(
dataframe=train_df,
**settings,
estimator_list=["arima", "sarimax"],
estimator_list=estimators_when_no_prophet,
period=time_horizon,
)
""" retrieve best config and best learner"""
Expand Down
1 change: 1 addition & 0 deletions website/docs/Use-Cases/Task-Oriented-AutoML.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ The estimator list can contain one or more estimator names, each corresponding t
- 'prophet': Prophet for task "ts_forecast". Hyperparameters: changepoint_prior_scale, seasonality_prior_scale, holidays_prior_scale, seasonality_mode.
- 'arima': ARIMA for task "ts_forecast". Hyperparameters: p, d, q.
- 'sarimax': SARIMAX for task "ts_forecast". Hyperparameters: p, d, q, P, D, Q, s.
- 'holt-winters': Holt-Winters (triple exponential smoothing) model for task "ts_forecast". Hyperparameters: seasonal_perdiods, seasonal, use_boxcox, trend, damped_trend.
- 'transformer': Huggingface transformer models for task "seq-classification", "seq-regression", "multichoice-classification", "token-classification" and "summarization". Hyperparameters: learning_rate, num_train_epochs, per_device_train_batch_size, warmup_ratio, weight_decay, adam_epsilon, seed.
- 'temporal_fusion_transformer': TemporalFusionTransformerEstimator for task "ts_forecast_panel". Hyperparameters: gradient_clip_val, hidden_size, hidden_continuous_size, attention_head_size, dropout, learning_rate. There is a [known issue](https://github.com/jdb78/pytorch-forecasting/issues/1145) with pytorch-forecast logging.
* Custom estimator. Use custom estimator for:
Expand Down

0 comments on commit 7f9402b

Please sign in to comment.