Skip to content

Commit

Permalink
Forecasting Operator (#268)
Browse files Browse the repository at this point in the history
Co-authored-by: Vikas Pandey <vikas.v.pandey@oracle.com>
Co-authored-by: Allen Hosler <allen.hosler@oracle.com>
Co-authored-by: Prashant Sankhla <prashant.s.sankhla@oracle.com>
Co-authored-by: Vikas Pandey <vikaspandey707@gmail.com>
Co-authored-by: Prashant Sankhla <sankhlaprashant15@gmail.com>
Co-authored-by: Goalla Varsha <goalla.v.varsha@oracle.com>
Co-authored-by: MING KANG <ming.kang@oracle.com>
  • Loading branch information
8 people authored Nov 16, 2023
1 parent 2f9be39 commit 79722d0
Show file tree
Hide file tree
Showing 201 changed files with 15,472 additions and 723 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,6 @@ logs/

# vim
*.swp

# Python Wheel
*.whl
4 changes: 4 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ exclude build/lib/notebooks/**
exclude benchmark/**
include ads/ads
include ads/model/common/*.*
include ads/operator/**/*.md
include ads/operator/**/*.yaml
include ads/operator/**/*.whl
include ads/operator/**/MLoperator
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ You have various options when installing ADS.
python3 -m pip install oracle-ads
```

### Installing OCI AI Operators

To use the AI Forecast Operator, install the "forecast" dependencies using the following command:

```bash
python3 -m pip install 'oracle_ads[forecast]==2.9.0'
```

### Installing extras libraries

To work with gradient boosting models, install the `boosted` module. This module includes XGBoost and LightGBM model classes.
Expand Down
7 changes: 4 additions & 3 deletions ads/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
import ads.opctl.cli
import ads.jobs.cli
import ads.pipeline.cli
import os
import json
import ads.opctl.operator.cli
except Exception as ex:
print(
"Please run `pip install oracle-ads[opctl]` to install "
"the required dependencies for ADS CLI."
"the required dependencies for ADS CLI. \n"
f"{str(ex)}"
)
logger.debug(ex)
logger.debug(traceback.format_exc())
Expand All @@ -44,6 +44,7 @@ def cli():
cli.add_command(ads.opctl.cli.commands)
cli.add_command(ads.jobs.cli.commands)
cli.add_command(ads.pipeline.cli.commands)
cli.add_command(ads.opctl.operator.cli.commands)


if __name__ == "__main__":
Expand Down
62 changes: 32 additions & 30 deletions ads/common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,8 @@ def create_signer(self) -> Dict:
user=configuration["user"],
fingerprint=configuration["fingerprint"],
private_key_file_location=configuration.get("key_file"),
pass_phrase= configuration.get("pass_phrase"),
private_key_content=configuration.get("key_content")
pass_phrase=configuration.get("pass_phrase"),
private_key_content=configuration.get("key_content"),
),
"client_kwargs": self.client_kwargs,
}
Expand Down Expand Up @@ -750,21 +750,10 @@ class SecurityToken(AuthSignerGenerator):
a given user - it requires that user's private key and security token.
It prepares extra arguments necessary for creating clients for variety of OCI services.
"""
SECURITY_TOKEN_GENERIC_HEADERS = [
"date",
"(request-target)",
"host"
]
SECURITY_TOKEN_BODY_HEADERS = [
"content-length",
"content-type",
"x-content-sha256"
]
SECURITY_TOKEN_REQUIRED = [
"security_token_file",
"key_file",
"region"
]

SECURITY_TOKEN_GENERIC_HEADERS = ["date", "(request-target)", "host"]
SECURITY_TOKEN_BODY_HEADERS = ["content-length", "content-type", "x-content-sha256"]
SECURITY_TOKEN_REQUIRED = ["security_token_file", "key_file", "region"]

def __init__(self, args: Optional[Dict] = None):
"""
Expand Down Expand Up @@ -831,12 +820,18 @@ def create_signer(self) -> Dict:
return {
"config": configuration,
"signer": oci.auth.signers.SecurityTokenSigner(
token=self._read_security_token_file(configuration.get("security_token_file")),
token=self._read_security_token_file(
configuration.get("security_token_file")
),
private_key=oci.signer.load_private_key_from_file(
configuration.get("key_file"), configuration.get("pass_phrase")
),
generic_headers=configuration.get("generic_headers", self.SECURITY_TOKEN_GENERIC_HEADERS),
body_headers=configuration.get("body_headers", self.SECURITY_TOKEN_BODY_HEADERS)
generic_headers=configuration.get(
"generic_headers", self.SECURITY_TOKEN_GENERIC_HEADERS
),
body_headers=configuration.get(
"body_headers", self.SECURITY_TOKEN_BODY_HEADERS
),
),
"client_kwargs": self.client_kwargs,
}
Expand All @@ -849,30 +844,37 @@ def _validate_and_refresh_token(self, configuration: Dict[str, Any]):
configuration: Dict
Security token configuration.
"""
security_token = self._read_security_token_file(configuration.get("security_token_file"))
security_token_container = oci.auth.security_token_container.SecurityTokenContainer(
session_key_supplier=None,
security_token=security_token
security_token = self._read_security_token_file(
configuration.get("security_token_file")
)
security_token_container = (
oci.auth.security_token_container.SecurityTokenContainer(
session_key_supplier=None, security_token=security_token
)
)

if not security_token_container.valid():
raise SecurityTokenError(
"Security token has expired. Call `oci session authenticate` to generate new session."
)

time_now = int(time.time())
time_expired = security_token_container.get_jwt()["exp"]
if time_expired - time_now < SECURITY_TOKEN_LEFT_TIME:
if not self.oci_config_location:
logger.warning("Can not auto-refresh token. Specify parameter `oci_config_location` through ads.set_auth() or ads.auth.create_signer().")
logger.warning(
"Can not auto-refresh token. Specify parameter `oci_config_location` through ads.set_auth() or ads.auth.create_signer()."
)
else:
result = os.system(f"oci session refresh --config-file {self.oci_config_location} --profile {self.oci_key_profile}")
result = os.system(
f"oci session refresh --config-file {self.oci_config_location} --profile {self.oci_key_profile}"
)
if result == 1:
logger.warning(
"Some error happened during auto-refreshing the token. Continue using the current one that's expiring in less than {SECURITY_TOKEN_LEFT_TIME} seconds."
"Please follow steps in https://docs.oracle.com/en-us/iaas/Content/API/SDKDocs/clitoken.htm to renew token."
)

date_time = datetime.fromtimestamp(time_expired).strftime("%Y-%m-%d %H:%M:%S")
logger.info(f"Session is valid until {date_time}.")

Expand All @@ -894,7 +896,7 @@ def _read_security_token_file(self, security_token_file: str) -> str:
raise ValueError("Invalid `security_token_file`. Specify a valid path.")
try:
token = None
with open(expanded_path, 'r') as f:
with open(expanded_path, "r") as f:
token = f.read()
return token
except:
Expand All @@ -903,7 +905,7 @@ def _read_security_token_file(self, security_token_file: str) -> str:

class AuthFactory:
"""
AuthFactory class which contains list of registered signers and alllows to register new signers.
AuthFactory class which contains list of registered signers and allows to register new signers.
Check documentation for more signers: https://docs.oracle.com/en-us/iaas/tools/python/latest/api/signing.html.
Current signers:
Expand Down
2 changes: 2 additions & 0 deletions ads/common/decorator/runtime_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class OptionalDependency:
OPTUNA = "oracle-ads[optuna]"
SPARK = "oracle-ads[spark]"
HUGGINGFACE = "oracle-ads[huggingface]"
FORECAST = "oracle-ads[forecast]"
PII = "oracle-ads[pii]"
FEATURE_STORE = "oracle-ads[feature-store]"
GRAPHVIZ = "oracle-ads[graphviz]"
MLM_INSIGHTS = "oracle-ads[mlm_insights]"
Expand Down
6 changes: 3 additions & 3 deletions ads/common/object_storage_details.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import json
Expand All @@ -15,7 +15,7 @@
from ads.common import oci_client


class InvalidObjectStoragePath(Exception): # pragma: no cover
class InvalidObjectStoragePath(Exception): # pragma: no cover
"""Invalid Object Storage Path."""

pass
Expand Down Expand Up @@ -137,4 +137,4 @@ def is_oci_path(uri: str = None) -> bool:
"""
if not uri:
return False
return uri.startswith("oci://")
return uri.lower().startswith("oci://")
79 changes: 56 additions & 23 deletions ads/common/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

"""
This module provides a base class for serializable items, as well as methods for serializing and
deserializing objects to and from JSON and YAML formats. It also includes methods for reading and
writing serialized objects to and from files.
"""

import dataclasses
import json
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -271,11 +277,16 @@ def from_yaml(
Parameters
----------
yaml_string (string, optional): YAML string. Defaults to None.
uri (string, optional): URI location of file containing YAML string. Defaults to None.
loader (callable, optional): Custom YAML loader. Defaults to CLoader/SafeLoader.
kwargs (dict): keyword arguments to be passed into fsspec.open(). For OCI object storage, this should be config="path/to/.oci/config".
For other storage connections consider e.g. host, port, username, password, etc.
yaml_string (string, optional)
YAML string. Defaults to None.
uri (string, optional)
URI location of file containing YAML string. Defaults to None.
loader (callable, optional)
Custom YAML loader. Defaults to CLoader/SafeLoader.
kwargs (dict)
keyword arguments to be passed into fsspec.open().
For OCI object storage, this should be config="path/to/.oci/config".
For other storage connections consider e.g. host, port, username, password, etc.
Raises
------
Expand All @@ -288,10 +299,10 @@ def from_yaml(
Returns instance of the class
"""
if yaml_string:
return cls.from_dict(yaml.load(yaml_string, Loader=loader))
return cls.from_dict(yaml.load(yaml_string, Loader=loader), **kwargs)
if uri:
yaml_dict = yaml.load(cls._read_from_file(uri=uri, **kwargs), Loader=loader)
return cls.from_dict(yaml_dict)
return cls.from_dict(yaml_dict, **kwargs)
raise ValueError("Must provide either YAML string or URI location")

@classmethod
Expand Down Expand Up @@ -345,8 +356,8 @@ class DataClassSerializable(Serializable):
Returns an instance of the class instantiated from the dictionary provided.
"""

@staticmethod
def _validate_dict(obj_dict: Dict) -> bool:
@classmethod
def _validate_dict(cls, obj_dict: Dict) -> bool:
"""validate the dictionary.
Parameters
Expand Down Expand Up @@ -379,7 +390,7 @@ def to_dict(self, **kwargs) -> Dict:
obj_dict = dataclasses.asdict(self)
if "side_effect" in kwargs and kwargs["side_effect"]:
obj_dict = DataClassSerializable._normalize_dict(
obj_dict=obj_dict, case=kwargs["side_effect"]
obj_dict=obj_dict, case=kwargs["side_effect"], recursively=True
)
return obj_dict

Expand All @@ -388,6 +399,8 @@ def from_dict(
cls,
obj_dict: dict,
side_effect: Optional[SideEffect] = SideEffect.CONVERT_KEYS_TO_LOWER.value,
ignore_unknown: Optional[bool] = False,
**kwargs,
) -> "DataClassSerializable":
"""Returns an instance of the class instantiated by the dictionary provided.
Expand All @@ -399,6 +412,8 @@ def from_dict(
side effect to take on the dictionary. The side effect can be either
convert the dictionary keys to "lower" (SideEffect.CONVERT_KEYS_TO_LOWER.value)
or "upper"(SideEffect.CONVERT_KEYS_TO_UPPER.value) cases.
ignore_unknown: (bool, optional). Defaults to `False`.
Whether to ignore unknown fields or not.
Returns
-------
Expand All @@ -415,25 +430,36 @@ def from_dict(

allowed_fields = set([f.name for f in dataclasses.fields(cls)])
wrong_fields = set(obj_dict.keys()) - allowed_fields
if wrong_fields:
if wrong_fields and not ignore_unknown:
logger.warning(
f"The class {cls.__name__} doesn't contain attributes: `{list(wrong_fields)}`. "
"These fields will be ignored."
)

obj = cls(**{key: obj_dict[key] for key in allowed_fields})
obj = cls(**{key: obj_dict.get(key) for key in allowed_fields})

for key, value in obj_dict.items():
if isinstance(value, dict) and hasattr(
getattr(cls(), key).__class__, "from_dict"
if (
key in allowed_fields
and isinstance(value, dict)
and hasattr(getattr(cls(), key).__class__, "from_dict")
):
attribute = getattr(cls(), key).__class__.from_dict(value)
attribute = getattr(cls(), key).__class__.from_dict(
value,
ignore_unknown=ignore_unknown,
side_effect=side_effect,
**kwargs,
)
setattr(obj, key, attribute)

return obj

@staticmethod
def _normalize_dict(
obj_dict: Dict, case: str = SideEffect.CONVERT_KEYS_TO_LOWER.value
obj_dict: Dict,
recursively: bool = False,
case: str = SideEffect.CONVERT_KEYS_TO_LOWER.value,
**kwargs,
) -> Dict:
"""lower all the keys.
Expand All @@ -444,6 +470,8 @@ def _normalize_dict(
case: (optional, str). Defaults to "lower".
the case to normalized to. can be either "lower" (SideEffect.CONVERT_KEYS_TO_LOWER.value)
or "upper"(SideEffect.CONVERT_KEYS_TO_UPPER.value).
recursively: (bool, optional). Defaults to `False`.
Whether to recursively normalize the dictionary or not.
Returns
-------
Expand All @@ -452,12 +480,16 @@ def _normalize_dict(
"""
normalized_obj_dict = {}
for key, value in obj_dict.items():
if isinstance(value, dict):
if recursively and isinstance(value, dict):
value = DataClassSerializable._normalize_dict(
value, case=SideEffect.CONVERT_KEYS_TO_UPPER.value
value, case=case, recursively=recursively, **kwargs
)
normalized_obj_dict = DataClassSerializable._normalize_key(
normalized_obj_dict=normalized_obj_dict, key=key, value=value, case=case
normalized_obj_dict=normalized_obj_dict,
key=key,
value=value,
case=case,
**kwargs,
)
return normalized_obj_dict

Expand All @@ -467,7 +499,7 @@ def _normalize_key(
) -> Dict:
"""helper function to normalize the key in the case specified and add it back to the dictionary.
Paramaters
Parameters
----------
normalized_obj_dict: (Dict)
the dictionary to append the key and value to.
Expand All @@ -476,17 +508,18 @@ def _normalize_key(
value: (Union[str, Dict])
value to be added.
case: (str)
the case to normalized to. can be either "lower" (SideEffect.CONVERT_KEYS_TO_LOWER.value)
The case to normalized to. can be either "lower" (SideEffect.CONVERT_KEYS_TO_LOWER.value)
or "upper"(SideEffect.CONVERT_KEYS_TO_UPPER.value).
Raises
------
NotImplementedError: if case provided is not either "lower" or "upper".
NotImplementedError
Raised when `case` is not supported.
Returns
-------
Dict
normalized dictionary with the key and value added in the case specified.
Normalized dictionary with the key and value added in the case specified.
"""
if case.lower() == SideEffect.CONVERT_KEYS_TO_LOWER.value:
normalized_obj_dict[key.lower()] = value
Expand Down
Loading

0 comments on commit 79722d0

Please sign in to comment.