Skip to content

Commit

Permalink
Fix integration tests to run with instance_principal (#339)
Browse files Browse the repository at this point in the history
  • Loading branch information
liudmylaru authored Nov 14, 2023
1 parent 6b75c58 commit c8ed767
Show file tree
Hide file tree
Showing 18 changed files with 156 additions and 84 deletions.
14 changes: 10 additions & 4 deletions ads/jobs/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
import yaml
from ads.common.auth import default_signer

# Special type to represent the current enclosed class.
# This type is used by factory class method or when a method returns ``self``.
Self = TypeVar("Self", bound="Serializable")
"""Special type to represent the current enclosed class.
This type is used by factory class method or when a method returns ``self``.
"""


class Serializable(ABC):
Expand Down Expand Up @@ -72,6 +70,14 @@ def _write_to_file(s: str, uri: str, **kwargs) -> None:
"if you wish to overwrite."
)

# Add default signer if the uri is an object storage uri, and
# the user does not specify config or signer.
if (
uri.startswith("oci://")
and "config" not in kwargs
and "signer" not in kwargs
):
kwargs.update(default_signer())
with fsspec.open(uri, "w", **kwargs) as f:
f.write(s)

Expand Down
19 changes: 11 additions & 8 deletions ads/opctl/config/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _fill_config_with_defaults(self, ads_config_path: str) -> None:
else:
self.config["execution"]["auth"] = AuthType.API_KEY
# determine profile
if self.config["execution"]["auth"] == AuthType.RESOURCE_PRINCIPAL:
if self.config["execution"]["auth"] != AuthType.API_KEY:
profile = self.config["execution"]["auth"].upper()
exec_config.pop("oci_profile", None)
self.config["execution"]["oci_profile"] = None
Expand Down Expand Up @@ -202,20 +202,23 @@ def _get_service_config(self, oci_profile: str, ads_config_folder: str) -> Dict:
def _config_flex_shape_details(self):
infrastructure = self.config["infrastructure"]
backend = self.config["execution"].get("backend", None)
if backend == BACKEND_NAME.JOB.value or backend == BACKEND_NAME.MODEL_DEPLOYMENT.value:
if (
backend == BACKEND_NAME.JOB.value
or backend == BACKEND_NAME.MODEL_DEPLOYMENT.value
):
shape_name = infrastructure.get("shape_name", "")
if shape_name.endswith(".Flex"):
if (
"ocpus" not in infrastructure or
"memory_in_gbs" not in infrastructure
"ocpus" not in infrastructure
or "memory_in_gbs" not in infrastructure
):
raise ValueError(
"Parameters `ocpus` and `memory_in_gbs` must be provided for using flex shape. "
"Call `ads opctl config` to specify."
)
infrastructure["shape_config_details"] = {
"ocpus": infrastructure.pop("ocpus"),
"memory_in_gbs": infrastructure.pop("memory_in_gbs")
"memory_in_gbs": infrastructure.pop("memory_in_gbs"),
}
elif backend == BACKEND_NAME.DATAFLOW.value:
executor_shape = infrastructure.get("executor_shape", "")
Expand All @@ -224,7 +227,7 @@ def _config_flex_shape_details(self):
"driver_shape_memory_in_gbs",
"driver_shape_ocpus",
"executor_shape_memory_in_gbs",
"executor_shape_ocpus"
"executor_shape_ocpus",
]
# executor_shape and driver_shape must be the same shape family
if executor_shape.endswith(".Flex") or driver_shape.endswith(".Flex"):
Expand All @@ -236,9 +239,9 @@ def _config_flex_shape_details(self):
)
infrastructure["driver_shape_config"] = {
"ocpus": infrastructure.pop("driver_shape_ocpus"),
"memory_in_gbs": infrastructure.pop("driver_shape_memory_in_gbs")
"memory_in_gbs": infrastructure.pop("driver_shape_memory_in_gbs"),
}
infrastructure["executor_shape_config"] = {
"ocpus": infrastructure.pop("executor_shape_ocpus"),
"memory_in_gbs": infrastructure.pop("executor_shape_memory_in_gbs")
"memory_in_gbs": infrastructure.pop("executor_shape_memory_in_gbs"),
}
7 changes: 3 additions & 4 deletions ads/opctl/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2022 Oracle and/or its affiliates.
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


Expand Down Expand Up @@ -88,9 +88,8 @@ def get_namespace(auth: dict) -> str:


