Skip to content

Commit

Permalink
add comments and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
VipulMascarenhas committed Sep 23, 2024
1 parent 882a215 commit 2a27ef6
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
3 changes: 3 additions & 0 deletions ads/aqua/extension/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,7 @@ def validate_function_parameters(data_class, input_data: Dict):

@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
def ui_compatability_check():
"""This method caches the service compartment OCID details that is set by either the environment variable or if
fetched from the configuration. The cached result is returned when multiple calls are made in quick succession
from the UI to avoid multiple config file loads."""
return ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment()
25 changes: 19 additions & 6 deletions tests/unitary/with_extras/aqua/test_common_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ads.config
from ads.aqua.constants import AQUA_GA_LIST
from ads.aqua.extension.common_handler import CompatibilityCheckHandler
from ads.aqua.extension.utils import ui_compatability_check


class TestDataset:
Expand All @@ -28,6 +29,9 @@ def setUp(self, ipython_init_mock) -> None:
self.common_handler = CompatibilityCheckHandler(MagicMock(), MagicMock())
self.common_handler.request = MagicMock()

def tearDown(self) -> None:
ui_compatability_check.cache_clear()

def test_get_ok(self):
"""Test to check if ok is returned when ODSC_MODEL_COMPARTMENT_OCID is set."""
with patch.dict(
Expand All @@ -36,15 +40,22 @@ def test_get_ok(self):
):
reload(ads.config)
reload(ads.aqua)
reload(ads.aqua.extension.utils)
reload(ads.aqua.extension.common_handler)

with patch(
"ads.aqua.extension.base_handler.AquaAPIhandler.finish"
) as mock_finish:
mock_finish.side_effect = lambda x: x
self.common_handler.request.path = "aqua/hello"
result = self.common_handler.get()
assert result["status"] == "ok"
with patch(
"ads.aqua.extension.utils.fetch_service_compartment"
) as mock_fetch_service_compartment:
mock_fetch_service_compartment.return_value = (
TestDataset.SERVICE_COMPARTMENT_ID
)
mock_finish.side_effect = lambda x: x
self.common_handler.request.path = "aqua/hello"
result = self.common_handler.get()
assert result["status"] == "ok"

def test_get_compatible_status(self):
"""Test to check if compatible is returned when ODSC_MODEL_COMPARTMENT_OCID is not set
Expand All @@ -55,12 +66,13 @@ def test_get_compatible_status(self):
):
reload(ads.config)
reload(ads.aqua)
reload(ads.aqua.extension.utils)
reload(ads.aqua.extension.common_handler)
with patch(
"ads.aqua.extension.base_handler.AquaAPIhandler.finish"
) as mock_finish:
with patch(
"ads.aqua.extension.common_handler.fetch_service_compartment"
"ads.aqua.extension.utils.fetch_service_compartment"
) as mock_fetch_service_compartment:
mock_fetch_service_compartment.return_value = None
mock_finish.side_effect = lambda x: x
Expand All @@ -77,12 +89,13 @@ def test_raise_not_compatible_error(self):
):
reload(ads.config)
reload(ads.aqua)
reload(ads.aqua.extension.utils)
reload(ads.aqua.extension.common_handler)
with patch(
"ads.aqua.extension.base_handler.AquaAPIhandler.finish"
) as mock_finish:
with patch(
"ads.aqua.extension.common_handler.fetch_service_compartment"
"ads.aqua.extension.utils.fetch_service_compartment"
) as mock_fetch_service_compartment:
mock_fetch_service_compartment.return_value = None
mock_finish.side_effect = lambda x: x
Expand Down
3 changes: 2 additions & 1 deletion tests/unitary/with_extras/aqua/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from notebook.base.handlers import APIHandler, IPythonHandler
from oci.exceptions import ServiceError
from parameterized import parameterized
from tornado.httpserver import HTTPRequest
from tornado.httputil import HTTPServerRequest
from tornado.web import Application, HTTPError

Expand Down Expand Up @@ -191,6 +190,7 @@ def setUpClass(cls):

reload(ads.config)
reload(ads.aqua)
reload(ads.aqua.extension.utils)
reload(ads.aqua.extension.common_handler)

@classmethod
Expand All @@ -200,6 +200,7 @@ def tearDownClass(cls):

reload(ads.config)
reload(ads.aqua)
reload(ads.aqua.extension.utils)
reload(ads.aqua.extension.common_handler)

@parameterized.expand(
Expand Down

0 comments on commit 2a27ef6

Please sign in to comment.