From f70d6bb550204bf9c900eb1728681ade8d887817 Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Wed, 4 Oct 2023 12:25:53 -0700 Subject: [PATCH] Adds tests for the operator loader. --- ads/opctl/operator/common/operator_loader.py | 43 +- .../operator/common/operator_schema.yaml | 2 +- ads/opctl/operator/common/utils.py | 5 +- .../operator/lowcode/forecast/MLoperator | 2 +- ads/opctl/operator/runtime/runtime.py | 2 +- .../with_extras/operator/test_common_utils.py | 106 +++- .../operator/test_files/Dockerfile.test | 2 + .../operator/test_operator_config.py | 88 +++- .../operator/test_operator_loader.py | 469 +++++++++++++++--- .../with_extras/operator/test_runtime.py | 123 +++-- 10 files changed, 692 insertions(+), 150 deletions(-) create mode 100644 tests/unitary/with_extras/operator/test_files/Dockerfile.test diff --git a/ads/opctl/operator/common/operator_loader.py b/ads/opctl/operator/common/operator_loader.py index 0f7a4a72d..cf95c7106 100644 --- a/ads/opctl/operator/common/operator_loader.py +++ b/ads/opctl/operator/common/operator_loader.py @@ -67,15 +67,15 @@ class OperatorInfo(DataClassSerializable): Generates the conda prefix for the custom conda pack. """ - type: str - gpu: bool - description: str - version: str - conda: str - conda_type: str - path: str - keywords: List[str] - backends: List[str] + type: str = "" + gpu: bool = False + description: str = "" + version: str = "" + conda: str = "" + conda_type: str = "" + path: str = "" + keywords: List[str] = None + backends: List[str] = None @property def conda_prefix(self) -> str: @@ -100,7 +100,7 @@ def conda_prefix(self) -> str: ) def __post_init__(self): - self.gpu = self.gpu == "yes" + self.gpu = self.gpu == True or self.gpu == "yes" self.version = self.version or "v1" self.conda_type = self.conda_type or PACK_TYPE.CUSTOM self.conda = self.conda or f"{self.type}_{self.version}" @@ -137,7 +137,9 @@ def from_yaml( obj: OperatorInfo = super().from_yaml( yaml_string=yaml_string, uri=uri, loader=loader, **kwargs ) - obj.path = os.path.dirname(uri) + + if uri: + obj.path = os.path.dirname(uri) return obj @@ -671,25 +673,6 @@ def _module_from_file(module_name: str, module_path: str) -> Any: return module -def _module_constant_values(module_name: str, module_path: str) -> Dict[str, Any]: - """Returns the list of constant variables from a given module. - - Parameters - ---------- - module_name (str) - The name of the module to be imported. - module_path (str) - The physical path of the module. - - Returns - ------- - Dict[str, Any] - Map of variable names and their values. - """ - module = _module_from_file(module_name, module_path) - return {name: value for name, value in vars(module).items()} - - def _operator_info(path: str = None, name: str = None) -> OperatorInfo: """ Extracts operator's details by given path. diff --git a/ads/opctl/operator/common/operator_schema.yaml b/ads/opctl/operator/common/operator_schema.yaml index f862c59b6..da30f4269 100644 --- a/ads/opctl/operator/common/operator_schema.yaml +++ b/ads/opctl/operator/common/operator_schema.yaml @@ -43,6 +43,6 @@ conda_type: default: custom allowed: - service - - custom + - published meta: description: "The operator's conda environment type. Can be either service or custom type." diff --git a/ads/opctl/operator/common/utils.py b/ads/opctl/operator/common/utils.py index 64738d803..fc13d6d3c 100644 --- a/ads/opctl/operator/common/utils.py +++ b/ads/opctl/operator/common/utils.py @@ -12,11 +12,10 @@ import fsspec import yaml from cerberus import Validator -from yaml import SafeLoader from ads.opctl import logger from ads.opctl.operator import __operators__ -from ads.opctl.utils import run_command +from ads.opctl import utils CONTAINER_NETWORK = "CONTAINER_NETWORK" @@ -95,7 +94,7 @@ def _build_image( logger.info(f"Build image: {command}") - proc = run_command(command) + proc = utils.run_command(command) if proc.returncode != 0: raise RuntimeError("Docker build failed.") diff --git a/ads/opctl/operator/lowcode/forecast/MLoperator b/ads/opctl/operator/lowcode/forecast/MLoperator index e1b32afc3..bf22bae38 100644 --- a/ads/opctl/operator/lowcode/forecast/MLoperator +++ b/ads/opctl/operator/lowcode/forecast/MLoperator @@ -1,6 +1,6 @@ type: forecast version: v1 -conda_type: custom +conda_type: published gpu: no keywords: - Prophet diff --git a/ads/opctl/operator/runtime/runtime.py b/ads/opctl/operator/runtime/runtime.py index ae22796f2..3342743ae 100644 --- a/ads/opctl/operator/runtime/runtime.py +++ b/ads/opctl/operator/runtime/runtime.py @@ -29,7 +29,7 @@ class OPERATOR_LOCAL_RUNTIME_TYPE(ExtendedEnum): class Runtime(DataClassSerializable): """Base class for the operator's runtimes.""" - _schema: ClassVar[str] = None + _schema: ClassVar[str] = "" kind: str = OPERATOR_LOCAL_RUNTIME_KIND type: str = None version: str = None diff --git a/tests/unitary/with_extras/operator/test_common_utils.py b/tests/unitary/with_extras/operator/test_common_utils.py index d95695986..f296a420e 100644 --- a/tests/unitary/with_extras/operator/test_common_utils.py +++ b/tests/unitary/with_extras/operator/test_common_utils.py @@ -4,21 +4,103 @@ # Copyright (c) 2023 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import argparse +import os +import unittest +from unittest.mock import patch, MagicMock -class TestCommonUtils: - """Tests all operator's common utils methods.""" +from ads.opctl.operator.common.utils import _build_image, _parse_input_args - def test_build_image_success(self): - pass - def test_build_image_fail(self): - pass +class TestBuildImage(unittest.TestCase): + def setUp(self): + self.curr_dir = os.path.dirname(os.path.abspath(__file__)) + self.dockerfile = os.path.join(self.curr_dir, "test_files", "Dockerfile.test") + self.image_name = "test_image" + self.tag = "test_tag" + self.target = "test_target" + self.kwargs = {"arg1": "value1", "arg2": "value2"} - def test_load_yaml_from_string(self): - pass + @patch("ads.opctl.utils.run_command") + def test_build_image(self, mock_run_command): + mock_proc = MagicMock() + mock_proc.returncode = 0 + mock_run_command.return_value = mock_proc - def test_load_yaml_from_uri(self): - pass + image_name = _build_image( + self.dockerfile, + self.image_name, + tag=self.tag, + target=self.target, + **self.kwargs, + ) - def test_parse_input_args(self): - pass + expected_image_name = f"{self.image_name}:{self.tag}" + command = [ + "docker", + "build", + "-t", + expected_image_name, + "-f", + self.dockerfile, + "--target", + self.target, + os.path.dirname(self.dockerfile), + ] + + mock_run_command.assert_called_once_with(command) + self.assertEqual(image_name, expected_image_name) + + def test_build_image_missing_dockerfile(self): + with self.assertRaises(FileNotFoundError): + _build_image("non_existing_docker_file", "non_existing_image") + + def test_build_image_missing_image_name(self): + with self.assertRaises(ValueError): + _build_image(self.dockerfile, None) + + @patch("ads.opctl.utils.run_command") + def test_build_image_docker_build_failure(self, mock_run_command): + mock_proc = MagicMock() + mock_proc.returncode = 1 + mock_run_command.return_value = mock_proc + + with self.assertRaises(RuntimeError): + _build_image(self.dockerfile, self.image_name) + + +class TestParseInputArgs(unittest.TestCase): + def test_parse_input_args_with_file(self): + raw_args = ["-f", "path/to/file.yaml"] + expected_output = ( + argparse.Namespace(file="path/to/file.yaml", spec=None, verify=False), + [], + ) + self.assertEqual(_parse_input_args(raw_args), expected_output) + + def test_parse_input_args_with_spec(self): + raw_args = ["-s", "spec"] + expected_output = (argparse.Namespace(file=None, spec="spec", verify=False), []) + self.assertEqual(_parse_input_args(raw_args), expected_output) + + def test_parse_input_args_with_verify(self): + raw_args = ["-v", "True"] + expected_output = (argparse.Namespace(file=None, spec=None, verify=True), []) + self.assertEqual(_parse_input_args(raw_args), expected_output) + + def test_parse_input_args_with_unknown_args(self): + raw_args = ["-f", "path/to/file.yaml", "--unknown-arg", "value"] + expected_output = ( + argparse.Namespace(file="path/to/file.yaml", spec=None, verify=False), + ["--unknown-arg", "value"], + ) + self.assertEqual(_parse_input_args(raw_args), expected_output) + + @patch("argparse.ArgumentParser.parse_known_args") + def test_parse_input_args_with_no_args(self, mock_parse_known_args): + mock_parse_known_args.return_value = ( + argparse.Namespace(file=None, spec=None, verify=False), + [], + ) + expected_output = (argparse.Namespace(file=None, spec=None, verify=False), []) + self.assertEqual(_parse_input_args([]), expected_output) diff --git a/tests/unitary/with_extras/operator/test_files/Dockerfile.test b/tests/unitary/with_extras/operator/test_files/Dockerfile.test new file mode 100644 index 000000000..8049807ad --- /dev/null +++ b/tests/unitary/with_extras/operator/test_files/Dockerfile.test @@ -0,0 +1,2 @@ +FROM baseImage +RUN echo "This is a message" diff --git a/tests/unitary/with_extras/operator/test_operator_config.py b/tests/unitary/with_extras/operator/test_operator_config.py index 2a0909d4c..52158c774 100644 --- a/tests/unitary/with_extras/operator/test_operator_config.py +++ b/tests/unitary/with_extras/operator/test_operator_config.py @@ -5,8 +5,90 @@ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from dataclasses import dataclass + +import pytest +import yaml + +from ads.opctl.operator.common.operator_config import OperatorConfig + + class TestOperatorConfig: - """Tests base class representing operator config.""" + def test_operator_config(self): + # Test valid operator config + + @dataclass(repr=True) + class MyOperatorConfig(OperatorConfig): + @classmethod + def _load_schema(cls) -> str: + return yaml.safe_load( + """ + kind: + required: true + type: string + version: + required: true + type: string + type: + required: true + type: string + spec: + required: true + type: dict + schema: + foo: + required: false + type: string + """ + ) + + config = MyOperatorConfig.from_dict( + { + "kind": "operator", + "type": "my-operator", + "version": "v1", + "spec": {"foo": "bar"}, + } + ) + assert config.kind == "operator" + assert config.type == "my-operator" + assert config.version == "v1" + assert config.spec == {"foo": "bar"} + + # Test invalid operator config + @dataclass(repr=True) + class InvalidOperatorConfig(OperatorConfig): + @classmethod + def _load_schema(cls) -> str: + return yaml.safe_load( + """ + kind: + required: true + type: string + version: + required: true + type: string + allowed: + - v1 + type: + required: true + type: string + spec: + required: true + type: dict + schema: + foo: + required: true + type: string + """ + ) - def test_validate_dict(self): - pass + with pytest.raises(ValueError): + InvalidOperatorConfig.from_dict( + { + "kind": "operator", + "type": "invalid-operator", + "version": "v2", + "spec": {"foo1": 123}, + } + ) diff --git a/tests/unitary/with_extras/operator/test_operator_loader.py b/tests/unitary/with_extras/operator/test_operator_loader.py index b2005665d..bbe71d87c 100644 --- a/tests/unitary/with_extras/operator/test_operator_loader.py +++ b/tests/unitary/with_extras/operator/test_operator_loader.py @@ -4,84 +4,433 @@ # Copyright (c) 2023 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ - -class TestCommonUtils: - """Tests common utils in the operator_loader module.""" - - def test_operator_info(self): - pass - - def test_operator_info_list(self): - pass - - -class TestOperatorInfo: - """Tests the class representing brief information about the operator.""" - - def test_init(self): - pass +import os +import shutil +import unittest +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from ads.opctl.operator.common.const import ARCH_TYPE, PACK_TYPE +from ads.opctl.operator.common.operator_loader import ( + GitOperatorLoader, + LocalOperatorLoader, + OperatorInfo, + OperatorLoader, + RemoteOperatorLoader, + ServiceOperatorLoader, +) + + +class TestOperatorInfo(unittest.TestCase): + def test_construction(self): + operator_info = OperatorInfo( + type="example", + gpu="yes", + description="An example operator", + version="v1", + conda="example_v1", + conda_type=PACK_TYPE.CUSTOM, + path="/path/to/operator", + backends=["backend1", "backend2"], + ) + assert (operator_info.type, "example") + self.assertTrue(operator_info.gpu) + assert (operator_info.description, "An example operator") + assert (operator_info.version, "v1") + assert (operator_info.conda, "example_v1") + assert (operator_info.conda_type, PACK_TYPE.CUSTOM) + assert (operator_info.path, "/path/to/operator") + assert (operator_info.backends, ["backend1", "backend2"]) def test_conda_prefix(self): - pass - - def test_from_yaml(self): - pass + operator_info = OperatorInfo( + type="example", + gpu="yes", + description="An example operator", + version="v1", + conda="example_v1", + conda_type=PACK_TYPE.CUSTOM, + path="/path/to/operator", + backends=["backend1", "backend2"], + ) + assert ( + operator_info.conda_prefix, + f"{ARCH_TYPE.GPU}/example/1/example_v1", + ) + + def test_conda_prefix_without_gpu(self): + operator_info = OperatorInfo( + type="example", + gpu="no", + description="An example operator", + version="v1", + conda="example_v1", + conda_type=PACK_TYPE.CUSTOM, + path="/path/to/operator", + backends=["backend1", "backend2"], + ) + assert ( + operator_info.conda_prefix, + f"{ARCH_TYPE.CPU}/example/1/example_v1", + ) + + def test_post_init(self): + operator_info = OperatorInfo( + type="example", + gpu="yes", # Should be converted to boolean + version="", # Should be set to "v1" + conda_type=None, # Should be set to PACK_TYPE.CUSTOM + conda=None, # Should be set to "example_v1" + ) + self.assertTrue(operator_info.gpu) + assert (operator_info.version, "v1") + assert (operator_info.conda_type, PACK_TYPE.CUSTOM) + assert (operator_info.conda, "example_v1") + + def test_from_yaml_with_yaml_string(self): + yaml_string = """ + type: example + gpu: yes + description: An example operator + version: v1 + conda_type: published + path: /path/to/operator + backends: + - backend1 + - backend2 + """ + operator_info = OperatorInfo.from_yaml(yaml_string=yaml_string) + assert (operator_info.type, "example") + self.assertTrue(operator_info.gpu) + assert (operator_info.description, "An example operator") + assert (operator_info.version, "v1") + assert (operator_info.conda, "example_v1") + assert (operator_info.conda_type, PACK_TYPE.CUSTOM) + assert (operator_info.path, "/path/to/operator") + assert (operator_info.backends, ["backend1", "backend2"]) + + @patch("ads.common.serializer.Serializable.from_yaml") + def test_from_yaml_with_uri(self, mock_from_yaml): + uri = "http://example.com/operator.yaml" + loader = MagicMock() + mock_from_yaml.return_value = OperatorInfo( + type="example", + gpu="yes", + description="An example operator", + version="v1", + conda="example_v1", + conda_type=PACK_TYPE.CUSTOM, + path="/path/to/operator", + backends=["backend1", "backend2"], + ) + operator_info = OperatorInfo.from_yaml(uri=uri, loader=loader) + mock_from_yaml.assert_called_with(yaml_string=None, uri=uri, loader=loader) + assert (operator_info.type, "example") + self.assertTrue(operator_info.gpu) + assert (operator_info.description, "An example operator") + assert (operator_info.version, "v1") + assert (operator_info.conda, "example_v1") + assert (operator_info.conda_type, PACK_TYPE.CUSTOM) + assert (operator_info.path, "http://example.com") + assert (operator_info.backends, ["backend1", "backend2"]) class TestOperatorLoader: - """Tests operator loader class""" - - def test_init(self): - pass - - def test_from_uri_success(self): - pass - - def test_from_uri_fail(self): - pass - - -class TestServiceOperatorLoader: - """Tests class to load a service operator.""" + def setup_method(self): + # Create a mock Loader instance for testing + self.loader = Mock() + self.operator_loader = OperatorLoader(self.loader) + + def test_load_operator(self): + # Define a mock OperatorInfo object to return when load is called on the loader + mock_operator_info = OperatorInfo( + type="mock_operator", + gpu=False, + description="Mock Operator", + version="v1", + conda="mock_operator_v1", + conda_type="custom", + path="/path/to/mock_operator", + backends=["cpu"], + ) + + # Mock the _load method to return the mock_operator_info object + self.loader.load.return_value = mock_operator_info + + # Call the load method of the OperatorLoader + operator_info = self.operator_loader.load() + + # Check if the returned OperatorInfo object matches the expected values + + assert operator_info.type == "mock_operator" + assert operator_info.gpu == False + assert operator_info.description == "Mock Operator" + assert operator_info.version == "v1" + assert operator_info.conda == "mock_operator_v1" + assert operator_info.conda_type == "custom" + assert operator_info.path == "/path/to/mock_operator" + assert operator_info.backends == ["cpu"] + + def test_load_operator_exception(self): + # Mock the _load method to raise an exception + self.loader.load.side_effect = Exception("Error loading operator") + + # Call the load method of the OperatorLoader and expect an exception + with pytest.raises(Exception): + self.operator_loader.load() + + @pytest.mark.parametrize( + "test_name, uri, expected_result", + [ + ("Service Path", "forecast", ServiceOperatorLoader), + ("Local Path", "/path/to/local_operator", LocalOperatorLoader), + ("OCI Path", "oci://bucket/operator.zip", RemoteOperatorLoader), + ( + "Git Path", + "https://github.com/my-operator-repository", + GitOperatorLoader, + ), + ], + ) + def test_from_uri(self, test_name, uri, expected_result): + # Call the from_uri method of the OperatorLoader class + operator_loader = OperatorLoader.from_uri(uri=uri) + assert isinstance(operator_loader.loader, expected_result) + + def test_empty_uri(self): + # Test with an empty URI that should raise a ValueError + with pytest.raises(ValueError): + OperatorLoader.from_uri(uri="", uri_dst=None) + + def test_invalid_uri(self): + # Test with an invalid URI that should raise a ValueError + with pytest.raises(ValueError): + OperatorLoader.from_uri(uri="aws://", uri_dst=None) + + +class TestServiceOperatorLoader(unittest.TestCase): + def setUp(self): + # Create a mock ServiceOperatorLoader instance for testing + self.loader = ServiceOperatorLoader(uri="mock_service_operator") + self.mock_operator_info = OperatorInfo( + type="mock_operator", + gpu="no", + description="Mock Operator", + version="v1", + conda="mock_operator_v1", + conda_type="custom", + path="/path/to/mock_operator", + backends=["cpu"], + ) def test_compatible(self): - pass + # Test the compatible method with a valid URI + uri = "forecast" + self.assertTrue(ServiceOperatorLoader.compatible(uri=uri)) - def test_load(self): - pass - - -class TestLocalOperatorLoader: - """Tests class to load a local operator.""" - - def test_compatible(self): - pass + # Test the compatible method with an invalid URI + uri = "invalid_service_operator" + self.assertFalse(ServiceOperatorLoader.compatible(uri=uri)) def test_load(self): - pass - - -class TestRemoteOperatorLoader: - """Tests class to load a remote operator.""" + # Mock the _load method to return the mock_operator_info object + self.loader._load = Mock(return_value=self.mock_operator_info) + + # Call the load method of the ServiceOperatorLoader + operator_info = self.loader.load() + + # Check if the returned OperatorInfo object matches the expected values + self.assertEqual(operator_info.type, "mock_operator") + self.assertEqual(operator_info.gpu, False) + self.assertEqual(operator_info.description, "Mock Operator") + self.assertEqual(operator_info.version, "v1") + self.assertEqual(operator_info.conda, "mock_operator_v1") + self.assertEqual(operator_info.conda_type, "custom") + self.assertEqual(operator_info.path, "/path/to/mock_operator") + self.assertEqual(operator_info.backends, ["cpu"]) + + def test_load_exception(self): + # Mock the _load method to raise an exception + self.loader._load = Mock( + side_effect=Exception("Error loading service operator") + ) + + # Call the load method of the ServiceOperatorLoader and expect an exception + with self.assertRaises(Exception): + self.loader.load() + + +class TestLocalOperatorLoader(unittest.TestCase): + def setUp(self): + # Create a mock LocalOperatorLoader instance for testing + self.loader = LocalOperatorLoader(uri="path/to/local/operator") + self.mock_operator_info = OperatorInfo( + type="mock_operator", + gpu=False, + description="Mock Operator", + version="v1", + conda="mock_operator_v1", + conda_type="custom", + path="/path/to/mock_operator", + backends=["cpu"], + ) def test_compatible(self): - pass + # Test the compatible method with a valid URI + uri = "path/to/local/operator" + self.assertTrue(LocalOperatorLoader.compatible(uri=uri)) + + # Test the compatible method with an invalid URI + uri = "http://example.com/remote/operator" + self.assertFalse(LocalOperatorLoader.compatible(uri=uri)) def test_load(self): - pass + # Mock the _load method to return the mock_operator_info object + self.loader._load = Mock(return_value=self.mock_operator_info) + + # Call the load method of the LocalOperatorLoader + operator_info = self.loader.load() + + # Check if the returned OperatorInfo object matches the expected values + self.assertEqual(operator_info.type, "mock_operator") + self.assertEqual(operator_info.gpu, False) + self.assertEqual(operator_info.description, "Mock Operator") + self.assertEqual(operator_info.version, "v1") + self.assertEqual(operator_info.conda, "mock_operator_v1") + self.assertEqual(operator_info.conda_type, "custom") + self.assertEqual(operator_info.path, "/path/to/mock_operator") + self.assertEqual(operator_info.backends, ["cpu"]) + + def test_load_exception(self): + # Mock the _load method to raise an exception + self.loader._load = Mock(side_effect=Exception("Error loading local operator")) + + # Call the load method of the LocalOperatorLoader and expect an exception + with self.assertRaises(Exception): + self.loader.load() + + +class TestRemoteOperatorLoader(unittest.TestCase): + def setUp(self): + # Create a mock RemoteOperatorLoader instance for testing + self.loader = RemoteOperatorLoader(uri="oci://bucket/operator.zip") + self.mock_operator_info = OperatorInfo( + type="mock_operator", + gpu=False, + description="Mock Operator", + version="v1", + conda="mock_operator_v1", + conda_type="custom", + path="/path/to/mock_operator", + backends=["cpu"], + ) - def test_cleanup(self): - pass + def test_compatible(self): + # Test the compatible method with a valid URI + uri = "oci://bucket/operator.zip" + self.assertTrue(RemoteOperatorLoader.compatible(uri=uri)) + # Test the compatible method with an invalid URI + uri = "http://example.com/remote/operator" + self.assertFalse(RemoteOperatorLoader.compatible(uri=uri)) -class TestGitOperatorLoader: - """Tests class to load an operator from a GIT repository.""" + def test_load(self): + # Mock the _load method to return the mock_operator_info object + self.loader._load = Mock(return_value=self.mock_operator_info) + + # Call the load method of the RemoteOperatorLoader + operator_info = self.loader.load() + + # Check if the returned OperatorInfo object matches the expected values + self.assertEqual(operator_info.type, "mock_operator") + self.assertEqual(operator_info.gpu, False) + self.assertEqual(operator_info.description, "Mock Operator") + self.assertEqual(operator_info.version, "v1") + self.assertEqual(operator_info.conda, "mock_operator_v1") + self.assertEqual(operator_info.conda_type, "custom") + self.assertEqual(operator_info.path, "/path/to/mock_operator") + self.assertEqual(operator_info.backends, ["cpu"]) + + def test_load_exception(self): + # Mock the _load method to raise an exception + self.loader._load = Mock(side_effect=Exception("Error loading remote operator")) + + # Call the load method of the RemoteOperatorLoader and expect an exception + with self.assertRaises(Exception): + self.loader.load() + + +class TestGitOperatorLoader(unittest.TestCase): + def setUp(self): + # Create a temporary directory for testing + self.temp_dir = "temp_git_loader" + os.makedirs(self.temp_dir, exist_ok=True) + + # Create a mock GitOperatorLoader instance for testing + self.loader = GitOperatorLoader( + uri="https://github.com/mock_operator_repository.git@feature-branch#forecasting", + uri_dst=self.temp_dir, + ) + self.mock_operator_info = OperatorInfo( + type="mock_operator", + gpu=False, + description="Mock Operator", + version="v1", + conda="mock_operator_v1", + conda_type="custom", + path=os.path.join(self.temp_dir, "forecasting"), + backends=["cpu"], + ) + + def tearDown(self): + # Clean up the temporary directory + shutil.rmtree(self.temp_dir, ignore_errors=True) def test_compatible(self): - pass + # Test the compatible method with a valid URI + uri = ( + "https://github.com/mock_operator_repository.git@feature-branch#forecasting" + ) + self.assertTrue(GitOperatorLoader.compatible(uri=uri)) - def test_load(self): - pass + # Test the compatible method with an invalid URI + uri = "http://example.com/remote/operator" + self.assertFalse(GitOperatorLoader.compatible(uri=uri)) - def test_cleanup(self): - pass + def test_load(self): + # Mock the git.Repo.clone_from method to avoid actual Git operations + with patch("git.Repo.clone_from") as mock_clone_from: + mock_clone_from.return_value = Mock() + + # Mock the _load method to return the mock_operator_info object + self.loader._load = Mock(return_value=self.mock_operator_info) + + # Call the load method of the GitOperatorLoader + operator_info = self.loader.load() + + # Check if the returned OperatorInfo object matches the expected values + self.assertEqual(operator_info.type, "mock_operator") + self.assertEqual(operator_info.gpu, False) + self.assertEqual(operator_info.description, "Mock Operator") + self.assertEqual(operator_info.version, "v1") + self.assertEqual(operator_info.conda, "mock_operator_v1") + self.assertEqual(operator_info.conda_type, "custom") + self.assertEqual( + operator_info.path, os.path.join(self.temp_dir, "forecasting") + ) + self.assertEqual(operator_info.backends, ["cpu"]) + + def test_load_exception(self): + # Mock the git.Repo.clone_from method to raise an exception + with patch( + "git.Repo.clone_from", side_effect=Exception("Error cloning Git repository") + ): + # Mock the _load method to raise an exception + self.loader._load = Mock( + side_effect=Exception("Error loading Git operator") + ) + + # Call the load method of the GitOperatorLoader and expect an exception + with self.assertRaises(Exception): + self.loader.load() diff --git a/tests/unitary/with_extras/operator/test_runtime.py b/tests/unitary/with_extras/operator/test_runtime.py index 0c08baaa2..6e3f8d7b3 100644 --- a/tests/unitary/with_extras/operator/test_runtime.py +++ b/tests/unitary/with_extras/operator/test_runtime.py @@ -4,52 +4,97 @@ # Copyright (c) 2023 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -import pytest -from my_module.runtime import ContainerRuntime, PythonRuntime +import unittest +from unittest.mock import patch, MagicMock +from ads.opctl.operator.runtime.runtime import ( + ContainerRuntime, + ContainerRuntimeSpec, + PythonRuntime, + Runtime, + OPERATOR_LOCAL_RUNTIME_TYPE, +) -class TestContainerRuntime: + +class TestRuntime(unittest.TestCase): + def setUp(self): + self.runtime = Runtime() + + def test_kind(self): + self.assertEqual(self.runtime.kind, "operator.local") + + def test_type(self): + self.assertIsNone(self.runtime.type) + + def test_version(self): + self.assertIsNone(self.runtime.version) + + @patch("ads.opctl.operator.runtime.runtime._load_yaml_from_uri") + @patch("ads.opctl.operator.runtime.runtime.Validator") + def test_validate_dict(self, mock_validator, mock_load_yaml): + mock_validator.return_value.validate.return_value = True + self.assertTrue(Runtime._validate_dict({})) + mock_load_yaml.assert_called_once() + mock_validator.assert_called_once() + + @patch("ads.opctl.operator.runtime.runtime._load_yaml_from_uri") + @patch("ads.opctl.operator.runtime.runtime.Validator") + def test_validate_dict_invalid(self, mock_validator, mock_load_yaml): + mock_validator.return_value = MagicMock( + errors=[{"error": "error"}], validate=MagicMock(return_value=False) + ) + mock_validator.return_value.validate.return_value = False + with self.assertRaises(ValueError): + Runtime._validate_dict({}) + mock_load_yaml.assert_called_once() + mock_validator.assert_called_once() + + +class TestContainerRuntime(unittest.TestCase): def test_init(self): - runtime = ContainerRuntime(image="my-image", command="my-command") - self.assertEqual(runtime.image, "my-image") - self.assertEqual(runtime.command, "my-command") - - def test_to_dict(self): - runtime = ContainerRuntime(image="my-image", command="my-command") - runtime_dict = runtime.to_dict() - self.assertIsInstance(runtime_dict, dict) - self.assertEqual(runtime_dict["type"], "container") - self.assertEqual(runtime_dict["image"], "my-image") - self.assertEqual(runtime_dict["command"], "my-command") - - def test_from_dict(self): - runtime_dict = { - "type": "container", - "image": "my-image", - "command": "my-command", - } - runtime = ContainerRuntime.from_dict(runtime_dict) + runtime = ContainerRuntime.init( + image="my-image", + env=[{"name": "VAR1", "value": "value1"}], + volume=["/data"], + ) self.assertIsInstance(runtime, ContainerRuntime) - self.assertEqual(runtime.image, "my-image") - self.assertEqual(runtime.command, "my-command") + self.assertEqual(runtime.type, OPERATOR_LOCAL_RUNTIME_TYPE.CONTAINER.value) + self.assertEqual(runtime.version, "v1") + self.assertIsInstance(runtime.spec, ContainerRuntimeSpec) + self.assertEqual(runtime.spec.image, "my-image") + self.assertEqual(runtime.spec.env, [{"name": "VAR1", "value": "value1"}]) + self.assertEqual(runtime.spec.volume, ["/data"]) + def test_validate_dict(self): + valid_dict = { + "kind": "operator.local", + "type": "container", + "version": "v1", + "spec": { + "image": "my-image", + "env": [{"name": "VAR1", "value": "value1"}], + "volume": ["/data"], + }, + } + self.assertTrue(ContainerRuntime._validate_dict(valid_dict)) -class TestPythonRuntime: - def test_init(self): - runtime = PythonRuntime(version="v2") - self.assertEqual(runtime.version, "v2") - self.assertEqual(runtime.type, "python") + invalid_dict = { + "kind": "operator.local", + "type": "unknown", + "version": "v1", + "spec": { + "image": "my-image", + "env": [{"name": "VAR1"}], + "volume": ["/data"], + }, + } + with self.assertRaises(ValueError): + ContainerRuntime._validate_dict(invalid_dict) - def test_to_dict(self): - runtime = PythonRuntime(version="v2") - runtime_dict = runtime.to_dict() - self.assertIsInstance(runtime_dict, dict) - self.assertEqual(runtime_dict["type"], "python") - self.assertEqual(runtime_dict["version"], "v2") - def test_from_dict(self): - runtime_dict = {"type": "python", "version": "v2"} - runtime = PythonRuntime.from_dict(runtime_dict) +class TestPythonRuntime(unittest.TestCase): + def test_init(self): + runtime = PythonRuntime.init() self.assertIsInstance(runtime, PythonRuntime) - self.assertEqual(runtime.version, "v2") self.assertEqual(runtime.type, "python") + self.assertEqual(runtime.version, "v1")