def get_region_key(auth: dict) -> str:
if len(auth["config"]) > 0:
tenancy = auth["config"]["tenancy"]
else:
tenancy = auth["config"].get("tenancy")
if not tenancy:
tenancy = auth["signer"].tenancy_id
client = OCIClientFactory(**auth).identity
return client.get_tenancy(tenancy).data.home_region_key
Expand Down
5 changes: 2 additions & 3 deletions tests/integration/jobs/test_dsc_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,10 @@ def assert_job_creation(self, job, expected_infra_spec, expected_runtime_spec):
random.seed(threading.get_ident() + os.getpid())
random_suffix = "".join(random.choices(string.ascii_uppercase, k=6))
yaml_uri = f"oci://{self.BUCKET}@{self.NAMESPACE}/tests/{timestamp}/example_job_{random_suffix}.yaml"
config_path = "~/.oci/config"
job.to_yaml(uri=yaml_uri, config=config_path)
job.to_yaml(uri=yaml_uri)
print(f"Job YAML saved to {yaml_uri}")
try:
job = Job.from_yaml(uri=yaml_uri, config=config_path)
job = Job.from_yaml(uri=yaml_uri)
except Exception:
self.fail(f"Failed to load job from YAML\n{traceback.format_exc()}")

Expand Down
64 changes: 57 additions & 7 deletions tests/integration/jobs/test_jobs_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,88 @@

from click.testing import CliRunner

from ads.common.auth import AuthType
from ads.jobs.cli import run, watch, delete


class TestJobsCLI:
# TeamCity will use Instance Principal, when running locally - set OCI_IAM_TYPE to security_token
auth = os.environ.get("OCI_IAM_TYPE", AuthType.INSTANCE_PRINCIPAL)

def test_create_watch_delete_job(self):
curr_dir = os.path.dirname(os.path.abspath(__file__))
runner = CliRunner()
res = runner.invoke(
run, args=["-f", os.path.join(curr_dir, "../yamls", "sample_job.yaml")]
run,
args=[
"-f",
os.path.join(curr_dir, "../yamls", "sample_job.yaml"),
"--auth",
self.auth,
],
)
assert res.exit_code == 0, res.output
run_id = res.output.split("\n")[1]
res2 = runner.invoke(watch, args=[run_id])
res2 = runner.invoke(
watch,
args=[
run_id,
"--auth",
self.auth,
],
)
assert res2.exit_code == 0, res2.output

res3 = runner.invoke(delete, args=[run_id])
res3 = runner.invoke(
delete,
args=[
run_id,
"--auth",
self.auth,
],
)
assert res3.exit_code == 0, res3.output

def test_create_watch_delete_dataflow(self):
curr_dir = os.path.dirname(os.path.abspath(__file__))
runner = CliRunner()
res = runner.invoke(
run, args=["-f", os.path.join(curr_dir, "../yamls", "sample_dataflow.yaml")]
run,
args=[
"-f",
os.path.join(curr_dir, "../yamls", "sample_dataflow.yaml"),
"--auth",
self.auth,
],
)
assert res.exit_code == 0, res.output
run_id = res.output.split("\n")[1]
res2 = runner.invoke(watch, args=[run_id])
res2 = runner.invoke(
watch,
args=[
run_id,
"--auth",
self.auth,
],
)
assert res2.exit_code == 0, res2.output

res3 = runner.invoke(
run, args=["-f", os.path.join(curr_dir, "../yamls", "sample_dataflow.yaml")]
run,
args=[
"-f",
os.path.join(curr_dir, "../yamls", "sample_dataflow.yaml"),
"--auth",
self.auth,
],
)
run_id2 = res3.output.split("\n")[1]
res4 = runner.invoke(delete, args=[run_id2])
res4 = runner.invoke(
delete,
args=[
run_id2,
"--auth",
self.auth,
],
)
assert res4.exit_code == 0, res4.output
5 changes: 2 additions & 3 deletions tests/integration/jobs/test_jobs_notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tempfile

