Skip to content

Commit

Permalink
Operator. Unit tests for the common utils. (#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrDzurb authored Oct 5, 2023
2 parents 7b287d4 + f70d6bb commit 1854750
Show file tree
Hide file tree
Showing 18 changed files with 883 additions and 82 deletions.
6 changes: 2 additions & 4 deletions ads/opctl/operator/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ads.opctl.utils import suppress_traceback

from .__init__ import __operators__
from .cmd import apply as cmd_apply
from .cmd import run as cmd_run
from .cmd import build_conda as cmd_build_conda
from .cmd import build_image as cmd_build_image
from .cmd import create as cmd_create
Expand Down Expand Up @@ -289,6 +289,4 @@ def run(debug: bool, **kwargs: Dict[str, Any]) -> None:
with fsspec.open(backend, "r", **auth) as f:
backend = suppress_traceback(debug)(yaml.safe_load)(f.read())

suppress_traceback(debug)(cmd_apply)(
config=operator_spec, backend=backend, **kwargs
)
suppress_traceback(debug)(cmd_run)(config=operator_spec, backend=backend, **kwargs)
2 changes: 1 addition & 1 deletion ads/opctl/operator/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def create(
raise NotImplementedError()


def apply(config: Dict, backend: Union[Dict, str] = None, **kwargs) -> None:
def run(config: Dict, backend: Union[Dict, str] = None, **kwargs) -> None:
"""
Runs the operator with the given specification on the targeted backend.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from ads.common.serializer import DataClassSerializable

from .common.utils import OperatorValidator
from .utils import OperatorValidator


@dataclass(repr=True)
Expand Down
45 changes: 14 additions & 31 deletions ads/opctl/operator/common/operator_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ads.common.serializer import DataClassSerializable
from ads.common.utils import copy_from_uri
from ads.opctl import logger
from ads.opctl.constants import BACKEND_NAME, OPERATOR_MODULE_PATH
from ads.opctl.constants import OPERATOR_MODULE_PATH
from ads.opctl.operator import __operators__

from .const import ARCH_TYPE, PACK_TYPE
Expand Down Expand Up @@ -67,15 +67,15 @@ class OperatorInfo(DataClassSerializable):
Generates the conda prefix for the custom conda pack.
"""

type: str
gpu: bool
description: str
version: str
conda: str
conda_type: str
path: str
keywords: List[str]
backends: List[str]
type: str = ""
gpu: bool = False
description: str = ""
version: str = ""
conda: str = ""
conda_type: str = ""
path: str = ""
keywords: List[str] = None
backends: List[str] = None

@property
def conda_prefix(self) -> str:
Expand All @@ -100,7 +100,7 @@ def conda_prefix(self) -> str:
)

def __post_init__(self):
self.gpu = self.gpu == "yes"
self.gpu = self.gpu == True or self.gpu == "yes"
self.version = self.version or "v1"
self.conda_type = self.conda_type or PACK_TYPE.CUSTOM
self.conda = self.conda or f"{self.type}_{self.version}"
Expand Down Expand Up @@ -137,7 +137,9 @@ def from_yaml(
obj: OperatorInfo = super().from_yaml(
yaml_string=yaml_string, uri=uri, loader=loader, **kwargs
)
obj.path = os.path.dirname(uri)

if uri:
obj.path = os.path.dirname(uri)
return obj


Expand Down Expand Up @@ -671,25 +673,6 @@ def _module_from_file(module_name: str, module_path: str) -> Any:
return module


def _module_constant_values(module_name: str, module_path: str) -> Dict[str, Any]:
"""Returns the list of constant variables from a given module.
Parameters
----------
module_name (str)
The name of the module to be imported.
module_path (str)
The physical path of the module.
Returns
-------
Dict[str, Any]
Map of variable names and their values.
"""
module = _module_from_file(module_name, module_path)
return {name: value for name, value in vars(module).items()}


def _operator_info(path: str = None, name: str = None) -> OperatorInfo:
"""
Extracts operator's details by given path.
Expand Down
2 changes: 1 addition & 1 deletion ads/opctl/operator/common/operator_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@ conda_type:
default: custom
allowed:
- service
- custom
- published
meta:
description: "The operator's conda environment type. Can be either service or custom type."
22 changes: 2 additions & 20 deletions ads/opctl/operator/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
import fsspec
import yaml
from cerberus import Validator
from yaml import SafeLoader

from ads.opctl import logger
from ads.opctl.operator import __operators__
from ads.opctl.utils import run_command
from ads.opctl import utils

CONTAINER_NETWORK = "CONTAINER_NETWORK"

Expand Down Expand Up @@ -95,7 +94,7 @@ def _build_image(

logger.info(f"Build image: {command}")

proc = run_command(command)
proc = utils.run_command(command)
if proc.returncode != 0:
raise RuntimeError("Docker build failed.")

Expand Down Expand Up @@ -145,23 +144,6 @@ def _load_yaml_from_string(doc: str, **kwargs) -> Dict:
)


def _load_multi_document_yaml_from_string(doc: str, **kwargs) -> Dict:
"""Loads multiline YAML from string and merge it with env variables and kwargs."""
template_dict = {**os.environ, **kwargs}
return yaml.load_all(
Template(doc).substitute(
**template_dict,
),
Loader=SafeLoader,
)


def _load_multi_document_yaml_from_uri(uri: str, **kwargs) -> Dict:
"""Loads multiline YAML from file and merge it with env variables and kwargs."""
with fsspec.open(uri) as f:
return _load_multi_document_yaml_from_string(str(f.read(), "UTF-8"), **kwargs)


def _load_yaml_from_uri(uri: str, **kwargs) -> str:
"""Loads YAML from the URI path. Can be Object Storage path."""
with fsspec.open(uri) as f:
Expand Down
2 changes: 1 addition & 1 deletion ads/opctl/operator/lowcode/forecast/MLoperator
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
type: forecast
version: v1
conda_type: custom
conda_type: published
gpu: no
keywords:
- Prophet
Expand Down
6 changes: 4 additions & 2 deletions ads/opctl/operator/lowcode/forecast/operator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ads.common.serializer import DataClassSerializable
from ads.opctl.operator.common.utils import _load_yaml_from_uri
from ads.opctl.operator.operator_config import OperatorConfig
from ads.opctl.operator.common.operator_config import OperatorConfig

from .const import SupportedMetrics

Expand Down Expand Up @@ -103,7 +103,9 @@ def __post_init__(self):
self.metric = (self.metric or "").lower() or SupportedMetrics.SMAPE.lower()
self.confidence_interval_width = self.confidence_interval_width or 0.80
self.report_file_name = self.report_file_name or "report.html"
self.preprocessing = self.preprocessing if self.preprocessing is not None else True
self.preprocessing = (
self.preprocessing if self.preprocessing is not None else True
)
self.report_theme = self.report_theme or "light"
self.metrics_filename = self.metrics_filename or "metrics.csv"
self.forecast_filename = self.forecast_filename or "forecast.csv"
Expand Down
2 changes: 1 addition & 1 deletion ads/opctl/operator/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class OPERATOR_LOCAL_RUNTIME_TYPE(ExtendedEnum):
class Runtime(DataClassSerializable):
"""Base class for the operator's runtimes."""

_schema: ClassVar[str] = None
_schema: ClassVar[str] = ""
kind: str = OPERATOR_LOCAL_RUNTIME_KIND
type: str = None
version: str = None
Expand Down
42 changes: 41 additions & 1 deletion tests/unitary/with_extras/operator/test_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,49 @@
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


class TestCMD:
class TestOperatorCMD:
"""Tests operator commands."""

def test_list(self):
"""Ensures that the list of the registered operators can be printed."""
pass

def test_info(self):
"""Ensures that the detailed information about the particular operator can be printed."""
pass

def test_init_backend_config(self):
"""Ensures the operator's backend configs can be generated."""
pass

def test_init_success(self):
"""Ensures that a starter YAML configurations for the operator can be generated."""
pass

def test_init_fail(self):
"""Ensures that generating starter specification fails in case of wrong input attributes."""
pass

def test_build_image(self):
"""Ensures that operator's image can be successfully built."""
pass

def test_publish_image(self):
"""Ensures that operator's image can be successfully published."""
pass

def test_verify(self):
"""Ensures that operator's config can be successfully verified."""
pass

def test_build_conda(self):
"""Ensures that the operator's conda environment can be successfully built."""
pass

def test_publish_conda(self):
"""Ensures that the operator's conda environment can be successfully published."""
pass

def test_run(self):
"""Ensures that the operator can be run on the targeted backend."""
pass
101 changes: 98 additions & 3 deletions tests/unitary/with_extras/operator/test_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,103 @@
# Copyright (c) 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import argparse
import os
import unittest
from unittest.mock import patch, MagicMock

class TestCommonUtils:
"""Tests all common utils methods of the operators."""
from ads.opctl.operator.common.utils import _build_image, _parse_input_args

pass

class TestBuildImage(unittest.TestCase):
def setUp(self):
self.curr_dir = os.path.dirname(os.path.abspath(__file__))
self.dockerfile = os.path.join(self.curr_dir, "test_files", "Dockerfile.test")
self.image_name = "test_image"
self.tag = "test_tag"
self.target = "test_target"
self.kwargs = {"arg1": "value1", "arg2": "value2"}

@patch("ads.opctl.utils.run_command")
def test_build_image(self, mock_run_command):
mock_proc = MagicMock()
mock_proc.returncode = 0
mock_run_command.return_value = mock_proc

image_name = _build_image(
self.dockerfile,
self.image_name,
tag=self.tag,
target=self.target,
**self.kwargs,
)

expected_image_name = f"{self.image_name}:{self.tag}"
command = [
"docker",
"build",
"-t",
expected_image_name,
"-f",
self.dockerfile,
"--target",
self.target,
os.path.dirname(self.dockerfile),
]

mock_run_command.assert_called_once_with(command)
self.assertEqual(image_name, expected_image_name)

def test_build_image_missing_dockerfile(self):
with self.assertRaises(FileNotFoundError):
_build_image("non_existing_docker_file", "non_existing_image")

def test_build_image_missing_image_name(self):
with self.assertRaises(ValueError):
_build_image(self.dockerfile, None)

@patch("ads.opctl.utils.run_command")
def test_build_image_docker_build_failure(self, mock_run_command):
mock_proc = MagicMock()
mock_proc.returncode = 1
mock_run_command.return_value = mock_proc

with self.assertRaises(RuntimeError):
_build_image(self.dockerfile, self.image_name)


class TestParseInputArgs(unittest.TestCase):
def test_parse_input_args_with_file(self):
raw_args = ["-f", "path/to/file.yaml"]
expected_output = (
argparse.Namespace(file="path/to/file.yaml", spec=None, verify=False),
[],
)
self.assertEqual(_parse_input_args(raw_args), expected_output)

def test_parse_input_args_with_spec(self):
raw_args = ["-s", "spec"]
expected_output = (argparse.Namespace(file=None, spec="spec", verify=False), [])
self.assertEqual(_parse_input_args(raw_args), expected_output)

def test_parse_input_args_with_verify(self):
raw_args = ["-v", "True"]
expected_output = (argparse.Namespace(file=None, spec=None, verify=True), [])
self.assertEqual(_parse_input_args(raw_args), expected_output)

def test_parse_input_args_with_unknown_args(self):
raw_args = ["-f", "path/to/file.yaml", "--unknown-arg", "value"]
expected_output = (
argparse.Namespace(file="path/to/file.yaml", spec=None, verify=False),
["--unknown-arg", "value"],
)
self.assertEqual(_parse_input_args(raw_args), expected_output)

@patch("argparse.ArgumentParser.parse_known_args")
def test_parse_input_args_with_no_args(self, mock_parse_known_args):
mock_parse_known_args.return_value = (
argparse.Namespace(file=None, spec=None, verify=False),
[],
)
expected_output = (argparse.Namespace(file=None, spec=None, verify=False), [])
self.assertEqual(_parse_input_args([]), expected_output)
2 changes: 2 additions & 0 deletions tests/unitary/with_extras/operator/test_files/Dockerfile.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
FROM baseImage
RUN echo "This is a message"
11 changes: 0 additions & 11 deletions tests/unitary/with_extras/operator/test_operator.py

This file was deleted.

Loading

0 comments on commit 1854750

Please sign in to comment.