Skip to content

Commit

Permalink
Adds tests for the operator loader.
Browse files Browse the repository at this point in the history
  • Loading branch information
mrDzurb committed Oct 4, 2023
1 parent 00815f1 commit f70d6bb
Show file tree
Hide file tree
Showing 10 changed files with 692 additions and 150 deletions.
43 changes: 13 additions & 30 deletions ads/opctl/operator/common/operator_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ class OperatorInfo(DataClassSerializable):
Generates the conda prefix for the custom conda pack.
"""

type: str
gpu: bool
description: str
version: str
conda: str
conda_type: str
path: str
keywords: List[str]
backends: List[str]
type: str = ""
gpu: bool = False
description: str = ""
version: str = ""
conda: str = ""
conda_type: str = ""
path: str = ""
keywords: List[str] = None
backends: List[str] = None

@property
def conda_prefix(self) -> str:
Expand All @@ -100,7 +100,7 @@ def conda_prefix(self) -> str:
)

def __post_init__(self):
self.gpu = self.gpu == "yes"
self.gpu = self.gpu == True or self.gpu == "yes"
self.version = self.version or "v1"
self.conda_type = self.conda_type or PACK_TYPE.CUSTOM
self.conda = self.conda or f"{self.type}_{self.version}"
Expand Down Expand Up @@ -137,7 +137,9 @@ def from_yaml(
obj: OperatorInfo = super().from_yaml(
yaml_string=yaml_string, uri=uri, loader=loader, **kwargs
)
obj.path = os.path.dirname(uri)

if uri:
obj.path = os.path.dirname(uri)
return obj


Expand Down Expand Up @@ -671,25 +673,6 @@ def _module_from_file(module_name: str, module_path: str) -> Any:
return module


def _module_constant_values(module_name: str, module_path: str) -> Dict[str, Any]:
"""Returns the list of constant variables from a given module.
Parameters
----------
module_name (str)
The name of the module to be imported.
module_path (str)
The physical path of the module.
Returns
-------
Dict[str, Any]
Map of variable names and their values.
"""
module = _module_from_file(module_name, module_path)
return {name: value for name, value in vars(module).items()}


def _operator_info(path: str = None, name: str = None) -> OperatorInfo:
"""
Extracts operator's details by given path.
Expand Down
2 changes: 1 addition & 1 deletion ads/opctl/operator/common/operator_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,6 @@ conda_type:
default: custom
allowed:
- service
- custom
- published
meta:
description: "The operator's conda environment type. Can be either service or custom type."
5 changes: 2 additions & 3 deletions ads/opctl/operator/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
import fsspec
import yaml
from cerberus import Validator
from yaml import SafeLoader

from ads.opctl import logger
from ads.opctl.operator import __operators__
from ads.opctl.utils import run_command
from ads.opctl import utils

CONTAINER_NETWORK = "CONTAINER_NETWORK"

