Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix integration tests to run with instance_principal #339

Merged
merged 15 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions ads/jobs/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
import yaml
from ads.common.auth import default_signer

# Special type to represent the current enclosed class.
# This type is used by factory class method or when a method returns ``self``.
Self = TypeVar("Self", bound="Serializable")
"""Special type to represent the current enclosed class.

This type is used by factory class method or when a method returns ``self``.
"""


class Serializable(ABC):
Expand Down Expand Up @@ -72,6 +70,14 @@ def _write_to_file(s: str, uri: str, **kwargs) -> None:
"if you wish to overwrite."
)

# Add default signer if the uri is an object storage uri, and
# the user does not specify config or signer.
if (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to check on the "oci://"?

Copy link
Member Author

@liudmylaru liudmylaru Nov 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copied lines 94-99 (https://github.com/oracle/accelerated-data-science/blob/main/ads/jobs/serializer.py#L94-L99) from the same file, I think QQ told me he forgot to add this code here also. So my answer to your question - I do not know :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the uri does not start with oci://, the config and signer are not needed. Also we don't want to send the credentials to non-oci server (which could be a security issue).

uri.startswith("oci://")
and "config" not in kwargs
and "signer" not in kwargs
):
kwargs.update(default_signer())
with fsspec.open(uri, "w", **kwargs) as f:
f.write(s)

Expand Down
19 changes: 11 additions & 8 deletions ads/opctl/config/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _fill_config_with_defaults(self, ads_config_path: str) -> None:
else:
self.config["execution"]["auth"] = AuthType.API_KEY
# determine profile
if self.config["execution"]["auth"] == AuthType.RESOURCE_PRINCIPAL:
if self.config["execution"]["auth"] != AuthType.API_KEY:
profile = self.config["execution"]["auth"].upper()
exec_config.pop("oci_profile", None)
self.config["execution"]["oci_profile"] = None
Expand Down Expand Up @@ -202,20 +202,23 @@ def _get_service_config(self, oci_profile: str, ads_config_folder: str) -> Dict:
def _config_flex_shape_details(self):
infrastructure = self.config["infrastructure"]
backend = self.config["execution"].get("backend", None)
if backend == BACKEND_NAME.JOB.value or backend == BACKEND_NAME.MODEL_DEPLOYMENT.value:
if (
backend == BACKEND_NAME.JOB.value
or backend == BACKEND_NAME.MODEL_DEPLOYMENT.value
):
shape_name = infrastructure.get("shape_name", "")
if shape_name.endswith(".Flex"):
if (
"ocpus" not in infrastructure or
"memory_in_gbs" not in infrastructure
"ocpus" not in infrastructure
or "memory_in_gbs" not in infrastructure
):
raise ValueError(
"Parameters `ocpus` and `memory_in_gbs` must be provided for using flex shape. "
"Call `ads opctl config` to specify."
)
infrastructure["shape_config_details"] = {
"ocpus": infrastructure.pop("ocpus"),
"memory_in_gbs": infrastructure.pop("memory_in_gbs")
"memory_in_gbs": infrastructure.pop("memory_in_gbs"),
}
elif backend == BACKEND_NAME.DATAFLOW.value:
executor_shape = infrastructure.get("executor_shape", "")
Expand All @@ -224,7 +227,7 @@ def _config_flex_shape_details(self):
"driver_shape_memory_in_gbs",
"driver_shape_ocpus",
"executor_shape_memory_in_gbs",
"executor_shape_ocpus"
"executor_shape_ocpus",
]
# executor_shape and driver_shape must be the same shape family
if executor_shape.endswith(".Flex") or driver_shape.endswith(".Flex"):
Expand All @@ -236,9 +239,9 @@ def _config_flex_shape_details(self):
)
infrastructure["driver_shape_config"] = {
"ocpus": infrastructure.pop("driver_shape_ocpus"),
"memory_in_gbs": infrastructure.pop("driver_shape_memory_in_gbs")
"memory_in_gbs": infrastructure.pop("driver_shape_memory_in_gbs"),
}
infrastructure["executor_shape_config"] = {
"ocpus": infrastructure.pop("executor_shape_ocpus"),
"memory_in_gbs": infrastructure.pop("executor_shape_memory_in_gbs")
"memory_in_gbs": infrastructure.pop("executor_shape_memory_in_gbs"),
}
7 changes: 3 additions & 4 deletions ads/opctl/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2022 Oracle and/or its affiliates.
# Copyright (c) 2022, 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


Expand Down Expand Up @@ -88,9 +88,8 @@ def get_namespace(auth: dict) -> str:


def get_region_key(auth: dict) -> str:
if len(auth["config"]) > 0:
tenancy = auth["config"]["tenancy"]
else:
tenancy = auth["config"].get("tenancy")
if not tenancy:
tenancy = auth["signer"].tenancy_id
client = OCIClientFactory(**auth).identity
return client.get_tenancy(tenancy).data.home_region_key
Expand Down
5 changes: 2 additions & 3 deletions tests/integration/jobs/test_dsc_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,10 @@ def assert_job_creation(self, job, expected_infra_spec, expected_runtime_spec):
random.seed(threading.get_ident() + os.getpid())
random_suffix = "".join(random.choices(string.ascii_uppercase, k=6))
yaml_uri = f"oci://{self.BUCKET}@{self.NAMESPACE}/tests/{timestamp}/example_job_{random_suffix}.yaml"
config_path = "~/.oci/config"
job.to_yaml(uri=yaml_uri, config=config_path)
job.to_yaml(uri=yaml_uri)
print(f"Job YAML saved to {yaml_uri}")
try:
job = Job.from_yaml(uri=yaml_uri, config=config_path)
job = Job.from_yaml(uri=yaml_uri)
except Exception:
self.fail(f"Failed to load job from YAML\n{traceback.format_exc()}")

Expand Down
64 changes: 57 additions & 7 deletions tests/integration/jobs/test_jobs_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,38 +7,88 @@

from click.testing import CliRunner

from ads.common.auth import AuthType
from ads.jobs.cli import run, watch, delete


class TestJobsCLI:
# TeamCity will use Instance Principal, when running locally - set OCI_IAM_TYPE to security_token
auth = os.environ.get("OCI_IAM_TYPE", AuthType.INSTANCE_PRINCIPAL)

def test_create_watch_delete_job(self):
curr_dir = os.path.dirname(os.path.abspath(__file__))
runner = CliRunner()
res = runner.invoke(
run, args=["-f", os.path.join(curr_dir, "../yamls", "sample_job.yaml")]
run,
args=[
"-f",
os.path.join(curr_dir, "../yamls", "sample_job.yaml"),
"--auth",
self.auth,
],
)
assert res.exit_code == 0, res.output
run_id = res.output.split("\n")[1]
res2 = runner.invoke(watch, args=[run_id])
res2 = runner.invoke(
watch,
args=[
run_id,
"--auth",
self.auth,
],
)
assert res2.exit_code == 0, res2.output

res3 = runner.invoke(delete, args=[run_id])
res3 = runner.invoke(
delete,
args=[
run_id,
"--auth",
self.auth,
],
)
assert res3.exit_code == 0, res3.output

def test_create_watch_delete_dataflow(self):
curr_dir = os.path.dirname(os.path.abspath(__file__))
runner = CliRunner()
res = runner.invoke(
run, args=["-f", os.path.join(curr_dir, "../yamls", "sample_dataflow.yaml")]
run,
args=[
"-f",
os.path.join(curr_dir, "../yamls", "sample_dataflow.yaml"),
"--auth",
self.auth,
],
)
assert res.exit_code == 0, res.output
run_id = res.output.split("\n")[1]
res2 = runner.invoke(watch, args=[run_id])
res2 = runner.invoke(
watch,
args=[
run_id,
"--auth",
self.auth,
],
)
assert res2.exit_code == 0, res2.output

res3 = runner.invoke(
run, args=["-f", os.path.join(curr_dir, "../yamls", "sample_dataflow.yaml")]
run,
args=[
"-f",
os.path.join(curr_dir, "../yamls", "sample_dataflow.yaml"),
"--auth",
self.auth,
],
)
run_id2 = res3.output.split("\n")[1]
res4 = runner.invoke(delete, args=[run_id2])
res4 = runner.invoke(
delete,
args=[
run_id2,
"--auth",
self.auth,
],
)
assert res4.exit_code == 0, res4.output
5 changes: 2 additions & 3 deletions tests/integration/jobs/test_jobs_notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tempfile

