diff --git a/ads/opctl/operator/cli.py b/ads/opctl/operator/cli.py index 318159dec..26e6c38e6 100644 --- a/ads/opctl/operator/cli.py +++ b/ads/opctl/operator/cli.py @@ -18,7 +18,7 @@ from ads.opctl.utils import suppress_traceback from .__init__ import __operators__ -from .cmd import apply as cmd_apply +from .cmd import run as cmd_run from .cmd import build_conda as cmd_build_conda from .cmd import build_image as cmd_build_image from .cmd import create as cmd_create @@ -289,6 +289,4 @@ def run(debug: bool, **kwargs: Dict[str, Any]) -> None: with fsspec.open(backend, "r", **auth) as f: backend = suppress_traceback(debug)(yaml.safe_load)(f.read()) - suppress_traceback(debug)(cmd_apply)( - config=operator_spec, backend=backend, **kwargs - ) + suppress_traceback(debug)(cmd_run)(config=operator_spec, backend=backend, **kwargs) diff --git a/ads/opctl/operator/cmd.py b/ads/opctl/operator/cmd.py index a97e72f34..9f7f2b095 100644 --- a/ads/opctl/operator/cmd.py +++ b/ads/opctl/operator/cmd.py @@ -685,7 +685,7 @@ def create( raise NotImplementedError() -def apply(config: Dict, backend: Union[Dict, str] = None, **kwargs) -> None: +def run(config: Dict, backend: Union[Dict, str] = None, **kwargs) -> None: """ Runs the operator with the given specification on the targeted backend. diff --git a/ads/opctl/operator/operator_config.py b/ads/opctl/operator/common/operator_config.py similarity index 97% rename from ads/opctl/operator/operator_config.py rename to ads/opctl/operator/common/operator_config.py index f1640913c..944984d20 100644 --- a/ads/opctl/operator/operator_config.py +++ b/ads/opctl/operator/common/operator_config.py @@ -12,7 +12,7 @@ from ads.common.serializer import DataClassSerializable -from .common.utils import OperatorValidator +from .utils import OperatorValidator @dataclass(repr=True) diff --git a/ads/opctl/operator/common/operator_loader.py b/ads/opctl/operator/common/operator_loader.py index 2c1eff964..cf95c7106 100644 --- a/ads/opctl/operator/common/operator_loader.py +++ b/ads/opctl/operator/common/operator_loader.py @@ -24,7 +24,7 @@ from ads.common.serializer import DataClassSerializable from ads.common.utils import copy_from_uri from ads.opctl import logger -from ads.opctl.constants import BACKEND_NAME, OPERATOR_MODULE_PATH +from ads.opctl.constants import OPERATOR_MODULE_PATH from ads.opctl.operator import __operators__ from .const import ARCH_TYPE, PACK_TYPE @@ -67,15 +67,15 @@ class OperatorInfo(DataClassSerializable): Generates the conda prefix for the custom conda pack. """ - type: str - gpu: bool - description: str - version: str - conda: str - conda_type: str - path: str - keywords: List[str] - backends: List[str] + type: str = "" + gpu: bool = False + description: str = "" + version: str = "" + conda: str = "" + conda_type: str = "" + path: str = "" + keywords: List[str] = None + backends: List[str] = None @property def conda_prefix(self) -> str: @@ -100,7 +100,7 @@ def conda_prefix(self) -> str: ) def __post_init__(self): - self.gpu = self.gpu == "yes" + self.gpu = self.gpu == True or self.gpu == "yes" self.version = self.version or "v1" self.conda_type = self.conda_type or PACK_TYPE.CUSTOM self.conda = self.conda or f"{self.type}_{self.version}" @@ -137,7 +137,9 @@ def from_yaml( obj: OperatorInfo = super().from_yaml( yaml_string=yaml_string, uri=uri, loader=loader, **kwargs ) - obj.path = os.path.dirname(uri) + + if uri: + obj.path = os.path.dirname(uri) return obj @@ -671,25 +673,6 @@ def _module_from_file(module_name: str, module_path: str) -> Any: return module -def _module_constant_values(module_name: str, module_path: str) -> Dict[str, Any]: - """Returns the list of constant variables from a given module. - - Parameters - ---------- - module_name (str) - The name of the module to be imported. - module_path (str) - The physical path of the module. - - Returns - ------- - Dict[str, Any] - Map of variable names and their values. - """ - module = _module_from_file(module_name, module_path) - return {name: value for name, value in vars(module).items()} - - def _operator_info(path: str = None, name: str = None) -> OperatorInfo: """ Extracts operator's details by given path. diff --git a/ads/opctl/operator/common/operator_schema.yaml b/ads/opctl/operator/common/operator_schema.yaml index f862c59b6..da30f4269 100644 --- a/ads/opctl/operator/common/operator_schema.yaml +++ b/ads/opctl/operator/common/operator_schema.yaml @@ -43,6 +43,6 @@ conda_type: default: custom allowed: - service - - custom + - published meta: description: "The operator's conda environment type. Can be either service or custom type." diff --git a/ads/opctl/operator/common/utils.py b/ads/opctl/operator/common/utils.py index 94480fd90..fc13d6d3c 100644 --- a/ads/opctl/operator/common/utils.py +++ b/ads/opctl/operator/common/utils.py @@ -12,11 +12,10 @@ import fsspec import yaml from cerberus import Validator -from yaml import SafeLoader from ads.opctl import logger from ads.opctl.operator import __operators__ -from ads.opctl.utils import run_command +from ads.opctl import utils CONTAINER_NETWORK = "CONTAINER_NETWORK" @@ -95,7 +94,7 @@ def _build_image( logger.info(f"Build image: {command}") - proc = run_command(command) + proc = utils.run_command(command) if proc.returncode != 0: raise RuntimeError("Docker build failed.") @@ -145,23 +144,6 @@ def _load_yaml_from_string(doc: str, **kwargs) -> Dict: ) -def _load_multi_document_yaml_from_string(doc: str, **kwargs) -> Dict: - """Loads multiline YAML from string and merge it with env variables and kwargs.""" - template_dict = {**os.environ, **kwargs} - return yaml.load_all( - Template(doc).substitute( - **template_dict, - ), - Loader=SafeLoader, - ) - - -def _load_multi_document_yaml_from_uri(uri: str, **kwargs) -> Dict: - """Loads multiline YAML from file and merge it with env variables and kwargs.""" - with fsspec.open(uri) as f: - return _load_multi_document_yaml_from_string(str(f.read(), "UTF-8"), **kwargs) - - def _load_yaml_from_uri(uri: str, **kwargs) -> str: """Loads YAML from the URI path. Can be Object Storage path.""" with fsspec.open(uri) as f: diff --git a/ads/opctl/operator/lowcode/forecast/MLoperator b/ads/opctl/operator/lowcode/forecast/MLoperator index e1b32afc3..bf22bae38 100644 --- a/ads/opctl/operator/lowcode/forecast/MLoperator +++ b/ads/opctl/operator/lowcode/forecast/MLoperator @@ -1,6 +1,6 @@ type: forecast version: v1 -conda_type: custom +conda_type: published gpu: no keywords: - Prophet diff --git a/ads/opctl/operator/lowcode/forecast/operator_config.py b/ads/opctl/operator/lowcode/forecast/operator_config.py index bcbc5fb20..f0e385f68 100644 --- a/ads/opctl/operator/lowcode/forecast/operator_config.py +++ b/ads/opctl/operator/lowcode/forecast/operator_config.py @@ -10,7 +10,7 @@ from ads.common.serializer import DataClassSerializable from ads.opctl.operator.common.utils import _load_yaml_from_uri -from ads.opctl.operator.operator_config import OperatorConfig +from ads.opctl.operator.common.operator_config import OperatorConfig from .const import SupportedMetrics @@ -103,7 +103,9 @@ def __post_init__(self): self.metric = (self.metric or "").lower() or SupportedMetrics.SMAPE.lower() self.confidence_interval_width = self.confidence_interval_width or 0.80 self.report_file_name = self.report_file_name or "report.html" - self.preprocessing = self.preprocessing if self.preprocessing is not None else True + self.preprocessing = ( + self.preprocessing if self.preprocessing is not None else True + ) self.report_theme = self.report_theme or "light" self.metrics_filename = self.metrics_filename or "metrics.csv" self.forecast_filename = self.forecast_filename or "forecast.csv" diff --git a/ads/opctl/operator/runtime/runtime.py b/ads/opctl/operator/runtime/runtime.py index ae22796f2..3342743ae 100644 --- a/ads/opctl/operator/runtime/runtime.py +++ b/ads/opctl/operator/runtime/runtime.py @@ -29,7 +29,7 @@ class OPERATOR_LOCAL_RUNTIME_TYPE(ExtendedEnum): class Runtime(DataClassSerializable): """Base class for the operator's runtimes.""" - _schema: ClassVar[str] = None + _schema: ClassVar[str] = "" kind: str = OPERATOR_LOCAL_RUNTIME_KIND type: str = None version: str = None diff --git a/tests/unitary/with_extras/operator/test_cmd.py b/tests/unitary/with_extras/operator/test_cmd.py index 6b9f30889..b59ceb40a 100644 --- a/tests/unitary/with_extras/operator/test_cmd.py +++ b/tests/unitary/with_extras/operator/test_cmd.py @@ -5,9 +5,49 @@ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -class TestCMD: +class TestOperatorCMD: + """Tests operator commands.""" + def test_list(self): """Ensures that the list of the registered operators can be printed.""" + pass def test_info(self): """Ensures that the detailed information about the particular operator can be printed.""" + pass + + def test_init_backend_config(self): + """Ensures the operator's backend configs can be generated.""" + pass + + def test_init_success(self): + """Ensures that a starter YAML configurations for the operator can be generated.""" + pass + + def test_init_fail(self): + """Ensures that generating starter specification fails in case of wrong input attributes.""" + pass + + def test_build_image(self): + """Ensures that operator's image can be successfully built.""" + pass + + def test_publish_image(self): + """Ensures that operator's image can be successfully published.""" + pass + + def test_verify(self): + """Ensures that operator's config can be successfully verified.""" + pass + + def test_build_conda(self): + """Ensures that the operator's conda environment can be successfully built.""" + pass + + def test_publish_conda(self): + """Ensures that the operator's conda environment can be successfully published.""" + pass + + def test_run(self): + """Ensures that the operator can be run on the targeted backend.""" + pass diff --git a/tests/unitary/with_extras/operator/test_common_utils.py b/tests/unitary/with_extras/operator/test_common_utils.py index 5c95d19f3..f296a420e 100644 --- a/tests/unitary/with_extras/operator/test_common_utils.py +++ b/tests/unitary/with_extras/operator/test_common_utils.py @@ -4,8 +4,103 @@ # Copyright (c) 2023 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import argparse +import os +import unittest +from unittest.mock import patch, MagicMock -class TestCommonUtils: - """Tests all common utils methods of the operators.""" +from ads.opctl.operator.common.utils import _build_image, _parse_input_args - pass + +class TestBuildImage(unittest.TestCase): + def setUp(self): + self.curr_dir = os.path.dirname(os.path.abspath(__file__)) + self.dockerfile = os.path.join(self.curr_dir, "test_files", "Dockerfile.test") + self.image_name = "test_image" + self.tag = "test_tag" + self.target = "test_target" + self.kwargs = {"arg1": "value1", "arg2": "value2"} + + @patch("ads.opctl.utils.run_command") + def test_build_image(self, mock_run_command): + mock_proc = MagicMock() + mock_proc.returncode = 0 + mock_run_command.return_value = mock_proc + + image_name = _build_image( + self.dockerfile, + self.image_name, + tag=self.tag, + target=self.target, + **self.kwargs, + ) + + expected_image_name = f"{self.image_name}:{self.tag}" + command = [ + "docker", + "build", + "-t", + expected_image_name, + "-f", + self.dockerfile, + "--target", + self.target, + os.path.dirname(self.dockerfile), + ] + + mock_run_command.assert_called_once_with(command) + self.assertEqual(image_name, expected_image_name) + + def test_build_image_missing_dockerfile(self): + with self.assertRaises(FileNotFoundError): + _build_image("non_existing_docker_file", "non_existing_image") + + def test_build_image_missing_image_name(self): + with self.assertRaises(ValueError): + _build_image(self.dockerfile, None) + + @patch("ads.opctl.utils.run_command") + def test_build_image_docker_build_failure(self, mock_run_command): + mock_proc = MagicMock() + mock_proc.returncode = 1 + mock_run_command.return_value = mock_proc + + with self.assertRaises(RuntimeError): + _build_image(self.dockerfile, self.image_name) + + +class TestParseInputArgs(unittest.TestCase): + def test_parse_input_args_with_file(self): + raw_args = ["-f", "path/to/file.yaml"] + expected_output = ( + argparse.Namespace(file="path/to/file.yaml", spec=None, verify=False), + [], + ) + self.assertEqual(_parse_input_args(raw_args), expected_output) + + def test_parse_input_args_with_spec(self): + raw_args = ["-s", "spec"] + expected_output = (argparse.Namespace(file=None, spec="spec", verify=False), []) + self.assertEqual(_parse_input_args(raw_args), expected_output) + + def test_parse_input_args_with_verify(self): + raw_args = ["-v", "True"] + expected_output = (argparse.Namespace(file=None, spec=None, verify=True), []) + self.assertEqual(_parse_input_args(raw_args), expected_output) + + def test_parse_input_args_with_unknown_args(self): + raw_args = ["-f", "path/to/file.yaml", "--unknown-arg", "value"] + expected_output = ( + argparse.Namespace(file="path/to/file.yaml", spec=None, verify=False), + ["--unknown-arg", "value"], + ) + self.assertEqual(_parse_input_args(raw_args), expected_output) + + @patch("argparse.ArgumentParser.parse_known_args") + def test_parse_input_args_with_no_args(self, mock_parse_known_args): + mock_parse_known_args.return_value = ( + argparse.Namespace(file=None, spec=None, verify=False), + [], + ) + expected_output = (argparse.Namespace(file=None, spec=None, verify=False), []) + self.assertEqual(_parse_input_args([]), expected_output) diff --git a/tests/unitary/with_extras/operator/test_files/Dockerfile.test b/tests/unitary/with_extras/operator/test_files/Dockerfile.test new file mode 100644 index 000000000..8049807ad --- /dev/null +++ b/tests/unitary/with_extras/operator/test_files/Dockerfile.test @@ -0,0 +1,2 @@ +FROM baseImage +RUN echo "This is a message" diff --git a/tests/unitary/with_extras/operator/test_operator.py b/tests/unitary/with_extras/operator/test_operator.py deleted file mode 100644 index ef2557f59..000000000 --- a/tests/unitary/with_extras/operator/test_operator.py +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8; -*- - -# Copyright (c) 2023 Oracle and/or its affiliates. -# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ - - -class TestOperator: - """Tests operator base class.""" - - pass diff --git a/tests/unitary/with_extras/operator/test_operator_backend.py b/tests/unitary/with_extras/operator/test_operator_backend.py new file mode 100644 index 000000000..1831be238 --- /dev/null +++ b/tests/unitary/with_extras/operator/test_operator_backend.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python +# -*- coding: utf-8; -*- + +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + + +class TestLocalOperatorBackend: + """Tests local operator backend.""" + + pass + + +class TestMLJobOperatorBackend: + """Tests backend class to run operator on Data Science Jobs.""" + + pass + + +class TestDataFlowOperatorBackend: + """Tests backend class to run operator on Data Flow Applications.""" + + pass diff --git a/tests/unitary/with_extras/operator/test_operator_config.py b/tests/unitary/with_extras/operator/test_operator_config.py index c1e8f9589..52158c774 100644 --- a/tests/unitary/with_extras/operator/test_operator_config.py +++ b/tests/unitary/with_extras/operator/test_operator_config.py @@ -5,7 +5,90 @@ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from dataclasses import dataclass + +import pytest +import yaml + +from ads.opctl.operator.common.operator_config import OperatorConfig + + class TestOperatorConfig: - """Tests operator config class""" + def test_operator_config(self): + # Test valid operator config + + @dataclass(repr=True) + class MyOperatorConfig(OperatorConfig): + @classmethod + def _load_schema(cls) -> str: + return yaml.safe_load( + """ + kind: + required: true + type: string + version: + required: true + type: string + type: + required: true + type: string + spec: + required: true + type: dict + schema: + foo: + required: false + type: string + """ + ) + + config = MyOperatorConfig.from_dict( + { + "kind": "operator", + "type": "my-operator", + "version": "v1", + "spec": {"foo": "bar"}, + } + ) + assert config.kind == "operator" + assert config.type == "my-operator" + assert config.version == "v1" + assert config.spec == {"foo": "bar"} + + # Test invalid operator config + @dataclass(repr=True) + class InvalidOperatorConfig(OperatorConfig): + @classmethod + def _load_schema(cls) -> str: + return yaml.safe_load( + """ + kind: + required: true + type: string + version: + required: true + type: string + allowed: + - v1 + type: + required: true + type: string + spec: + required: true + type: dict + schema: + foo: + required: true + type: string + """ + ) - pass + with pytest.raises(ValueError): + InvalidOperatorConfig.from_dict( + { + "kind": "operator", + "type": "invalid-operator", + "version": "v2", + "spec": {"foo1": 123}, + } + ) diff --git a/tests/unitary/with_extras/operator/test_operator_loader.py b/tests/unitary/with_extras/operator/test_operator_loader.py new file mode 100644 index 000000000..bbe71d87c --- /dev/null +++ b/tests/unitary/with_extras/operator/test_operator_loader.py @@ -0,0 +1,436 @@ +#!/usr/bin/env python +# -*- coding: utf-8; -*- + +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import os +import shutil +import unittest +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from ads.opctl.operator.common.const import ARCH_TYPE, PACK_TYPE +from ads.opctl.operator.common.operator_loader import ( + GitOperatorLoader, + LocalOperatorLoader, + OperatorInfo, + OperatorLoader, + RemoteOperatorLoader, + ServiceOperatorLoader, +) + + +class TestOperatorInfo(unittest.TestCase): + def test_construction(self): + operator_info = OperatorInfo( + type="example", + gpu="yes", + description="An example operator", + version="v1", + conda="example_v1", + conda_type=PACK_TYPE.CUSTOM, + path="/path/to/operator", + backends=["backend1", "backend2"], + ) + assert (operator_info.type, "example") + self.assertTrue(operator_info.gpu) + assert (operator_info.description, "An example operator") + assert (operator_info.version, "v1") + assert (operator_info.conda, "example_v1") + assert (operator_info.conda_type, PACK_TYPE.CUSTOM) + assert (operator_info.path, "/path/to/operator") + assert (operator_info.backends, ["backend1", "backend2"]) + + def test_conda_prefix(self): + operator_info = OperatorInfo( + type="example", + gpu="yes", + description="An example operator", + version="v1", + conda="example_v1", + conda_type=PACK_TYPE.CUSTOM, + path="/path/to/operator", + backends=["backend1", "backend2"], + ) + assert ( + operator_info.conda_prefix, + f"{ARCH_TYPE.GPU}/example/1/example_v1", + ) + + def test_conda_prefix_without_gpu(self): + operator_info = OperatorInfo( + type="example", + gpu="no", + description="An example operator", + version="v1", + conda="example_v1", + conda_type=PACK_TYPE.CUSTOM, + path="/path/to/operator", + backends=["backend1", "backend2"], + ) + assert ( + operator_info.conda_prefix, + f"{ARCH_TYPE.CPU}/example/1/example_v1", + ) + + def test_post_init(self): + operator_info = OperatorInfo( + type="example", + gpu="yes", # Should be converted to boolean + version="", # Should be set to "v1" + conda_type=None, # Should be set to PACK_TYPE.CUSTOM + conda=None, # Should be set to "example_v1" + ) + self.assertTrue(operator_info.gpu) + assert (operator_info.version, "v1") + assert (operator_info.conda_type, PACK_TYPE.CUSTOM) + assert (operator_info.conda, "example_v1") + + def test_from_yaml_with_yaml_string(self): + yaml_string = """ + type: example + gpu: yes + description: An example operator + version: v1 + conda_type: published + path: /path/to/operator + backends: + - backend1 + - backend2 + """ + operator_info = OperatorInfo.from_yaml(yaml_string=yaml_string) + assert (operator_info.type, "example") + self.assertTrue(operator_info.gpu) + assert (operator_info.description, "An example operator") + assert (operator_info.version, "v1") + assert (operator_info.conda, "example_v1") + assert (operator_info.conda_type, PACK_TYPE.CUSTOM) + assert (operator_info.path, "/path/to/operator") + assert (operator_info.backends, ["backend1", "backend2"]) + + @patch("ads.common.serializer.Serializable.from_yaml") + def test_from_yaml_with_uri(self, mock_from_yaml): + uri = "http://example.com/operator.yaml" + loader = MagicMock() + mock_from_yaml.return_value = OperatorInfo( + type="example", + gpu="yes", + description="An example operator", + version="v1", + conda="example_v1", + conda_type=PACK_TYPE.CUSTOM, + path="/path/to/operator", + backends=["backend1", "backend2"], + ) + operator_info = OperatorInfo.from_yaml(uri=uri, loader=loader) + mock_from_yaml.assert_called_with(yaml_string=None, uri=uri, loader=loader) + assert (operator_info.type, "example") + self.assertTrue(operator_info.gpu) + assert (operator_info.description, "An example operator") + assert (operator_info.version, "v1") + assert (operator_info.conda, "example_v1") + assert (operator_info.conda_type, PACK_TYPE.CUSTOM) + assert (operator_info.path, "http://example.com") + assert (operator_info.backends, ["backend1", "backend2"]) + + +class TestOperatorLoader: + def setup_method(self): + # Create a mock Loader instance for testing + self.loader = Mock() + self.operator_loader = OperatorLoader(self.loader) + + def test_load_operator(self): + # Define a mock OperatorInfo object to return when load is called on the loader + mock_operator_info = OperatorInfo( + type="mock_operator", + gpu=False, + description="Mock Operator", + version="v1", + conda="mock_operator_v1", + conda_type="custom", + path="/path/to/mock_operator", + backends=["cpu"], + ) + + # Mock the _load method to return the mock_operator_info object + self.loader.load.return_value = mock_operator_info + + # Call the load method of the OperatorLoader + operator_info = self.operator_loader.load() + + # Check if the returned OperatorInfo object matches the expected values + + assert operator_info.type == "mock_operator" + assert operator_info.gpu == False + assert operator_info.description == "Mock Operator" + assert operator_info.version == "v1" + assert operator_info.conda == "mock_operator_v1" + assert operator_info.conda_type == "custom" + assert operator_info.path == "/path/to/mock_operator" + assert operator_info.backends == ["cpu"] + + def test_load_operator_exception(self): + # Mock the _load method to raise an exception + self.loader.load.side_effect = Exception("Error loading operator") + + # Call the load method of the OperatorLoader and expect an exception + with pytest.raises(Exception): + self.operator_loader.load() + + @pytest.mark.parametrize( + "test_name, uri, expected_result", + [ + ("Service Path", "forecast", ServiceOperatorLoader), + ("Local Path", "/path/to/local_operator", LocalOperatorLoader), + ("OCI Path", "oci://bucket/operator.zip", RemoteOperatorLoader), + ( + "Git Path", + "https://github.com/my-operator-repository", + GitOperatorLoader, + ), + ], + ) + def test_from_uri(self, test_name, uri, expected_result): + # Call the from_uri method of the OperatorLoader class + operator_loader = OperatorLoader.from_uri(uri=uri) + assert isinstance(operator_loader.loader, expected_result) + + def test_empty_uri(self): + # Test with an empty URI that should raise a ValueError + with pytest.raises(ValueError): + OperatorLoader.from_uri(uri="", uri_dst=None) + + def test_invalid_uri(self): + # Test with an invalid URI that should raise a ValueError + with pytest.raises(ValueError): + OperatorLoader.from_uri(uri="aws://", uri_dst=None) + + +class TestServiceOperatorLoader(unittest.TestCase): + def setUp(self): + # Create a mock ServiceOperatorLoader instance for testing + self.loader = ServiceOperatorLoader(uri="mock_service_operator") + self.mock_operator_info = OperatorInfo( + type="mock_operator", + gpu="no", + description="Mock Operator", + version="v1", + conda="mock_operator_v1", + conda_type="custom", + path="/path/to/mock_operator", + backends=["cpu"], + ) + + def test_compatible(self): + # Test the compatible method with a valid URI + uri = "forecast" + self.assertTrue(ServiceOperatorLoader.compatible(uri=uri)) + + # Test the compatible method with an invalid URI + uri = "invalid_service_operator" + self.assertFalse(ServiceOperatorLoader.compatible(uri=uri)) + + def test_load(self): + # Mock the _load method to return the mock_operator_info object + self.loader._load = Mock(return_value=self.mock_operator_info) + + # Call the load method of the ServiceOperatorLoader + operator_info = self.loader.load() + + # Check if the returned OperatorInfo object matches the expected values + self.assertEqual(operator_info.type, "mock_operator") + self.assertEqual(operator_info.gpu, False) + self.assertEqual(operator_info.description, "Mock Operator") + self.assertEqual(operator_info.version, "v1") + self.assertEqual(operator_info.conda, "mock_operator_v1") + self.assertEqual(operator_info.conda_type, "custom") + self.assertEqual(operator_info.path, "/path/to/mock_operator") + self.assertEqual(operator_info.backends, ["cpu"]) + + def test_load_exception(self): + # Mock the _load method to raise an exception + self.loader._load = Mock( + side_effect=Exception("Error loading service operator") + ) + + # Call the load method of the ServiceOperatorLoader and expect an exception + with self.assertRaises(Exception): + self.loader.load() + + +class TestLocalOperatorLoader(unittest.TestCase): + def setUp(self): + # Create a mock LocalOperatorLoader instance for testing + self.loader = LocalOperatorLoader(uri="path/to/local/operator") + self.mock_operator_info = OperatorInfo( + type="mock_operator", + gpu=False, + description="Mock Operator", + version="v1", + conda="mock_operator_v1", + conda_type="custom", + path="/path/to/mock_operator", + backends=["cpu"], + ) + + def test_compatible(self): + # Test the compatible method with a valid URI + uri = "path/to/local/operator" + self.assertTrue(LocalOperatorLoader.compatible(uri=uri)) + + # Test the compatible method with an invalid URI + uri = "http://example.com/remote/operator" + self.assertFalse(LocalOperatorLoader.compatible(uri=uri)) + + def test_load(self): + # Mock the _load method to return the mock_operator_info object + self.loader._load = Mock(return_value=self.mock_operator_info) + + # Call the load method of the LocalOperatorLoader + operator_info = self.loader.load() + + # Check if the returned OperatorInfo object matches the expected values + self.assertEqual(operator_info.type, "mock_operator") + self.assertEqual(operator_info.gpu, False) + self.assertEqual(operator_info.description, "Mock Operator") + self.assertEqual(operator_info.version, "v1") + self.assertEqual(operator_info.conda, "mock_operator_v1") + self.assertEqual(operator_info.conda_type, "custom") + self.assertEqual(operator_info.path, "/path/to/mock_operator") + self.assertEqual(operator_info.backends, ["cpu"]) + + def test_load_exception(self): + # Mock the _load method to raise an exception + self.loader._load = Mock(side_effect=Exception("Error loading local operator")) + + # Call the load method of the LocalOperatorLoader and expect an exception + with self.assertRaises(Exception): + self.loader.load() + + +class TestRemoteOperatorLoader(unittest.TestCase): + def setUp(self): + # Create a mock RemoteOperatorLoader instance for testing + self.loader = RemoteOperatorLoader(uri="oci://bucket/operator.zip") + self.mock_operator_info = OperatorInfo( + type="mock_operator", + gpu=False, + description="Mock Operator", + version="v1", + conda="mock_operator_v1", + conda_type="custom", + path="/path/to/mock_operator", + backends=["cpu"], + ) + + def test_compatible(self): + # Test the compatible method with a valid URI + uri = "oci://bucket/operator.zip" + self.assertTrue(RemoteOperatorLoader.compatible(uri=uri)) + + # Test the compatible method with an invalid URI + uri = "http://example.com/remote/operator" + self.assertFalse(RemoteOperatorLoader.compatible(uri=uri)) + + def test_load(self): + # Mock the _load method to return the mock_operator_info object + self.loader._load = Mock(return_value=self.mock_operator_info) + + # Call the load method of the RemoteOperatorLoader + operator_info = self.loader.load() + + # Check if the returned OperatorInfo object matches the expected values + self.assertEqual(operator_info.type, "mock_operator") + self.assertEqual(operator_info.gpu, False) + self.assertEqual(operator_info.description, "Mock Operator") + self.assertEqual(operator_info.version, "v1") + self.assertEqual(operator_info.conda, "mock_operator_v1") + self.assertEqual(operator_info.conda_type, "custom") + self.assertEqual(operator_info.path, "/path/to/mock_operator") + self.assertEqual(operator_info.backends, ["cpu"]) + + def test_load_exception(self): + # Mock the _load method to raise an exception + self.loader._load = Mock(side_effect=Exception("Error loading remote operator")) + + # Call the load method of the RemoteOperatorLoader and expect an exception + with self.assertRaises(Exception): + self.loader.load() + + +class TestGitOperatorLoader(unittest.TestCase): + def setUp(self): + # Create a temporary directory for testing + self.temp_dir = "temp_git_loader" + os.makedirs(self.temp_dir, exist_ok=True) + + # Create a mock GitOperatorLoader instance for testing + self.loader = GitOperatorLoader( + uri="https://github.com/mock_operator_repository.git@feature-branch#forecasting", + uri_dst=self.temp_dir, + ) + self.mock_operator_info = OperatorInfo( + type="mock_operator", + gpu=False, + description="Mock Operator", + version="v1", + conda="mock_operator_v1", + conda_type="custom", + path=os.path.join(self.temp_dir, "forecasting"), + backends=["cpu"], + ) + + def tearDown(self): + # Clean up the temporary directory + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_compatible(self): + # Test the compatible method with a valid URI + uri = ( + "https://github.com/mock_operator_repository.git@feature-branch#forecasting" + ) + self.assertTrue(GitOperatorLoader.compatible(uri=uri)) + + # Test the compatible method with an invalid URI + uri = "http://example.com/remote/operator" + self.assertFalse(GitOperatorLoader.compatible(uri=uri)) + + def test_load(self): + # Mock the git.Repo.clone_from method to avoid actual Git operations + with patch("git.Repo.clone_from") as mock_clone_from: + mock_clone_from.return_value = Mock() + + # Mock the _load method to return the mock_operator_info object + self.loader._load = Mock(return_value=self.mock_operator_info) + + # Call the load method of the GitOperatorLoader + operator_info = self.loader.load() + + # Check if the returned OperatorInfo object matches the expected values + self.assertEqual(operator_info.type, "mock_operator") + self.assertEqual(operator_info.gpu, False) + self.assertEqual(operator_info.description, "Mock Operator") + self.assertEqual(operator_info.version, "v1") + self.assertEqual(operator_info.conda, "mock_operator_v1") + self.assertEqual(operator_info.conda_type, "custom") + self.assertEqual( + operator_info.path, os.path.join(self.temp_dir, "forecasting") + ) + self.assertEqual(operator_info.backends, ["cpu"]) + + def test_load_exception(self): + # Mock the git.Repo.clone_from method to raise an exception + with patch( + "git.Repo.clone_from", side_effect=Exception("Error cloning Git repository") + ): + # Mock the _load method to raise an exception + self.loader._load = Mock( + side_effect=Exception("Error loading Git operator") + ) + + # Call the load method of the GitOperatorLoader and expect an exception + with self.assertRaises(Exception): + self.loader.load() diff --git a/tests/unitary/with_extras/operator/test_operator_yaml_generator.py b/tests/unitary/with_extras/operator/test_operator_yaml_generator.py new file mode 100644 index 000000000..354e6fc88 --- /dev/null +++ b/tests/unitary/with_extras/operator/test_operator_yaml_generator.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +# -*- coding: utf-8; -*- + +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import pytest +import yaml + +from ads.opctl.operator.common.operator_yaml_generator import YamlGenerator + + +class TestOperatorYamlGenerator: + """Tests class for generating the YAML config based on the given YAML schema.""" + + @pytest.mark.parametrize( + "schema, values, expected_result", + [ + # Test case: Basic schema with required and default values + ( + { + "key1": { + "type": "string", + "default": "test_value", + "required": True, + }, + "key2": {"type": "number", "default": 42, "required": True}, + }, + {}, + {"key1": "test_value", "key2": 42}, + ), + # Test case: Basic schema with required and default values + ( + { + "key1": {"type": "string", "required": True}, + "key2": {"type": "number", "default": 42, "required": True}, + }, + {"key1": "test_value"}, + {"key1": "test_value", "key2": 42}, + ), + # Test case: Basic schema with required and default values + ( + { + "key1": {"type": "string", "required": True}, + "key2": {"type": "number", "default": 42}, + }, + {"key1": "test_value"}, + {"key1": "test_value"}, + ), + # Test case: Schema with dependencies + ( + { + "model": {"type": "string", "required": True, "default": "prophet"}, + "owner_name": { + "type": "string", + "dependencies": {"model": "prophet"}, + }, + }, + {"owner_name": "value"}, + {"model": "prophet", "owner_name": "value"}, + ), + # Test case: Schema with dependencies + ( + { + "model": {"type": "string", "required": True, "default": "prophet"}, + "owner_name": { + "type": "string", + "dependencies": {"model": "prophet1"}, + }, + }, + {"owner_name": "value"}, + {"model": "prophet"}, + ), + ], + ) + def test_generate_example(self, schema, values, expected_result): + yaml_generator = YamlGenerator(schema=schema) + yaml_config = yaml_generator.generate_example(values) + assert yaml_config == yaml.dump(expected_result) diff --git a/tests/unitary/with_extras/operator/test_runtime.py b/tests/unitary/with_extras/operator/test_runtime.py index 396ca303e..6e3f8d7b3 100644 --- a/tests/unitary/with_extras/operator/test_runtime.py +++ b/tests/unitary/with_extras/operator/test_runtime.py @@ -4,8 +4,97 @@ # Copyright (c) 2023 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import unittest +from unittest.mock import patch, MagicMock -class TestContainerRuntime: - """Tests operator container local runtime.""" +from ads.opctl.operator.runtime.runtime import ( + ContainerRuntime, + ContainerRuntimeSpec, + PythonRuntime, + Runtime, + OPERATOR_LOCAL_RUNTIME_TYPE, +) - pass + +class TestRuntime(unittest.TestCase): + def setUp(self): + self.runtime = Runtime() + + def test_kind(self): + self.assertEqual(self.runtime.kind, "operator.local") + + def test_type(self): + self.assertIsNone(self.runtime.type) + + def test_version(self): + self.assertIsNone(self.runtime.version) + + @patch("ads.opctl.operator.runtime.runtime._load_yaml_from_uri") + @patch("ads.opctl.operator.runtime.runtime.Validator") + def test_validate_dict(self, mock_validator, mock_load_yaml): + mock_validator.return_value.validate.return_value = True + self.assertTrue(Runtime._validate_dict({})) + mock_load_yaml.assert_called_once() + mock_validator.assert_called_once() + + @patch("ads.opctl.operator.runtime.runtime._load_yaml_from_uri") + @patch("ads.opctl.operator.runtime.runtime.Validator") + def test_validate_dict_invalid(self, mock_validator, mock_load_yaml): + mock_validator.return_value = MagicMock( + errors=[{"error": "error"}], validate=MagicMock(return_value=False) + ) + mock_validator.return_value.validate.return_value = False + with self.assertRaises(ValueError): + Runtime._validate_dict({}) + mock_load_yaml.assert_called_once() + mock_validator.assert_called_once() + + +class TestContainerRuntime(unittest.TestCase): + def test_init(self): + runtime = ContainerRuntime.init( + image="my-image", + env=[{"name": "VAR1", "value": "value1"}], + volume=["/data"], + ) + self.assertIsInstance(runtime, ContainerRuntime) + self.assertEqual(runtime.type, OPERATOR_LOCAL_RUNTIME_TYPE.CONTAINER.value) + self.assertEqual(runtime.version, "v1") + self.assertIsInstance(runtime.spec, ContainerRuntimeSpec) + self.assertEqual(runtime.spec.image, "my-image") + self.assertEqual(runtime.spec.env, [{"name": "VAR1", "value": "value1"}]) + self.assertEqual(runtime.spec.volume, ["/data"]) + + def test_validate_dict(self): + valid_dict = { + "kind": "operator.local", + "type": "container", + "version": "v1", + "spec": { + "image": "my-image", + "env": [{"name": "VAR1", "value": "value1"}], + "volume": ["/data"], + }, + } + self.assertTrue(ContainerRuntime._validate_dict(valid_dict)) + + invalid_dict = { + "kind": "operator.local", + "type": "unknown", + "version": "v1", + "spec": { + "image": "my-image", + "env": [{"name": "VAR1"}], + "volume": ["/data"], + }, + } + with self.assertRaises(ValueError): + ContainerRuntime._validate_dict(invalid_dict) + + +class TestPythonRuntime(unittest.TestCase): + def test_init(self): + runtime = PythonRuntime.init() + self.assertIsInstance(runtime, PythonRuntime) + self.assertEqual(runtime.type, "python") + self.assertEqual(runtime.version, "v1")