import fsspec
from ads.common.auth import default_signer, AuthType
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
NotebookRuntimeHandler,
)
Expand Down Expand Up @@ -64,9 +65,7 @@ def run_notebook(
# Clear the files in output URI
try:
# Ignore the error for unit tests.
fs = fsspec.filesystem(
"oci", config=os.path.expanduser("~/.oci/config")
)
fs = fsspec.filesystem("oci", **default_signer())
if fs.find(output_uri):
fs.rm(output_uri, recursive=True)
except:
Expand Down
18 changes: 13 additions & 5 deletions tests/integration/jobs/test_jobs_notebook_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import json
import pytest
import os
import tempfile
from zipfile import ZipFile

import fsspec

from ads.common.auth import default_signer
from tests.integration.config import secrets
from tests.integration.jobs.test_dsc_job import DSCJobTestCaseWithCleanUp
from tests.integration.jobs.test_jobs_notebook import NotebookDriverRunTest
Expand All @@ -19,7 +20,9 @@


class NotebookRuntimeTest(DSCJobTestCaseWithCleanUp):
NOTEBOOK_PATH = os.path.join(os.path.dirname(__file__), "../fixtures/ads_check.ipynb")
NOTEBOOK_PATH = os.path.join(
os.path.dirname(__file__), "../fixtures/ads_check.ipynb"
)
NOTEBOOK_PATH_EXCLUDE = os.path.join(
os.path.dirname(__file__), "../fixtures/exclude_check.ipynb"
)
Expand Down Expand Up @@ -86,10 +89,15 @@ def test_create_job_with_notebook(self):


class NotebookDriverIntegrationTest(NotebookDriverRunTest):
@pytest.mark.skip(
reason="api_keys not an option anymore, this test is candidate to be removed"
)
def test_notebook_driver_with_outputs(self):
"""Tests run the notebook driver with a notebook plotting and saving data."""
# Notebook to be executed
notebook_path = os.path.join(os.path.dirname(__file__), "../fixtures/plot.ipynb")
notebook_path = os.path.join(
os.path.dirname(__file__), "../fixtures/plot.ipynb"
)
# Object storage output location
output_uri = f"oci://{secrets.jobs.BUCKET_B}@{secrets.common.NAMESPACE}/notebook_driver_int_test/plot/"
# Run the notebook with driver and check the logs
Expand All @@ -100,7 +108,7 @@ def test_notebook_driver_with_outputs(self):
# Check the notebook saved to object storage.
with fsspec.open(
os.path.join(output_uri, os.path.basename(notebook_path)),
config=os.path.expanduser("~/.oci/config"),
**default_signer(),
) as f:
outputs = [cell.get("outputs") for cell in json.load(f).get("cells")]
# There should be 7 cells in the notebook
Expand All @@ -113,7 +121,7 @@ def test_notebook_driver_with_outputs(self):
# Check the JSON output file from the notebook
with fsspec.open(
os.path.join(output_uri, "data.json"),
config=os.path.expanduser("~/.oci/config"),
**default_signer(),
) as f:
data = json.load(f)
# There should be 10 data points
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/jobs/test_jobs_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def job_run_test_infra(self):
@staticmethod
def list_objects(uri: str) -> list:
"""Lists objects on OCI object storage."""
oci_os = fsspec.filesystem("oci", config=oci.config.from_file())
oci_os = fsspec.filesystem("oci", **default_signer())
if uri.startswith("oci://"):
uri = uri[len("oci://") :]
items = oci_os.ls(uri, detail=False, refresh=True)
Expand All @@ -126,7 +126,7 @@ def list_objects(uri: str) -> list:
@staticmethod
def remove_objects(uri: str):
"""Removes objects from OCI object storage."""
oci_os = fsspec.filesystem("oci", config=oci.config.from_file())
oci_os = fsspec.filesystem("oci", **default_signer())
try:
oci_os.rm(uri, recursive=True)
except FileNotFoundError:
Expand Down
12 changes: 11 additions & 1 deletion tests/integration/opctl/test_opctl_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
)
ADS_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")

if "TEAMCITY_VERSION" in os.environ:
# When running in TeamCity we specify dir, which is CHECKOUT_DIR="%teamcity.build.checkoutDir%"
WORK_DIR = os.getenv("CHECKOUT_DIR", "~")
CONDA_PACK_FOLDER = f"{WORK_DIR}/conda"
else:
CONDA_PACK_FOLDER = "~/conda"


def _assert_run_command(cmd_str, expected_outputs: list = None):
runner = CliRunner()
Expand Down Expand Up @@ -48,7 +55,7 @@ class TestLocalRunsWithConda:
# For tests, we can always run the command in debug mode (-d)
# By default, pytest only print the logs if the test is failed,
# in which case we would like to see the debug logs.
CMD_OPTIONS = "-d -b local "
CMD_OPTIONS = f"-d -b local --conda-pack-folder {CONDA_PACK_FOLDER} "

def test_hello_world(self):
test_folder = os.path.join(TESTS_FILES_DIR, "hello_world_test")
Expand Down Expand Up @@ -79,6 +86,9 @@ def test_linear_reg_test(self):
]
_assert_run_command(cmd, expected_outputs)

@pytest.mark.skip(
reason="spark do not support instance principal - this test candidate to remove"
)
def test_spark_run(self):
test_folder = os.path.join(TESTS_FILES_DIR, "spark_test")
cmd = (
Expand Down
Loading

0 comments on commit c8ed767

Please sign in to comment.