Skip to content

Commit

Permalink
Added boolean disables, generation of train (metrics.csv) and test me…
Browse files Browse the repository at this point in the history
…trics (test_metrics.csv) (#397)
  • Loading branch information
ahosler authored Nov 1, 2023
2 parents 0cdb5d0 + 4eef4e0 commit 223f9c0
Show file tree
Hide file tree
Showing 11 changed files with 469 additions and 218 deletions.
5 changes: 1 addition & 4 deletions ads/opctl/operator/lowcode/forecast/model/arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .. import utils
from .base_model import ForecastOperatorBaseModel
from ..operator_config import ForecastOperatorConfig


class ArimaOperatorModel(ForecastOperatorBaseModel):
Expand Down Expand Up @@ -149,17 +150,13 @@ def _generate_report(self):
"it predicts future values based on past values."
)
other_sections = all_sections
forecast_col_name = "yhat"
train_metrics = False
ds_column_series = self.data[self.spec.datetime_column.name]
ds_forecast_col = self.outputs[0].index
ci_col_names = ["yhat_lower", "yhat_upper"]

return (
model_description,
other_sections,
forecast_col_name,
train_metrics,
ds_column_series,
ds_forecast_col,
ci_col_names,
Expand Down
6 changes: 1 addition & 5 deletions ads/opctl/operator/lowcode/forecast/model/automlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _generate_report(self):

all_sections = [selected_models_text, selected_models_section]

if self.spec.explain:
if self.spec.generate_explanations:
# If the key is present, call the "explain_model" method
self.explain_model()

Expand Down Expand Up @@ -263,17 +263,13 @@ def _generate_report(self):
"high-quality features in your dataset, which are then provided for further processing."
)
other_sections = all_sections
forecast_col_name = "yhat"
train_metrics = False
ds_column_series = self.data[self.spec.datetime_column.name]
ds_forecast_col = self.outputs[0]["ds"]
ci_col_names = ["yhat_lower", "yhat_upper"]

return (
model_description,
other_sections,
forecast_col_name,
train_metrics,
ds_column_series,
ds_forecast_col,
ci_col_names,
Expand Down
5 changes: 1 addition & 4 deletions ads/opctl/operator/lowcode/forecast/model/autots.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from .. import utils
from .base_model import ForecastOperatorBaseModel
from ..operator_config import ForecastOperatorConfig
from ads.common.decorator.runtime_dependency import runtime_dependency


Expand Down Expand Up @@ -261,8 +262,6 @@ def _generate_report(self) -> tuple:
)

other_sections = all_sections
forecast_col_name = "yhat"
train_metrics = False

ds_column_series = pd.to_datetime(self.data[self.spec.datetime_column.name])
ds_forecast_col = self.outputs[0].index
Expand All @@ -271,8 +270,6 @@ def _generate_report(self) -> tuple:
return (
model_description,
other_sections,
forecast_col_name,
train_metrics,
ds_column_series,
ds_forecast_col,
ci_col_names,
Expand Down
Loading

0 comments on commit 223f9c0

Please sign in to comment.