import fsspec
from ads.common.auth import default_signer, AuthType
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
NotebookRuntimeHandler,
)
Expand Down Expand Up @@ -64,9 +65,7 @@ def run_notebook(
# Clear the files in output URI
try:
# Ignore the error for unit tests.
fs = fsspec.filesystem(
"oci", config=os.path.expanduser("~/.oci/config")
)
fs = fsspec.filesystem("oci", **default_signer())
if fs.find(output_uri):
fs.rm(output_uri, recursive=True)
except:
Expand Down
18 changes: 13 additions & 5 deletions tests/integration/jobs/test_jobs_notebook_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import json
import pytest
import os
import tempfile
from zipfile import ZipFile

import fsspec

from ads.common.auth import default_signer
from tests.integration.config import secrets
from tests.integration.jobs.test_dsc_job import DSCJobTestCaseWithCleanUp
from tests.integration.jobs.test_jobs_notebook import NotebookDriverRunTest
Expand All @@ -19,7 +20,9 @@


class NotebookRuntimeTest(DSCJobTestCaseWithCleanUp):
NOTEBOOK_PATH = os.path.join(os.path.dirname(__file__), "../fixtures/ads_check.ipynb")
NOTEBOOK_PATH = os.path.join(
os.path.dirname(__file__), "../fixtures/ads_check.ipynb"
)
NOTEBOOK_PATH_EXCLUDE = os.path.join(
os.path.dirname(__file__), "../fixtures/exclude_check.ipynb"
)
Expand Down Expand Up @@ -86,10 +89,15 @@ def test_create_job_with_notebook(self):


class NotebookDriverIntegrationTest(NotebookDriverRunTest):
@pytest.mark.skip(
reason="api_keys not an option anymore, this test is candidate to be removed"
)
def test_notebook_driver_with_outputs(self):
"""Tests run the notebook driver with a notebook plotting and saving data."""
# Notebook to be executed
notebook_path = os.path.join(os.path.dirname(__file__), "../fixtures/plot.ipynb")
notebook_path = os.path.join(
os.path.dirname(__file__), "../fixtures/plot.ipynb"
)
# Object storage output location
output_uri = f"oci://{secrets.jobs.BUCKET_B}@{secrets.common.NAMESPACE}/notebook_driver_int_test/plot/"
# Run the notebook with driver and check the logs
Expand All @@ -100,7 +108,7 @@ def test_notebook_driver_with_outputs(self):
# Check the notebook saved to object storage.
with fsspec.open(
os.path.join(output_uri, os.path.basename(notebook_path)),
config=os.path.expanduser("~/.oci/config"),
**default_signer(),
) as f:
outputs = [cell.get("outputs") for cell in json.load(f).get("cells")]
# There should be 7 cells in the notebook
Expand All @@ -113,7 +121,7 @@ def test_notebook_driver_with_outputs(self):
# Check the JSON output file from the notebook
with fsspec.open(
os.path.join(output_uri, "data.json"),
config=os.path.expanduser("~/.oci/config"),
**default_signer(),
) as f:
data = json.load(f)
# There should be 10 data points
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/jobs/test_jobs_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def job_run_test_infra(self):
@staticmethod
def list_objects(uri: str) -> list:
"""Lists objects on OCI object storage."""
oci_os = fsspec.filesystem("oci", config=oci.config.from_file())
oci_os = fsspec.filesystem("oci", **default_signer())
if uri.startswith("oci://"):
uri = uri[len("oci://") :]
items = oci_os.ls(uri, detail=False, refresh=True)
Expand All @@ -126,7 +126,7 @@ def list_objects(uri: str) -> list:
@staticmethod
def remove_objects(uri: str):
"""Removes objects from OCI object storage."""
oci_os = fsspec.filesystem("oci", config=oci.config.from_file())
oci_os = fsspec.filesystem("oci", **default_signer())
try:
oci_os.rm(uri, recursive=True)
except FileNotFoundError:
Expand Down
12 changes: 11 additions & 1 deletion tests/integration/opctl/test_opctl_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
)
ADS_CONFIG_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")

if "TEAMCITY_VERSION" in os.environ:
# When running in TeamCity we specify dir, which is CHECKOUT_DIR="%teamcity.build.checkoutDir%"
WORK_DIR = os.getenv("CHECKOUT_DIR", "~")
CONDA_PACK_FOLDER = f"{WORK_DIR}/conda"
else:
CONDA_PACK_FOLDER = "~/conda"


def _assert_run_command(cmd_str, expected_outputs: list = None):
runner = CliRunner()
Expand Down Expand Up @@ -48,7 +55,7 @@ class TestLocalRunsWithConda:
# For tests, we can always run the command in debug mode (-d)
# By default, pytest only print the logs if the test is failed,
# in which case we would like to see the debug logs.
CMD_OPTIONS = "-d -b local "
CMD_OPTIONS = f"-d -b local --conda-pack-folder {CONDA_PACK_FOLDER} "

def test_hello_world(self):
test_folder = os.path.join(TESTS_FILES_DIR, "hello_world_test")
Expand Down Expand Up @@ -79,6 +86,9 @@ def test_linear_reg_test(self):
]
_assert_run_command(cmd, expected_outputs)

@pytest.mark.skip(
reason="spark do not support instance principal - this test candidate to remove"
)
def test_spark_run(self):
test_folder = os.path.join(TESTS_FILES_DIR, "spark_test")
cmd = (
Expand Down
Loading