Expand Down Expand Up @@ -95,7 +94,7 @@ def _build_image(

logger.info(f"Build image: {command}")

proc = run_command(command)
proc = utils.run_command(command)
if proc.returncode != 0:
raise RuntimeError("Docker build failed.")

Expand Down
2 changes: 1 addition & 1 deletion ads/opctl/operator/lowcode/forecast/MLoperator
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
type: forecast
version: v1
conda_type: custom
conda_type: published
gpu: no
keywords:
- Prophet
Expand Down
2 changes: 1 addition & 1 deletion ads/opctl/operator/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class OPERATOR_LOCAL_RUNTIME_TYPE(ExtendedEnum):
class Runtime(DataClassSerializable):
"""Base class for the operator's runtimes."""

_schema: ClassVar[str] = None
_schema: ClassVar[str] = ""
kind: str = OPERATOR_LOCAL_RUNTIME_KIND
type: str = None
version: str = None
Expand Down
106 changes: 94 additions & 12 deletions tests/unitary/with_extras/operator/test_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,103 @@
# Copyright (c) 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import argparse
import os
import unittest
from unittest.mock import patch, MagicMock

class TestCommonUtils:
"""Tests all operator's common utils methods."""
from ads.opctl.operator.common.utils import _build_image, _parse_input_args

def test_build_image_success(self):
pass

def test_build_image_fail(self):
pass
class TestBuildImage(unittest.TestCase):
def setUp(self):
self.curr_dir = os.path.dirname(os.path.abspath(__file__))
self.dockerfile = os.path.join(self.curr_dir, "test_files", "Dockerfile.test")
self.image_name = "test_image"
self.tag = "test_tag"
self.target = "test_target"
self.kwargs = {"arg1": "value1", "arg2": "value2"}

def test_load_yaml_from_string(self):
pass
@patch("ads.opctl.utils.run_command")
def test_build_image(self, mock_run_command):
mock_proc = MagicMock()
mock_proc.returncode = 0
mock_run_command.return_value = mock_proc

def test_load_yaml_from_uri(self):
pass
image_name = _build_image(
self.dockerfile,
self.image_name,
tag=self.tag,
target=self.target,
**self.kwargs,
)

def test_parse_input_args(self):
pass
expected_image_name = f"{self.image_name}:{self.tag}"
command = [
"docker",
"build",
"-t",
expected_image_name,
"-f",
self.dockerfile,
"--target",
self.target,
os.path.dirname(self.dockerfile),
]

mock_run_command.assert_called_once_with(command)
self.assertEqual(image_name, expected_image_name)

def test_build_image_missing_dockerfile(self):
with self.assertRaises(FileNotFoundError):
_build_image("non_existing_docker_file", "non_existing_image")

def test_build_image_missing_image_name(self):
with self.assertRaises(ValueError):
_build_image(self.dockerfile, None)

@patch("ads.opctl.utils.run_command")
def test_build_image_docker_build_failure(self, mock_run_command):
mock_proc = MagicMock()
mock_proc.returncode = 1
mock_run_command.return_value = mock_proc

with self.assertRaises(RuntimeError):
_build_image(self.dockerfile, self.image_name)


class TestParseInputArgs(unittest.TestCase):
def test_parse_input_args_with_file(self):
raw_args = ["-f", "path/to/file.yaml"]
expected_output = (
argparse.Namespace(file="path/to/file.yaml", spec=None, verify=False),
[],
)
self.assertEqual(_parse_input_args(raw_args), expected_output)

def test_parse_input_args_with_spec(self):
raw_args = ["-s", "spec"]
expected_output = (argparse.Namespace(file=None, spec="spec", verify=False), [])
self.assertEqual(_parse_input_args(raw_args), expected_output)

def test_parse_input_args_with_verify(self):
raw_args = ["-v", "True"]
expected_output = (argparse.Namespace(file=None, spec=None, verify=True), [])
self.assertEqual(_parse_input_args(raw_args), expected_output)

def test_parse_input_args_with_unknown_args(self):
raw_args = ["-f", "path/to/file.yaml", "--unknown-arg", "value"]
expected_output = (
argparse.Namespace(file="path/to/file.yaml", spec=None, verify=False),
["--unknown-arg", "value"],
)
self.assertEqual(_parse_input_args(raw_args), expected_output)

@patch("argparse.ArgumentParser.parse_known_args")
def test_parse_input_args_with_no_args(self, mock_parse_known_args):
mock_parse_known_args.return_value = (
argparse.Namespace(file=None, spec=None, verify=False),
[],
)
expected_output = (argparse.Namespace(file=None, spec=None, verify=False), [])
self.assertEqual(_parse_input_args([]), expected_output)
2 changes: 2 additions & 0 deletions tests/unitary/with_extras/operator/test_files/Dockerfile.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
FROM baseImage
RUN echo "This is a message"
88 changes: 85 additions & 3 deletions tests/unitary/with_extras/operator/test_operator_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,90 @@
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


from dataclasses import dataclass

import pytest
import yaml

from ads.opctl.operator.common.operator_config import OperatorConfig


class TestOperatorConfig:
"""Tests base class representing operator config."""
def test_operator_config(self):
# Test valid operator config

@dataclass(repr=True)
class MyOperatorConfig(OperatorConfig):
@classmethod
def _load_schema(cls) -> str:
return yaml.safe_load(
"""
kind:
required: true
type: string
version:
required: true
type: string
type:
required: true
type: string
spec:
required: true
type: dict
schema:
foo:
required: false
type: string
"""
)

config = MyOperatorConfig.from_dict(
{
"kind": "operator",
"type": "my-operator",
"version": "v1",
"spec": {"foo": "bar"},
}
)
assert config.kind == "operator"
assert config.type == "my-operator"
assert config.version == "v1"
assert config.spec == {"foo": "bar"}

# Test invalid operator config
@dataclass(repr=True)
class InvalidOperatorConfig(OperatorConfig):
@classmethod
def _load_schema(cls) -> str:
return yaml.safe_load(
"""
kind:
required: true
type: string
version:
required: true
type: string
allowed:
- v1
type:
required: true
type: string
spec:
required: true
type: dict
schema:
foo:
required: true
type: string
"""
)

def test_validate_dict(self):
pass
with pytest.raises(ValueError):
InvalidOperatorConfig.from_dict(
{
"kind": "operator",
"type": "invalid-operator",
"version": "v2",
"spec": {"foo1": 123},
}
)
Loading

0 comments on commit f70d6bb

Please sign in to comment.