Skip to content

Commit

Permalink
ODSC-47592. Fix test_jobs_python_runtime.py
Browse files Browse the repository at this point in the history
- updated driver_utils.py to work with Instance Principals
- added default_signer() into test_notebook_driver_with_outputs()
  • Loading branch information
liudmylaru committed Nov 7, 2023
1 parent c4e5e7f commit 969e15d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 17 deletions.
29 changes: 17 additions & 12 deletions ads/jobs/templates/driver_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
CONST_ENV_PIP_REQ = "OCI__PIP_REQUIREMENTS"
CONST_ENV_PIP_PKG = "OCI__PIP_PKG"
CONST_API_KEY = "api_key"
CONST_INSTANCE_PRINCIPAL = "instance_principal"


DEFAULT_CODE_DIR = os.path.join(
Expand Down Expand Up @@ -81,17 +82,28 @@ class OCIHelper:

@staticmethod
def init_oci_client(client_class):
"""Initializes OCI client with API key or Resource Principal.
"""Initializes OCI client with API key, Resource Principal or Instance Principal.
Parameters
----------
client_class :
The class of OCI client to be initialized.
"""
if (
os.environ.get(CONST_ENV_ADS_IAM, "").lower() == CONST_API_KEY
or CONST_ENV_OCI_RP not in os.environ
):
if CONST_ENV_OCI_RP in os.environ:
logger.info(
"Initializing %s with Resource Principal...", client_class.__name__
)
client = client_class(
{}, signer=oci.auth.signers.get_resource_principals_signer()
)
elif os.environ.get(CONST_ENV_ADS_IAM).lower() == CONST_INSTANCE_PRINCIPAL:
logger.info(
"Initializing %s with Instance Principal...", {client_class.__name__}
)
client = client_class(
{}, signer=oci.auth.signers.InstancePrincipalsSecurityTokenSigner()
)
else:
logger.info("Initializing %s with API Key...", {client_class.__name__})
client = client_class(
oci.config.from_file(
Expand All @@ -103,13 +115,6 @@ def init_oci_client(client_class):
),
)
)
else:
logger.info(
"Initializing %s with Resource Principal...", client_class.__name__
)
client = client_class(
{}, signer=oci.auth.signers.get_resource_principals_signer()
)
return client

@staticmethod
Expand Down
14 changes: 9 additions & 5 deletions tests/integration/jobs/test_jobs_notebook_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from zipfile import ZipFile

import fsspec

from ads.common.auth import default_signer
from tests.integration.config import secrets
from tests.integration.jobs.test_dsc_job import DSCJobTestCaseWithCleanUp
from tests.integration.jobs.test_jobs_notebook import NotebookDriverRunTest
Expand All @@ -19,7 +19,9 @@


class NotebookRuntimeTest(DSCJobTestCaseWithCleanUp):
NOTEBOOK_PATH = os.path.join(os.path.dirname(__file__), "../fixtures/ads_check.ipynb")
NOTEBOOK_PATH = os.path.join(
os.path.dirname(__file__), "../fixtures/ads_check.ipynb"
)
NOTEBOOK_PATH_EXCLUDE = os.path.join(
os.path.dirname(__file__), "../fixtures/exclude_check.ipynb"
)
Expand Down Expand Up @@ -89,7 +91,9 @@ class NotebookDriverIntegrationTest(NotebookDriverRunTest):
def test_notebook_driver_with_outputs(self):
"""Tests run the notebook driver with a notebook plotting and saving data."""
# Notebook to be executed
notebook_path = os.path.join(os.path.dirname(__file__), "../fixtures/plot.ipynb")
notebook_path = os.path.join(
os.path.dirname(__file__), "../fixtures/plot.ipynb"
)
# Object storage output location
output_uri = f"oci://{secrets.jobs.BUCKET_B}@{secrets.common.NAMESPACE}/notebook_driver_int_test/plot/"
# Run the notebook with driver and check the logs
Expand All @@ -100,7 +104,7 @@ def test_notebook_driver_with_outputs(self):
# Check the notebook saved to object storage.
with fsspec.open(
os.path.join(output_uri, os.path.basename(notebook_path)),
config=os.path.expanduser("~/.oci/config"),
**default_signer(),
) as f:
outputs = [cell.get("outputs") for cell in json.load(f).get("cells")]
# There should be 7 cells in the notebook
Expand All @@ -113,7 +117,7 @@ def test_notebook_driver_with_outputs(self):
# Check the JSON output file from the notebook
with fsspec.open(
os.path.join(output_uri, "data.json"),
config=os.path.expanduser("~/.oci/config"),
**default_signer(),
) as f:
data = json.load(f)
# There should be 10 data points
Expand Down

0 comments on commit 969e15d

Please sign in to comment.