diff --git a/ads/aqua/extension/utils.py b/ads/aqua/extension/utils.py index e39b35b53..90787beb6 100644 --- a/ads/aqua/extension/utils.py +++ b/ads/aqua/extension/utils.py @@ -28,4 +28,7 @@ def validate_function_parameters(data_class, input_data: Dict): @cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now)) def ui_compatability_check(): + """This method caches the service compartment OCID details that is set by either the environment variable or if + fetched from the configuration. The cached result is returned when multiple calls are made in quick succession + from the UI to avoid multiple config file loads.""" return ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment() diff --git a/tests/unitary/with_extras/aqua/test_common_handler.py b/tests/unitary/with_extras/aqua/test_common_handler.py index 88e3e6e06..ec0590b07 100644 --- a/tests/unitary/with_extras/aqua/test_common_handler.py +++ b/tests/unitary/with_extras/aqua/test_common_handler.py @@ -15,6 +15,7 @@ import ads.config from ads.aqua.constants import AQUA_GA_LIST from ads.aqua.extension.common_handler import CompatibilityCheckHandler +from ads.aqua.extension.utils import ui_compatability_check class TestDataset: @@ -28,6 +29,9 @@ def setUp(self, ipython_init_mock) -> None: self.common_handler = CompatibilityCheckHandler(MagicMock(), MagicMock()) self.common_handler.request = MagicMock() + def tearDown(self) -> None: + ui_compatability_check.cache_clear() + def test_get_ok(self): """Test to check if ok is returned when ODSC_MODEL_COMPARTMENT_OCID is set.""" with patch.dict( @@ -36,15 +40,22 @@ def test_get_ok(self): ): reload(ads.config) reload(ads.aqua) + reload(ads.aqua.extension.utils) reload(ads.aqua.extension.common_handler) with patch( "ads.aqua.extension.base_handler.AquaAPIhandler.finish" ) as mock_finish: - mock_finish.side_effect = lambda x: x - self.common_handler.request.path = "aqua/hello" - result = self.common_handler.get() - assert result["status"] == "ok" + with patch( + "ads.aqua.extension.utils.fetch_service_compartment" + ) as mock_fetch_service_compartment: + mock_fetch_service_compartment.return_value = ( + TestDataset.SERVICE_COMPARTMENT_ID + ) + mock_finish.side_effect = lambda x: x + self.common_handler.request.path = "aqua/hello" + result = self.common_handler.get() + assert result["status"] == "ok" def test_get_compatible_status(self): """Test to check if compatible is returned when ODSC_MODEL_COMPARTMENT_OCID is not set @@ -55,12 +66,13 @@ def test_get_compatible_status(self): ): reload(ads.config) reload(ads.aqua) + reload(ads.aqua.extension.utils) reload(ads.aqua.extension.common_handler) with patch( "ads.aqua.extension.base_handler.AquaAPIhandler.finish" ) as mock_finish: with patch( - "ads.aqua.extension.common_handler.fetch_service_compartment" + "ads.aqua.extension.utils.fetch_service_compartment" ) as mock_fetch_service_compartment: mock_fetch_service_compartment.return_value = None mock_finish.side_effect = lambda x: x @@ -77,12 +89,13 @@ def test_raise_not_compatible_error(self): ): reload(ads.config) reload(ads.aqua) + reload(ads.aqua.extension.utils) reload(ads.aqua.extension.common_handler) with patch( "ads.aqua.extension.base_handler.AquaAPIhandler.finish" ) as mock_finish: with patch( - "ads.aqua.extension.common_handler.fetch_service_compartment" + "ads.aqua.extension.utils.fetch_service_compartment" ) as mock_fetch_service_compartment: mock_fetch_service_compartment.return_value = None mock_finish.side_effect = lambda x: x diff --git a/tests/unitary/with_extras/aqua/test_handlers.py b/tests/unitary/with_extras/aqua/test_handlers.py index 97c5660f7..74b9853b4 100644 --- a/tests/unitary/with_extras/aqua/test_handlers.py +++ b/tests/unitary/with_extras/aqua/test_handlers.py @@ -13,7 +13,6 @@ from notebook.base.handlers import APIHandler, IPythonHandler from oci.exceptions import ServiceError from parameterized import parameterized -from tornado.httpserver import HTTPRequest from tornado.httputil import HTTPServerRequest from tornado.web import Application, HTTPError @@ -191,6 +190,7 @@ def setUpClass(cls): reload(ads.config) reload(ads.aqua) + reload(ads.aqua.extension.utils) reload(ads.aqua.extension.common_handler) @classmethod @@ -200,6 +200,7 @@ def tearDownClass(cls): reload(ads.config) reload(ads.aqua) + reload(ads.aqua.extension.utils) reload(ads.aqua.extension.common_handler) @parameterized.expand(