Skip to content

Commit

Permalink
Merge branch 'main' into ODSC-65032/md-shapes-api
Browse files Browse the repository at this point in the history
  • Loading branch information
kumar-shivam-ranjan authored Nov 26, 2024
2 parents 8dc7bba + affeae4 commit bb2476b
Show file tree
Hide file tree
Showing 20 changed files with 190 additions and 98 deletions.
10 changes: 6 additions & 4 deletions ads/opctl/operator/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
Expand All @@ -18,23 +17,26 @@
from cerberus import Validator

from ads.opctl import logger, utils
from ads.opctl.operator import __operators__

CONTAINER_NETWORK = "CONTAINER_NETWORK"


class OperatorValidator(Validator):
"""The custom validator class."""

pass
def validate(self, obj_dict, **kwargs):
# Model should be case insensitive
if "model" in obj_dict["spec"]:
obj_dict["spec"]["model"] = str(obj_dict["spec"]["model"]).lower()
return super().validate(obj_dict, **kwargs)


def create_output_folder(name):
output_folder = name
protocol = fsspec.utils.get_protocol(output_folder)
storage_options = {}
if protocol != "file":
storage_options = auth or default_signer()
storage_options = default_signer()

fs = fsspec.filesystem(protocol, **storage_options)
name_suffix = 1
Expand Down
5 changes: 2 additions & 3 deletions ads/opctl/operator/lowcode/anomaly/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,8 @@ def generate_report(self):
yaml_appendix = rc.Yaml(self.config.to_dict())
summary = rc.Block(
rc.Group(
rc.Text(
f"You selected the **`{self.spec.model}`** model.\n{model_description.text}\n"
),
rc.Text(f"You selected the **`{self.spec.model}`** model.\n"),
model_description,
rc.Text(
"Based on your dataset, you could have also selected "
f"any of the models: `{'`, `'.join(SupportedModels.keys() if self.spec.datetime_column else NonTimeADSupportedModels.keys())}`."
Expand Down
4 changes: 2 additions & 2 deletions ads/opctl/operator/lowcode/anomaly/model/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class UnSupportedModelError(Exception):

def __init__(self, operator_config: AnomalyOperatorConfig, model_type: str):
supported_models = (
SupportedModels.values
SupportedModels.values()
if operator_config.spec.datetime_column
else NonTimeADSupportedModels.values
else NonTimeADSupportedModels.values()
)
message = (
f"Model: `{model_type}` is not supported. "
Expand Down
24 changes: 14 additions & 10 deletions ads/opctl/operator/lowcode/common/transformations.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2023 Oracle and/or its affiliates.
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from abc import ABC

import pandas as pd

from ads.opctl import logger
from ads.opctl.operator.lowcode.common.const import DataColumns
from ads.opctl.operator.lowcode.common.errors import (
InvalidParameterError,
DataMismatchError,
InvalidParameterError,
)
from ads.opctl.operator.lowcode.common.const import DataColumns
from ads.opctl.operator.lowcode.common.utils import merge_category_columns
import pandas as pd
from abc import ABC


class Transformations(ABC):
Expand Down Expand Up @@ -58,6 +59,7 @@ def run(self, data):
"""
clean_df = self._remove_trailing_whitespace(data)
# clean_df = self._normalize_column_names(clean_df)
if self.name == "historical_data":
self._check_historical_dataset(clean_df)
clean_df = self._set_series_id_column(clean_df)
Expand Down Expand Up @@ -95,8 +97,11 @@ def run(self, data):
def _remove_trailing_whitespace(self, df):
return df.apply(lambda x: x.str.strip() if x.dtype == "object" else x)

# def _normalize_column_names(self, df):
# return df.rename(columns=lambda x: re.sub("[^A-Za-z0-9_]+", "", x))

def _set_series_id_column(self, df):
self._target_category_columns_map = dict()
self._target_category_columns_map = {}
if not self.target_category_columns:
df[DataColumns.Series] = "Series 1"
self.has_artificial_series = True
Expand Down Expand Up @@ -125,10 +130,10 @@ def _format_datetime_col(self, df):
df[self.dt_column_name] = pd.to_datetime(
df[self.dt_column_name], format=self.dt_column_format
)
except:
except Exception as ee:
raise InvalidParameterError(
f"Unable to determine the datetime type for column: {self.dt_column_name} in dataset: {self.name}. Please specify the format explicitly. (For example adding 'format: %d/%m/%Y' underneath 'name: {self.dt_column_name}' in the datetime_column section of the yaml file if you haven't already. For reference, here is the first datetime given: {df[self.dt_column_name].values[0]}"
)
) from ee
return df

def _set_multi_index(self, df):
Expand Down Expand Up @@ -242,7 +247,6 @@ def _check_historical_dataset(self, df):
"Class": "A",
"Num": 2
},
}
"""

Expand Down
74 changes: 37 additions & 37 deletions ads/opctl/operator/lowcode/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,32 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import argparse
import logging
import os
import shutil
import sys
import tempfile
import time
from string import Template
from typing import Any, Dict, List, Tuple
import pandas as pd
from ads.opctl import logger
import oracledb
from typing import List, Union

import fsspec
import yaml
from typing import Union
import oracledb
import pandas as pd

from ads.common.object_storage_details import ObjectStorageDetails
from ads.opctl import logger
from ads.opctl.operator.common.operator_config import OutputDirectory
from ads.opctl.operator.lowcode.common.errors import (
InputDataError,
InvalidParameterError,
PermissionsError,
DataMismatchError,
)
from ads.opctl.operator.common.operator_config import OutputDirectory
from ads.common.object_storage_details import ObjectStorageDetails
from ads.secrets import ADBSecretKeeper


def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):
if fsspec.utils.get_protocol(filename) == "file":
return pd_fn(filename, **kwargs)
elif fsspec.utils.get_protocol(filename) in ["http", "https"]:
if fsspec.utils.get_protocol(filename) == "file" or fsspec.utils.get_protocol(
filename
) in ["http", "https"]:
return pd_fn(filename, **kwargs)

storage_options = storage_options or (
Expand All @@ -48,7 +38,7 @@ def call_pandas_fsspec(pd_fn, filename, storage_options, **kwargs):

def load_data(data_spec, storage_options=None, **kwargs):
if data_spec is None:
raise InvalidParameterError(f"No details provided for this data source.")
raise InvalidParameterError("No details provided for this data source.")
filename = data_spec.url
format = data_spec.format
columns = data_spec.columns
Expand All @@ -67,7 +57,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
if not format:
_, format = os.path.splitext(filename)
format = format[1:]
if format in ["json", "clipboard", "excel", "csv", "feather", "hdf"]:
if format in ["json", "clipboard", "excel", "csv", "feather", "hdf", "parquet"]:
read_fn = getattr(pd, f"read_{format}")
data = call_pandas_fsspec(
read_fn, filename, storage_options=storage_options
Expand All @@ -84,19 +74,31 @@ def load_data(data_spec, storage_options=None, **kwargs):
with tempfile.TemporaryDirectory() as temp_dir:
if vault_secret_id is not None:
try:
with ADBSecretKeeper.load_secret(vault_secret_id, wallet_dir=temp_dir) as adwsecret:
if 'wallet_location' in adwsecret and 'wallet_location' not in connect_args:
shutil.unpack_archive(adwsecret["wallet_location"], temp_dir)
connect_args['wallet_location'] = temp_dir
if 'user_name' in adwsecret and 'user' not in connect_args:
connect_args['user'] = adwsecret['user_name']
if 'password' in adwsecret and 'password' not in connect_args:
connect_args['password'] = adwsecret['password']
if 'service_name' in adwsecret and 'service_name' not in connect_args:
connect_args['service_name'] = adwsecret['service_name']
with ADBSecretKeeper.load_secret(
vault_secret_id, wallet_dir=temp_dir
) as adwsecret:
if (
"wallet_location" in adwsecret
and "wallet_location" not in connect_args
):
shutil.unpack_archive(
adwsecret["wallet_location"], temp_dir
)
connect_args["wallet_location"] = temp_dir
if "user_name" in adwsecret and "user" not in connect_args:
connect_args["user"] = adwsecret["user_name"]
if "password" in adwsecret and "password" not in connect_args:
connect_args["password"] = adwsecret["password"]
if (
"service_name" in adwsecret
and "service_name" not in connect_args
):
connect_args["service_name"] = adwsecret["service_name"]

except Exception as e:
raise Exception(f"Could not retrieve database credentials from vault {vault_secret_id}: {e}")
raise Exception(
f"Could not retrieve database credentials from vault {vault_secret_id}: {e}"
)

con = oracledb.connect(**connect_args)
if table_name is not None:
Expand All @@ -105,11 +107,11 @@ def load_data(data_spec, storage_options=None, **kwargs):
data = pd.read_sql(sql, con)
else:
raise InvalidParameterError(
f"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`."
"Database `connect_args` provided without sql query or table name. Please specify either `sql` or `table_name`."
)
else:
raise InvalidParameterError(
f"No filename/url provided, and no connect_args provided. Please specify one of these if you want to read data from a file or a database respectively."
"No filename/url provided, and no connect_args provided. Please specify one of these if you want to read data from a file or a database respectively."
)
if columns:
# keep only these columns, done after load because only CSV supports stream filtering
Expand Down Expand Up @@ -232,7 +234,7 @@ def human_time_friendly(seconds):
accumulator.append(
"{} {}{}".format(int(amount), unit, "" if amount == 1 else "s")
)
accumulator.append("{} secs".format(round(seconds, 2)))
accumulator.append(f"{round(seconds, 2)} secs")
return ", ".join(accumulator)


Expand All @@ -248,9 +250,7 @@ def find_output_dirname(output_dir: OutputDirectory):
unique_output_dir = f"{output_dir}_{counter}"
counter += 1
logger.warn(
"Since the output directory was not specified, the output will be saved to {} directory.".format(
unique_output_dir
)
f"Since the output directory was not specified, the output will be saved to {unique_output_dir} directory."
)
return unique_output_dir

Expand Down
12 changes: 10 additions & 2 deletions ads/opctl/operator/lowcode/forecast/model/automlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2023, 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
import logging
import os
import traceback

import numpy as np
Expand Down Expand Up @@ -80,10 +81,17 @@ def _build_model(self) -> pd.DataFrame:

from automlx import Pipeline, init

cpu_count = os.cpu_count()
try:
if cpu_count < 4:
engine = "local"
engine_opts = None
else:
engine = "ray"
engine_opts = ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
init(
engine="ray",
engine_opts={"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},
engine=engine,
engine_opts=engine_opts,
loglevel=logging.CRITICAL,
)
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion ads/opctl/operator/lowcode/forecast/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ def generate_report(self):
header_section = rc.Block(
rc.Heading("Forecast Report", level=1),
rc.Text(
f"You selected the {self.spec.model} model.\n{model_description}\nBased on your dataset, you could have also selected any of the models: {SupportedModels.keys()}."
f"You selected the {self.spec.model} model.\nBased on your dataset, you could have also selected any of the models: {SupportedModels.keys()}."
),
model_description,
rc.Group(
rc.Metric(
heading="Analysis was completed in ",
Expand Down
5 changes: 3 additions & 2 deletions ads/opctl/operator/lowcode/forecast/model/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .autots import AutoTSOperatorModel
from .base_model import ForecastOperatorBaseModel
from .forecast_datasets import ForecastDatasets
from .ml_forecast import MLForecastOperatorModel
from .neuralprophet import NeuralProphetOperatorModel
from .prophet import ProphetOperatorModel

Expand All @@ -19,7 +20,7 @@ class UnSupportedModelError(Exception):
def __init__(self, model_type: str):
super().__init__(
f"Model: `{model_type}` "
f"is not supported. Supported models: {SupportedModels.values}"
f"is not supported. Supported models: {SupportedModels.values()}"
)


Expand All @@ -32,7 +33,7 @@ class ForecastOperatorModelFactory:
SupportedModels.Prophet: ProphetOperatorModel,
SupportedModels.Arima: ArimaOperatorModel,
SupportedModels.NeuralProphet: NeuralProphetOperatorModel,
# SupportedModels.LGBForecast: MLForecastOperatorModel,
SupportedModels.LGBForecast: MLForecastOperatorModel,
SupportedModels.AutoMLX: AutoMLXOperatorModel,
SupportedModels.AutoTS: AutoTSOperatorModel,
}
Expand Down
5 changes: 4 additions & 1 deletion ads/opctl/operator/lowcode/forecast/model/prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def _build_model(self) -> pd.DataFrame:
dt_column=self.spec.datetime_column.name,
)

# if os.environ["OCI__IS_SPARK"]:
# pass
# else:
Parallel(n_jobs=-1, require="sharedmem")(
delayed(ProphetOperatorModel._train_model)(
self, i, series_id, df, model_kwargs.copy()
Expand Down Expand Up @@ -354,7 +357,7 @@ def _generate_report(self):
logger.warn(f"Failed to generate Explanations with error: {e}.")
logger.debug(f"Full Traceback: {traceback.format_exc()}")

model_description = (
model_description = rc.Text(
"Prophet is a procedure for forecasting time series data based on an additive "
"model where non-linear trends are fit with yearly, weekly, and daily seasonality, "
"plus holiday effects. It works best with time series that have strong seasonal "
Expand Down
2 changes: 1 addition & 1 deletion ads/opctl/operator/lowcode/forecast/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ spec:
missing_value_imputation:
type: boolean
required: false
default: false
default: true
outlier_treatment:
type: boolean
required: false
Expand Down
Loading

0 comments on commit bb2476b

Please sign in to comment.