Skip to content

Commit

Permalink
bug fix for datetime column in mlforecast
Browse files Browse the repository at this point in the history
  • Loading branch information
prasankh committed May 21, 2024
1 parent 2a625ed commit 2e39598
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions ads/opctl/operator/lowcode/forecast/model/ml_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(self, config: ForecastOperatorConfig, datasets: ForecastDatasets):
self.local_explanation = {}
self.formatted_global_explanation = None
self.formatted_local_explanation = None
self.date_col = config.spec.datetime_column.name

def set_kwargs(self):
"""
Expand Down Expand Up @@ -73,8 +74,8 @@ def _train_model(self, data_train, data_test, model_kwargs):
alpha=model_kwargs["lower_quantile"],
),
},
freq=pd.infer_freq(data_train["Date"].drop_duplicates())
or pd.infer_freq(data_train["Date"].drop_duplicates()[-5:]),
freq=pd.infer_freq(data_train[self.date_col].drop_duplicates())
or pd.infer_freq(data_train[self.date_col].drop_duplicates()[-5:]),
target_transforms=[Differences([12])],
lags=model_kwargs.get(
"lags",
Expand Down Expand Up @@ -104,7 +105,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
data_train[self.model_columns],
static_features=model_kwargs.get("static_features", []),
id_col=ForecastOutputColumns.SERIES,
time_col=self.spec.datetime_column.name,
time_col=self.date_col,
target_col=self.spec.target_column,
fitted=True,
max_horizon=None if num_models is False else self.spec.horizon,
Expand Down Expand Up @@ -168,7 +169,7 @@ def _build_model(self) -> pd.DataFrame:
confidence_interval_width=self.spec.confidence_interval_width,
horizon=self.spec.horizon,
target_column=self.original_target_column,
dt_column=self.spec.datetime_column.name,
dt_column=self.date_col,
)
self._train_model(data_train, data_test, model_kwargs)
return self.forecast_output.get_forecast_long()
Expand Down

0 comments on commit 2e39598

Please sign in to comment.