Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forecasting Operator #268

Merged
merged 334 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 250 commits
Commits
Show all changes
334 commits
Select commit Hold shift + click to select a range
71ab9a6
added the autots multivariate operator, added the report and model ou…
codeloop Sep 14, 2023
cd9326c
syntax fix in model params
codeloop Sep 14, 2023
0c4e143
minor spacing in text
ahosler Sep 14, 2023
a3112c9
Update autots.py
ahosler Sep 14, 2023
8e248f0
Fixes the local backend for operators.
mrDzurb Sep 14, 2023
64d6f4d
Adds IAM policies documentation section for the operators
mrDzurb Sep 14, 2023
98b04b9
Removes json2table converter from operator's code
mrDzurb Sep 14, 2023
c5e7551
Adds opctl environment validator decorator.
mrDzurb Sep 15, 2023
a28acd3
Cleans up the operator's redundant functionality.
mrDzurb Sep 15, 2023
79d3f9b
Removes redundant docker files.
mrDzurb Sep 17, 2023
22c3338
Adds supported backends to the operator's init file.
mrDzurb Sep 18, 2023
b75c305
Merge branch 'feature/forecasting' of https://github.com/oracle/accel…
mrDzurb Sep 18, 2023
c299c08
fixed bugs that is causing errors when series are missing in test dat…
prasankh Sep 19, 2023
484cc24
initial benchmarks
prasankh Sep 19, 2023
f02d4a6
get the possible tunable model params via kwargs
codeloop Sep 20, 2023
0c3321f
Add AutoTS Multivariate Operator (#332)
codeloop Sep 20, 2023
8f8d0ba
update default yaml (#330)
ahosler Sep 20, 2023
b028fb6
Update benchmark_datasets.py
ahosler Sep 20, 2023
b1b836b
Update benchmark_datasets.py
ahosler Sep 20, 2023
a513d7b
adding 15% margin to automlx numbers
ahosler Sep 20, 2023
5a1d78a
Cleans up the operator's redundant functionality. (#334)
mrDzurb Sep 20, 2023
0a6fd95
Feature/forecasting benchmarks (#336)
prasankh Sep 21, 2023
9c08c4f
added flag for feature engineering
prasankh Sep 21, 2023
dea1cdd
Adds opctl environment validator decorator. (#333)
mrDzurb Sep 21, 2023
8303205
Adds operator loader to load operator from the remote source.
mrDzurb Sep 24, 2023
016aade
Merge branch 'feature/forecasting' of https://github.com/oracle/accel…
mrDzurb Sep 24, 2023
d7cf7ce
Improves the operator init command.
mrDzurb Sep 24, 2023
59a8440
formatting
prasankh Sep 25, 2023
4afde23
made preprocessing optional in automlx
prasankh Sep 25, 2023
a4a3ece
preprocessing flag not required
prasankh Sep 25, 2023
2d5345b
Relax the type field in forecast schema YAML.
mrDzurb Sep 25, 2023
8b5ff81
Uses default profile in ads opctl auth in case if given profile not f…
mrDzurb Sep 25, 2023
eacc574
ODSC-47199: Adds supporting GPU image for the operators.
mrDzurb Sep 26, 2023
5ce3453
Feature/forecasting skip feature engineering (#350)
prasankh Sep 26, 2023
d376777
fix typo
codeloop Sep 26, 2023
51954d5
Merge branch 'feature/forecasting' into feature/forecast_explain
codeloop Sep 26, 2023
e106bd9
ODSC-47630/fix_datapane_failure_for_only_one_series (#351)
govarsha Sep 26, 2023
02c1f00
ODSC-46457/refactor_forecast.csv (#349)
govarsha Sep 26, 2023
d0669f5
ODSC-47199/Adds supporting GPU image for the operators. (#352)
mrDzurb Sep 26, 2023
1d7d13a
ODSC-47777/Custom Operator Integration (#348)
mrDzurb Sep 26, 2023
9660068
bug fix and yaml docs
ahosler Sep 27, 2023
7b765da
Refactors ADS OPCTL OPERATOR.
mrDzurb Sep 28, 2023
d32ac94
Refactors the operator's documentaion.
mrDzurb Sep 28, 2023
997d27f
Refactors the `ads opctl operator` CLI. (#357)
mrDzurb Sep 28, 2023
ac932aa
Merge branch 'main' of https://github.com/oracle/accelerated-data-sci…
mrDzurb Sep 28, 2023
13eaf2d
Fixes the problem with the operator's base image.
mrDzurb Sep 28, 2023
e97f7bb
Moves operator information form the init file to the MLoperator yaml.
mrDzurb Oct 2, 2023
076cb58
Replaces the name attribute with the type one in the ads operator cli.
mrDzurb Oct 2, 2023
189bf8c
Fixes operator's publish conda command.
mrDzurb Oct 2, 2023
869856d
Fixes the operator's documentation.
mrDzurb Oct 2, 2023
6b93921
Merge branch 'feature/forecasting' into feature/forecast_explain
codeloop Oct 3, 2023
9b448bc
Moves operator information form the init file to the MLoperator yaml.…
mrDzurb Oct 4, 2023
6aa416a
Adds opeartor tests structure.
mrDzurb Oct 3, 2023
00815f1
Adds tests for the operator yaml generator.
mrDzurb Oct 4, 2023
776aa9a
update the custom predict & model explain
codeloop Oct 4, 2023
f70d6bb
Adds tests for the operator loader.
mrDzurb Oct 4, 2023
7b287d4
skipping issues with training metrics
ahosler Oct 5, 2023
5cc0c1d
added docstring, other fixes
codeloop Oct 5, 2023
308894f
Merge branch 'feature/forecasting' into feature/forecast_explain
codeloop Oct 5, 2023
c890366
remove redundant import
codeloop Oct 5, 2023
0fdcb9f
remove redundant import
codeloop Oct 5, 2023
1854750
Operator. Unit tests for the common utils. (#364)
mrDzurb Oct 5, 2023
ff4b1a9
Fixes the operator run command.
mrDzurb Oct 6, 2023
5bdc0dc
Fixes building base operator image.
mrDzurb Oct 7, 2023
1869db8
generate explaination when explain_model kwargs is available for auto…
codeloop Oct 9, 2023
2ae6d15
code cleanup, add dependency to env yaml
codeloop Oct 9, 2023
27148f2
updating forecasting operator docs
ahosler Oct 9, 2023
4dbf46b
remove breakpoint
codeloop Oct 9, 2023
4d0ecc1
Fixes common utils unit tests.
mrDzurb Oct 9, 2023
9801e8a
Merge branch 'main' of https://github.com/oracle/accelerated-data-sci…
mrDzurb Oct 9, 2023
efae87b
Fixes the DataClassSerializer from_yaml method.
mrDzurb Oct 9, 2023
c802c18
fix comments + docstrings
codeloop Oct 10, 2023
a25951a
extended operator errors and added forecast specific exceptions
prasankh Sep 29, 2023
b8911c8
reverted errors in cmd.py
prasankh Oct 9, 2023
a03cd9c
Extended operator errors and added forecast specific exceptions (#360)
prasankh Oct 10, 2023
e006d73
[ODSC-47260] Global Explainer Class : Feature/forecast explain (#367)
codeloop Oct 11, 2023
9d83e05
Merge branch 'main' of https://github.com/oracle/accelerated-data-sci…
mrDzurb Oct 11, 2023
3459d6c
add local explainer
codeloop Oct 12, 2023
19f6287
add report for local explanability, format with black & normalize scores
codeloop Oct 12, 2023
75c931e
fix return type & docstring
codeloop Oct 12, 2023
983a58b
updating yaml formatting
ahosler Oct 12, 2023
2563bf1
Fixes GIT operator loader.
mrDzurb Oct 12, 2023
42b0202
Merge branch 'feature/forecasting' of https://github.com/oracle/accel…
mrDzurb Oct 12, 2023
b280b10
Removes operator.py from the forecast.
mrDzurb Oct 17, 2023
fcf7c16
Moves forecast.operate into the __main__
mrDzurb Oct 17, 2023
23db302
fixed bug in holdout data summary metrics per horizon
govarsha Oct 17, 2023
b1cdc4b
updating docs formatting
ahosler Oct 17, 2023
b579c4c
ODSC-47259 Local Explainer Class : Feature/forecast explain (#375)
codeloop Oct 17, 2023
5eca230
bug fixes
ahosler Oct 17, 2023
b6f7e01
enable time limit in automlx
ahosler Oct 17, 2023
f3ab090
Merge branch 'main' of https://github.com/oracle/accelerated-data-sci…
mrDzurb Oct 18, 2023
39c3b9a
Fixes operator's tests.
mrDzurb Oct 19, 2023
3c220c3
Made small changes
govarsha Oct 19, 2023
6e33aef
fixed bug in holdout data summary metrics per horizon (#377)
govarsha Oct 19, 2023
346e931
handles empty test file
govarsha Oct 4, 2023
15df43b
handles when test file doesnot have any series
govarsha Oct 11, 2023
9c68820
Fixes OPCTL unit tests.
mrDzurb Oct 20, 2023
d0e0a13
handles when entire series or some values in series is missing in tes…
govarsha Oct 20, 2023
2e5527f
updating formatting
ahosler Oct 20, 2023
652ca51
Fixes OPCTL DataFlow unit tests.
mrDzurb Oct 20, 2023
e86d470
ODSC-48941. Run operator within 'ads opctl run'
mrDzurb Oct 23, 2023
2e98395
fixing merge error
ahosler Oct 23, 2023
b94c9a6
Adds an option to merge operator and backend configs into one YAML.
mrDzurb Oct 23, 2023
d18fca0
Adjusts the operators documentation.
mrDzurb Oct 23, 2023
c7b8dc7
Run operator within "ads opctl run" (#384)
mrDzurb Oct 23, 2023
f05b8e4
Improves operators exploration documentation.
mrDzurb Oct 23, 2023
a425a7e
Merge branch 'main' of https://github.com/oracle/accelerated-data-sci…
mrDzurb Oct 24, 2023
a91d880
Added unit tests
govarsha Oct 25, 2023
63d064c
handles case where there are no series in test data
govarsha Oct 25, 2023
e019ff0
Merge branch 'feature/forecasting' of https://github.com/oracle/accel…
govarsha Oct 25, 2023
8bbfb10
small change
govarsha Oct 25, 2023
140fc7a
add local & global explanation for prophet model
codeloop Oct 25, 2023
33e8781
refactor local explainer
codeloop Oct 25, 2023
f568558
add the reports for global & local explainability, reformat with blac…
codeloop Oct 26, 2023
62b1b14
add runtime dependency, internal date col to constants, update docstring
codeloop Oct 26, 2023
bd989f3
Reduced horizon parameters and handled datetime error for AutoMLx model
prasankh Oct 27, 2023
bca8a9f
Fixes operator init method.
mrDzurb Oct 27, 2023
d7592a8
minor refactoring
prasankh Oct 30, 2023
98d8b72
resolving autots models bug
ahosler Oct 30, 2023
7c32dfa
added boolean disables and added test_metrics.csv generation
govarsha Oct 30, 2023
1660455
updated metrics_filename, test_metrics_filename, forecast_filename in…
govarsha Oct 30, 2023
05350e1
additional references to periods
ahosler Oct 30, 2023
6824a57
changed report_file_name to report_filename
govarsha Oct 30, 2023
ebd8c19
ODSC-49155: Operator init method fails when OCI config is not provide…
mrDzurb Oct 30, 2023
43e9d04
changes to enable the model_kwargs dict to be passed directly through…
govarsha Oct 30, 2023
dccc779
added changes to schema.yaml
govarsha Oct 30, 2023
a9417fd
added comments
govarsha Oct 30, 2023
bc88e7b
added unit test to test autots parameter passthrough
govarsha Oct 31, 2023
7fc641f
removed comments
govarsha Oct 31, 2023
4ad4873
small fix
govarsha Oct 31, 2023
c9bee4f
Reduced horizon parameters and handled datetime error for AutoMLx mod…
prasankh Oct 31, 2023
3ddb17a
resolving autots models bug (#398)
ahosler Oct 31, 2023
74c11b7
Fixes operator init method.
mrDzurb Oct 27, 2023
31f6098
Reduced horizon parameters and handled datetime error for AutoMLx model
prasankh Oct 27, 2023
40e38b3
minor refactoring
prasankh Oct 30, 2023
3165efd
additional references to periods
ahosler Oct 30, 2023
02ce9e9
resolving autots models bug
ahosler Oct 30, 2023
e961bca
changes to enable the model_kwargs dict to be passed directly through…
govarsha Oct 30, 2023
6075727
removed comments
govarsha Oct 31, 2023
0ba2f62
small fix
govarsha Oct 31, 2023
bb3637b
Revert "small fix"
ahosler Oct 31, 2023
c4cb8ce
Revert "removed comments"
ahosler Oct 31, 2023
dc196d7
Revert "changes to enable the model_kwargs dict to be passed directly…
ahosler Oct 31, 2023
0df1441
Revert "resolving autots models bug"
ahosler Oct 31, 2023
73dd161
Revert "additional references to periods"
ahosler Oct 31, 2023
27523dc
Revert "minor refactoring"
ahosler Oct 31, 2023
24c20fb
Revert "Reduced horizon parameters and handled datetime error for Aut…
ahosler Oct 31, 2023
54f61f7
Revert "Fixes operator init method."
ahosler Oct 31, 2023
498326b
Merge branch 'feature/forecasting' into ODSC-48871/autots_parameter_p…
ahosler Oct 31, 2023
50883c5
ODSC-48871/autots parameter passthrough (#399)
ahosler Oct 31, 2023
52eaf9c
Merge branch 'feature/forecasting' into feature/add-boolean-disables-…
ahosler Oct 31, 2023
9a8ab34
added train_metrics to base_model
govarsha Oct 31, 2023
db4b6f5
added forecast_col_name attribute to base_model
govarsha Oct 31, 2023
1c4949a
small fixes to be consistent with recent changes
govarsha Oct 31, 2023
7e12673
patching issue with automlx check
ahosler Oct 31, 2023
8559ebf
fixing merge conflicts
govarsha Oct 31, 2023
baff6cd
resolving merge conflicts
govarsha Oct 31, 2023
d4d49bc
Merge branch 'feature/add-boolean-disables-and-save-train-test-metric…
govarsha Oct 31, 2023
e7c3575
[ODSC-48860/48861] Global & Local explainability for prophet model (#…
codeloop Oct 31, 2023
3a5600d
freq function causing errors with automlx
ahosler Oct 31, 2023
98c074f
ODSC-47050: Adds unit tests for the operator backends.
mrDzurb Oct 27, 2023
19aade0
attempt 2 to get freq of datetime for automlx
ahosler Nov 1, 2023
5f1755f
attempt 2 to get freq of datetime for automlx
ahosler Nov 1, 2023
0cdb5d0
Feature/forecasting automlx freq bug (#400)
ahosler Nov 1, 2023
4eef4e0
Merge branch 'feature/forecasting' into feature/add-boolean-disables-…
ahosler Nov 1, 2023
223f9c0
Added boolean disables, generation of train (metrics.csv) and test me…
ahosler Nov 1, 2023
7ee5aa6
Merge branch 'feature/forecasting' into ODSC-46836/fix_incomplete_tes…
ahosler Nov 1, 2023
e65e5c1
Odsc 46836/fix incomplete testdata issues (#385)
ahosler Nov 1, 2023
b4ea0d0
explain bool bug
ahosler Nov 1, 2023
c5d05b1
patching bugs from merging
ahosler Nov 1, 2023
8ca0e96
ODSC-47050: Adds unit tests for the operator backends. (#401)
mrDzurb Nov 1, 2023
8be7bfb
remove print stmts
ahosler Nov 1, 2023
1d191f4
lld changes, moving data reading outside base_model class
prasankh Nov 2, 2023
0a65a61
auto algorithm improvements
prasankh Nov 2, 2023
8a9ffd6
add local & global explanation for arima model, add reports to datapane
codeloop Nov 3, 2023
8ea2035
added code to sort by datetime col
govarsha Nov 3, 2023
586c856
clean up
prasankh Nov 3, 2023
1a1f8c0
added sorting for additional data
govarsha Nov 3, 2023
db2f13a
added format parameter for to_datetime functions
govarsha Nov 3, 2023
a8014c1
new docs structure
ahosler Nov 3, 2023
fccdcdf
grabbing requirements
ahosler Nov 3, 2023
6bfd0a8
added unit test for automlx when unsorted data is given
govarsha Nov 3, 2023
8b90a19
fixing previous merge mistakes in base_model unit tests
govarsha Nov 3, 2023
258c8a2
fix for ODSC-48265: training metrics mismatch
govarsha Nov 3, 2023
9fcc8f3
resolving circular import
ahosler Nov 3, 2023
05ec20b
check for additional data
ahosler Nov 6, 2023
2c16b8c
updated auto model conditions
prasankh Nov 7, 2023
a7ebb8b
dropping target col if present in additional data
prasankh Nov 7, 2023
890a4d3
Feature/forecasting auto algorithm improvements (#407)
ahosler Nov 7, 2023
a87733a
Bug Fix : Dropping target column if present in additional data (#412)
ahosler Nov 7, 2023
e16c4b8
Merge branch 'feature/forecasting' into ODSC-49028/sort_datetime_col
ahosler Nov 7, 2023
f3dae93
ODSC-49028/sort by datetime col & ODSC-48265/training metrics mismatc…
ahosler Nov 7, 2023
3780d7e
merge resolution
ahosler Nov 7, 2023
66f2f79
creating csv output files
ahosler Nov 7, 2023
9a2299d
create the agg local explanation in long format
codeloop Nov 7, 2023
1cb6983
factoring out code
ahosler Nov 7, 2023
28ae7d7
check for additional data
ahosler Nov 7, 2023
1c2c1f2
using const throughout
ahosler Nov 7, 2023
a71fc50
Merge branch 'feature/forecasting' into feature/arima_model_explain
ahosler Nov 7, 2023
73be38f
[ODSC-48857 | ODSC-48858] Global Explainability & Local Explainabilit…
ahosler Nov 7, 2023
dc88d6b
refactoring code
ahosler Nov 7, 2023
5d829fc
changes for ODSC-49565
govarsha Nov 7, 2023
605f343
re-factoring
ahosler Nov 7, 2023
eb16cb0
bug fix for automlx explanation generation
codeloop Nov 7, 2023
3200c2c
ODSC-49565/corrections in metrics calculation per horizon (#415)
ahosler Nov 7, 2023
7bd9a62
adding yaml example
ahosler Nov 7, 2023
98dc11e
new docs structure (#409)
ahosler Nov 7, 2023
33d6f67
resolve example
ahosler Nov 7, 2023
627ee7c
minor docs formatting
ahosler Nov 7, 2023
04643b9
Merge branch 'main' of https://github.com/oracle/accelerated-data-sci…
mrDzurb Nov 8, 2023
33bb90e
clean docs
ahosler Nov 8, 2023
b77efea
typos
ahosler Nov 8, 2023
24013c3
ODSC-49703: Move the ADS config errors into debug level.
mrDzurb Nov 8, 2023
f3022c7
making the output formatting more consistent
ahosler Nov 12, 2023
fe542d6
making the output formatting more consistent (#419)
ahosler Nov 12, 2023
f6e1566
add local, global explanation for the autots model, add the formatted…
codeloop Nov 13, 2023
83f6f81
Merge branch 'feature/forecasting' into feature/autots_explain_model
codeloop Nov 13, 2023
14eca17
error statement re-word
ahosler Nov 13, 2023
df2fb6f
cleaning up output files
ahosler Nov 14, 2023
63895fc
Merge branch 'feature/forecasting' into feature/autots_explain_model
ahosler Nov 14, 2023
bb1c01e
[ODSC-48861 | ODSC-48862] Add local, global explanation for the AutoT…
ahosler Nov 14, 2023
6433593
automlx changes
ahosler Nov 14, 2023
e8278ae
updating forecast dependencies
ahosler Nov 14, 2023
460f007
adding requirements to ads forecast
ahosler Nov 14, 2023
c308ccf
rc1
ahosler Nov 14, 2023
b547481
updating docs for non-conda release
ahosler Nov 14, 2023
e2f3567
update pyproject
ahosler Nov 14, 2023
42143c5
update pyproject
ahosler Nov 14, 2023
7578885
test data bug
ahosler Nov 14, 2023
722b449
Updates the dev-requirements.txt with forecast requirements.
mrDzurb Nov 14, 2023
ca0829b
Merge branch 'feature/forecasting' of https://github.com/oracle/accel…
mrDzurb Nov 14, 2023
bf2cda0
Merge branch 'main' of https://github.com/oracle/accelerated-data-sci…
mrDzurb Nov 14, 2023
8084dea
PII operator. (#395)
mrDzurb Nov 15, 2023
9a0f820
minor typo
ahosler Nov 15, 2023
b64f45c
using datatables
ahosler Nov 15, 2023
a58c715
clean up
ahosler Nov 15, 2023
0be7edd
Merge branch 'main' into feature/forecasting
ahosler Nov 15, 2023
72be593
relaxing lightgbm req
ahosler Nov 15, 2023
6e69020
support for no index
ahosler Nov 15, 2023
1869b52
Merge branch 'main' of https://github.com/oracle/accelerated-data-sci…
mrDzurb Nov 15, 2023
702779b
ODSC-49703: Move the ADS config errors into debug level. (#418)
mrDzurb Nov 15, 2023
6d9c84f
Changes the version of ADS fro the forecasting.
mrDzurb Nov 15, 2023
a1f09f0
Merge branch 'feature/forecasting' of https://github.com/oracle/accel…
mrDzurb Nov 15, 2023
6899cbc
Fixing test for pii operator (#430)
mingkang111 Nov 16, 2023
a86d887
updated unittests according to latest changes
govarsha Nov 16, 2023
8a780b5
updated unittests according to latest changes (#431)
ahosler Nov 16, 2023
ab85920
more forecast unit tests
ahosler Nov 16, 2023
4d272bc
adding test lib
ahosler Nov 16, 2023
e064a6e
add docker dependency
ahosler Nov 16, 2023
c47eaec
remove tests
ahosler Nov 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,6 @@ logs/

# vim
*.swp

# Python Wheel
*.whl
4 changes: 4 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ exclude build/lib/notebooks/**
exclude benchmark/**
include ads/ads
include ads/model/common/*.*
include ads/operator/**/*.md
include ads/operator/**/*.yaml
include ads/operator/**/*.whl
include ads/operator/**/MLoperator
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ You have various options when installing ADS.
python3 -m pip install oracle-ads
```

### Installing OCI AI Operators

To use the AI Forecast Operator, install the "forecast" dependencies using the following command:

```bash
python3 -m pip install 'oracle_ads[forecast]==2.9.0rc1'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2.9.0?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

```

### Installing extras libraries

To work with gradient boosting models, install the `boosted` module. This module includes XGBoost and LightGBM model classes.
Expand Down
7 changes: 4 additions & 3 deletions ads/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
import ads.opctl.cli
import ads.jobs.cli
import ads.pipeline.cli
import os
import json
import ads.opctl.operator.cli
except Exception as ex:
print(
"Please run `pip install oracle-ads[opctl]` to install "
"the required dependencies for ADS CLI."
"the required dependencies for ADS CLI. \n"
f"{str(ex)}"
)
logger.debug(ex)
logger.debug(traceback.format_exc())
Expand All @@ -44,6 +44,7 @@ def cli():
cli.add_command(ads.opctl.cli.commands)
cli.add_command(ads.jobs.cli.commands)
cli.add_command(ads.pipeline.cli.commands)
cli.add_command(ads.opctl.operator.cli.commands)


if __name__ == "__main__":
Expand Down
62 changes: 32 additions & 30 deletions ads/common/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,8 @@ def create_signer(self) -> Dict:
user=configuration["user"],
fingerprint=configuration["fingerprint"],
private_key_file_location=configuration.get("key_file"),
pass_phrase= configuration.get("pass_phrase"),
private_key_content=configuration.get("key_content")
pass_phrase=configuration.get("pass_phrase"),
private_key_content=configuration.get("key_content"),
),
"client_kwargs": self.client_kwargs,
}
Expand Down Expand Up @@ -750,21 +750,10 @@ class SecurityToken(AuthSignerGenerator):
a given user - it requires that user's private key and security token.
It prepares extra arguments necessary for creating clients for variety of OCI services.
"""
SECURITY_TOKEN_GENERIC_HEADERS = [
"date",
"(request-target)",
"host"
]
SECURITY_TOKEN_BODY_HEADERS = [
"content-length",
"content-type",
"x-content-sha256"
]
SECURITY_TOKEN_REQUIRED = [
"security_token_file",
"key_file",
"region"
]

SECURITY_TOKEN_GENERIC_HEADERS = ["date", "(request-target)", "host"]
SECURITY_TOKEN_BODY_HEADERS = ["content-length", "content-type", "x-content-sha256"]
SECURITY_TOKEN_REQUIRED = ["security_token_file", "key_file", "region"]

def __init__(self, args: Optional[Dict] = None):
"""
Expand Down Expand Up @@ -831,12 +820,18 @@ def create_signer(self) -> Dict:
return {
"config": configuration,
"signer": oci.auth.signers.SecurityTokenSigner(
token=self._read_security_token_file(configuration.get("security_token_file")),
token=self._read_security_token_file(
configuration.get("security_token_file")
),
private_key=oci.signer.load_private_key_from_file(
configuration.get("key_file"), configuration.get("pass_phrase")
),
generic_headers=configuration.get("generic_headers", self.SECURITY_TOKEN_GENERIC_HEADERS),
body_headers=configuration.get("body_headers", self.SECURITY_TOKEN_BODY_HEADERS)
generic_headers=configuration.get(
"generic_headers", self.SECURITY_TOKEN_GENERIC_HEADERS
),
body_headers=configuration.get(
"body_headers", self.SECURITY_TOKEN_BODY_HEADERS
),
),
"client_kwargs": self.client_kwargs,
}
Expand All @@ -849,30 +844,37 @@ def _validate_and_refresh_token(self, configuration: Dict[str, Any]):
configuration: Dict
Security token configuration.
"""
security_token = self._read_security_token_file(configuration.get("security_token_file"))
security_token_container = oci.auth.security_token_container.SecurityTokenContainer(
session_key_supplier=None,
security_token=security_token
security_token = self._read_security_token_file(
configuration.get("security_token_file")
)
security_token_container = (
oci.auth.security_token_container.SecurityTokenContainer(
session_key_supplier=None, security_token=security_token
)
)

if not security_token_container.valid():
raise SecurityTokenError(
"Security token has expired. Call `oci session authenticate` to generate new session."
)

time_now = int(time.time())
time_expired = security_token_container.get_jwt()["exp"]
if time_expired - time_now < SECURITY_TOKEN_LEFT_TIME:
if not self.oci_config_location:
logger.warning("Can not auto-refresh token. Specify parameter `oci_config_location` through ads.set_auth() or ads.auth.create_signer().")
logger.warning(
"Can not auto-refresh token. Specify parameter `oci_config_location` through ads.set_auth() or ads.auth.create_signer()."
)
else:
result = os.system(f"oci session refresh --config-file {self.oci_config_location} --profile {self.oci_key_profile}")
result = os.system(
f"oci session refresh --config-file {self.oci_config_location} --profile {self.oci_key_profile}"
)
if result == 1:
logger.warning(
"Some error happened during auto-refreshing the token. Continue using the current one that's expiring in less than {SECURITY_TOKEN_LEFT_TIME} seconds."
"Please follow steps in https://docs.oracle.com/en-us/iaas/Content/API/SDKDocs/clitoken.htm to renew token."
)

date_time = datetime.fromtimestamp(time_expired).strftime("%Y-%m-%d %H:%M:%S")
logger.info(f"Session is valid until {date_time}.")

Expand All @@ -894,7 +896,7 @@ def _read_security_token_file(self, security_token_file: str) -> str:
raise ValueError("Invalid `security_token_file`. Specify a valid path.")
try:
token = None
with open(expanded_path, 'r') as f:
with open(expanded_path, "r") as f:
token = f.read()
return token
except:
Expand All @@ -903,7 +905,7 @@ def _read_security_token_file(self, security_token_file: str) -> str:

class AuthFactory:
"""
AuthFactory class which contains list of registered signers and alllows to register new signers.
AuthFactory class which contains list of registered signers and allows to register new signers.
Check documentation for more signers: https://docs.oracle.com/en-us/iaas/tools/python/latest/api/signing.html.
Current signers:
Expand Down
2 changes: 2 additions & 0 deletions ads/common/decorator/runtime_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class OptionalDependency:
OPTUNA = "oracle-ads[optuna]"
SPARK = "oracle-ads[spark]"
HUGGINGFACE = "oracle-ads[huggingface]"
FORECAST = "oracle-ads[forecast]"
PII = "oracle-ads[pii]"


def runtime_dependency(
Expand Down
6 changes: 3 additions & 3 deletions ads/common/object_storage_details.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2021, 2022 Oracle and/or its affiliates.
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import json
Expand All @@ -15,7 +15,7 @@
from ads.common import oci_client


class InvalidObjectStoragePath(Exception): # pragma: no cover
class InvalidObjectStoragePath(Exception): # pragma: no cover
"""Invalid Object Storage Path."""

pass
Expand Down Expand Up @@ -137,4 +137,4 @@ def is_oci_path(uri: str = None) -> bool:
"""
if not uri:
return False
return uri.startswith("oci://")
return uri.lower().startswith("oci://")
79 changes: 56 additions & 23 deletions ads/common/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

"""
This module provides a base class for serializable items, as well as methods for serializing and
deserializing objects to and from JSON and YAML formats. It also includes methods for reading and
writing serialized objects to and from files.
"""

import dataclasses
import json
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -271,11 +277,16 @@ def from_yaml(

Parameters
----------
yaml_string (string, optional): YAML string. Defaults to None.
uri (string, optional): URI location of file containing YAML string. Defaults to None.
loader (callable, optional): Custom YAML loader. Defaults to CLoader/SafeLoader.
kwargs (dict): keyword arguments to be passed into fsspec.open(). For OCI object storage, this should be config="path/to/.oci/config".
For other storage connections consider e.g. host, port, username, password, etc.
yaml_string (string, optional)
YAML string. Defaults to None.
uri (string, optional)
URI location of file containing YAML string. Defaults to None.
loader (callable, optional)
Custom YAML loader. Defaults to CLoader/SafeLoader.
kwargs (dict)
keyword arguments to be passed into fsspec.open().
For OCI object storage, this should be config="path/to/.oci/config".
For other storage connections consider e.g. host, port, username, password, etc.

Raises
------
Expand All @@ -288,10 +299,10 @@ def from_yaml(
Returns instance of the class
"""
if yaml_string:
return cls.from_dict(yaml.load(yaml_string, Loader=loader))
return cls.from_dict(yaml.load(yaml_string, Loader=loader), **kwargs)
if uri:
yaml_dict = yaml.load(cls._read_from_file(uri=uri, **kwargs), Loader=loader)
return cls.from_dict(yaml_dict)
return cls.from_dict(yaml_dict, **kwargs)
raise ValueError("Must provide either YAML string or URI location")

@classmethod
Expand Down Expand Up @@ -345,8 +356,8 @@ class DataClassSerializable(Serializable):
Returns an instance of the class instantiated from the dictionary provided.
"""

@staticmethod
def _validate_dict(obj_dict: Dict) -> bool:
@classmethod
def _validate_dict(cls, obj_dict: Dict) -> bool:
"""validate the dictionary.

Parameters
Expand Down Expand Up @@ -379,7 +390,7 @@ def to_dict(self, **kwargs) -> Dict:
obj_dict = dataclasses.asdict(self)
if "side_effect" in kwargs and kwargs["side_effect"]:
obj_dict = DataClassSerializable._normalize_dict(
obj_dict=obj_dict, case=kwargs["side_effect"]
obj_dict=obj_dict, case=kwargs["side_effect"], recursively=True
)
return obj_dict

Expand All @@ -388,6 +399,8 @@ def from_dict(
cls,
obj_dict: dict,
side_effect: Optional[SideEffect] = SideEffect.CONVERT_KEYS_TO_LOWER.value,
ignore_unknown: Optional[bool] = False,
**kwargs,
) -> "DataClassSerializable":
"""Returns an instance of the class instantiated by the dictionary provided.

Expand All @@ -399,6 +412,8 @@ def from_dict(
side effect to take on the dictionary. The side effect can be either
convert the dictionary keys to "lower" (SideEffect.CONVERT_KEYS_TO_LOWER.value)
or "upper"(SideEffect.CONVERT_KEYS_TO_UPPER.value) cases.
ignore_unknown: (bool, optional). Defaults to `False`.
Whether to ignore unknown fields or not.

Returns
-------
Expand All @@ -415,25 +430,36 @@ def from_dict(

allowed_fields = set([f.name for f in dataclasses.fields(cls)])
wrong_fields = set(obj_dict.keys()) - allowed_fields
if wrong_fields:
if wrong_fields and not ignore_unknown:
logger.warning(
f"The class {cls.__name__} doesn't contain attributes: `{list(wrong_fields)}`. "
"These fields will be ignored."
)

obj = cls(**{key: obj_dict[key] for key in allowed_fields})
obj = cls(**{key: obj_dict.get(key) for key in allowed_fields})

for key, value in obj_dict.items():
if isinstance(value, dict) and hasattr(
getattr(cls(), key).__class__, "from_dict"
if (
key in allowed_fields
and isinstance(value, dict)
and hasattr(getattr(cls(), key).__class__, "from_dict")
):
attribute = getattr(cls(), key).__class__.from_dict(value)
attribute = getattr(cls(), key).__class__.from_dict(
value,
ignore_unknown=ignore_unknown,
side_effect=side_effect,
**kwargs,
)
setattr(obj, key, attribute)

return obj

@staticmethod
def _normalize_dict(
obj_dict: Dict, case: str = SideEffect.CONVERT_KEYS_TO_LOWER.value
obj_dict: Dict,
recursively: bool = False,
case: str = SideEffect.CONVERT_KEYS_TO_LOWER.value,
**kwargs,
) -> Dict:
"""lower all the keys.

Expand All @@ -444,6 +470,8 @@ def _normalize_dict(
case: (optional, str). Defaults to "lower".
the case to normalized to. can be either "lower" (SideEffect.CONVERT_KEYS_TO_LOWER.value)
or "upper"(SideEffect.CONVERT_KEYS_TO_UPPER.value).
recursively: (bool, optional). Defaults to `False`.
Whether to recursively normalize the dictionary or not.

Returns
-------
Expand All @@ -452,12 +480,16 @@ def _normalize_dict(
"""
normalized_obj_dict = {}
for key, value in obj_dict.items():
if isinstance(value, dict):
if recursively and isinstance(value, dict):
value = DataClassSerializable._normalize_dict(
value, case=SideEffect.CONVERT_KEYS_TO_UPPER.value
value, case=case, recursively=recursively, **kwargs
)
normalized_obj_dict = DataClassSerializable._normalize_key(
normalized_obj_dict=normalized_obj_dict, key=key, value=value, case=case
normalized_obj_dict=normalized_obj_dict,
key=key,
value=value,
case=case,
**kwargs,
)
return normalized_obj_dict

Expand All @@ -467,7 +499,7 @@ def _normalize_key(
) -> Dict:
"""helper function to normalize the key in the case specified and add it back to the dictionary.

Paramaters
Parameters
----------
normalized_obj_dict: (Dict)
the dictionary to append the key and value to.
Expand All @@ -476,17 +508,18 @@ def _normalize_key(
value: (Union[str, Dict])
value to be added.
case: (str)
the case to normalized to. can be either "lower" (SideEffect.CONVERT_KEYS_TO_LOWER.value)
The case to normalized to. can be either "lower" (SideEffect.CONVERT_KEYS_TO_LOWER.value)
or "upper"(SideEffect.CONVERT_KEYS_TO_UPPER.value).

Raises
------
NotImplementedError: if case provided is not either "lower" or "upper".
NotImplementedError
Raised when `case` is not supported.

Returns
-------
Dict
normalized dictionary with the key and value added in the case specified.
Normalized dictionary with the key and value added in the case specified.
"""
if case.lower() == SideEffect.CONVERT_KEYS_TO_LOWER.value:
normalized_obj_dict[key.lower()] = value
Expand Down
Loading
Loading