Skip to content

Commit

Permalink
resolving circular import
Browse files Browse the repository at this point in the history
  • Loading branch information
ahosler committed Nov 3, 2023
1 parent 586c856 commit 9fcc8f3
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions ads/opctl/operator/lowcode/forecast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from .const import SupportedMetrics, SupportedModels
from .errors import ForecastInputDataError, ForecastSchemaYamlError
from .operator_config import ForecastOperatorSpec, ForecastOperatorConfig
from .model.forecast_datasets import ForecastDatasets


def _label_encode_dataframe(df, no_encode=set()):
Expand Down Expand Up @@ -464,7 +463,9 @@ def human_time_friendly(seconds):
return ", ".join(accumulator)


def select_auto_model(datasets: ForecastDatasets, operator_config: ForecastOperatorConfig) -> str:
def select_auto_model(
datasets: "ForecastDatasets", operator_config: ForecastOperatorConfig
) -> str:
"""
Selects AutoMLX or Arima model based on column count.
Expand All @@ -487,10 +488,10 @@ def select_auto_model(datasets: ForecastDatasets, operator_config: ForecastOpera
if num_of_additional_cols < 15 and row_count < 10000 and number_of_series < 10:
return SupportedModels.AutoMLX
elif row_count < 10000 and number_of_series > 10:
operator_config.spec.model_kwargs['model_list'] = "fast_parallel"
operator_config.spec.model_kwargs["model_list"] = "fast_parallel"
return SupportedModels.AutoTS
elif row_count < 20000 and number_of_series > 10:
operator_config.spec.model_kwargs['model_list'] = "superfast"
operator_config.spec.model_kwargs["model_list"] = "superfast"
return SupportedModels.AutoTS
elif row_count > 20000:
return SupportedModels.NeuralProphet
Expand Down

0 comments on commit 9fcc8f3

Please sign in to comment.