diff --git a/ads/aqua/common/utils.py b/ads/aqua/common/utils.py index a59dac646..ede4ddb88 100644 --- a/ads/aqua/common/utils.py +++ b/ads/aqua/common/utils.py @@ -58,6 +58,7 @@ ) from ads.aqua.data import AquaResourceIdentifier from ads.common.auth import AuthState, default_signer +from ads.common.decorator.threaded import threaded from ads.common.extended_enum import ExtendedEnumMeta from ads.common.object_storage_details import ObjectStorageDetails from ads.common.oci_resource import SEARCH_TYPE, OCIResource @@ -225,6 +226,7 @@ def read_file(file_path: str, **kwargs) -> str: return UNKNOWN +@threaded() def load_config(file_path: str, config_file_name: str, **kwargs) -> dict: artifact_path = f"{file_path.rstrip('/')}/{config_file_name}" signer = default_signer() if artifact_path.startswith("oci://") else {} @@ -1065,11 +1067,15 @@ def get_hf_model_info(repo_id: str) -> ModelInfo: @cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now)) -def list_hf_models(query:str) -> List[str]: +def list_hf_models(query: str) -> List[str]: try: - models= HfApi().list_models(model_name=query,task="text-generation",sort="downloads",direction=-1,limit=20) + models = HfApi().list_models( + model_name=query, + task="text-generation", + sort="downloads", + direction=-1, + limit=20, + ) return [model.id for model in models if model.disabled is None] except HfHubHTTPError as err: raise format_hf_custom_error_message(err) from err - - diff --git a/ads/aqua/extension/common_handler.py b/ads/aqua/extension/common_handler.py index c114b3a14..cc9a2f663 100644 --- a/ads/aqua/extension/common_handler.py +++ b/ads/aqua/extension/common_handler.py @@ -11,16 +11,15 @@ from huggingface_hub.utils import LocalTokenNotFoundError from tornado.web import HTTPError -from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID from ads.aqua.common.decorator import handle_exceptions from ads.aqua.common.errors import AquaResourceAccessError, AquaRuntimeError from ads.aqua.common.utils import ( - fetch_service_compartment, get_huggingface_login_timeout, known_realm, ) from ads.aqua.extension.base_handler import AquaAPIhandler from ads.aqua.extension.errors import Errors +from ads.aqua.extension.utils import ui_compatability_check class ADSVersionHandler(AquaAPIhandler): @@ -51,7 +50,7 @@ def get(self): AquaResourceAccessError: raised when aqua is not accessible in the given session/region. """ - if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment(): + if ui_compatability_check(): return self.finish({"status": "ok"}) elif known_realm(): return self.finish({"status": "compatible"}) diff --git a/ads/aqua/extension/common_ws_msg_handler.py b/ads/aqua/extension/common_ws_msg_handler.py index 71cb545f4..cc54af1de 100644 --- a/ads/aqua/extension/common_ws_msg_handler.py +++ b/ads/aqua/extension/common_ws_msg_handler.py @@ -7,7 +7,6 @@ from importlib import metadata from typing import List, Union -from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, fetch_service_compartment from ads.aqua.common.decorator import handle_exceptions from ads.aqua.common.errors import AquaResourceAccessError from ads.aqua.common.utils import known_realm @@ -17,6 +16,7 @@ CompatibilityCheckResponse, RequestResponseType, ) +from ads.aqua.extension.utils import ui_compatability_check class AquaCommonWsMsgHandler(AquaWSMsgHandler): @@ -39,7 +39,7 @@ def process(self) -> Union[AdsVersionResponse, CompatibilityCheckResponse]: ) return response if request.get("kind") == "CompatibilityCheck": - if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment(): + if ui_compatability_check(): return CompatibilityCheckResponse( message_id=request.get("message_id"), kind=RequestResponseType.CompatibilityCheck, diff --git a/ads/aqua/extension/utils.py b/ads/aqua/extension/utils.py index c757d91e2..90787beb6 100644 --- a/ads/aqua/extension/utils.py +++ b/ads/aqua/extension/utils.py @@ -1,12 +1,15 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- # Copyright (c) 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ from dataclasses import fields +from datetime import datetime, timedelta from typing import Dict, Optional +from cachetools import TTLCache, cached from tornado.web import HTTPError +from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID +from ads.aqua.common.utils import fetch_service_compartment from ads.aqua.extension.errors import Errors @@ -21,3 +24,11 @@ def validate_function_parameters(data_class, input_data: Dict): raise HTTPError( 400, Errors.MISSING_REQUIRED_PARAMETER.format(required_parameter) ) + + +@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now)) +def ui_compatability_check(): + """This method caches the service compartment OCID details that is set by either the environment variable or if + fetched from the configuration. The cached result is returned when multiple calls are made in quick succession + from the UI to avoid multiple config file loads.""" + return ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment() diff --git a/tests/unitary/with_extras/aqua/test_common_handler.py b/tests/unitary/with_extras/aqua/test_common_handler.py index 88e3e6e06..ec0590b07 100644 --- a/tests/unitary/with_extras/aqua/test_common_handler.py +++ b/tests/unitary/with_extras/aqua/test_common_handler.py @@ -15,6 +15,7 @@ import ads.config from ads.aqua.constants import AQUA_GA_LIST from ads.aqua.extension.common_handler import CompatibilityCheckHandler +from ads.aqua.extension.utils import ui_compatability_check class TestDataset: @@ -28,6 +29,9 @@ def setUp(self, ipython_init_mock) -> None: self.common_handler = CompatibilityCheckHandler(MagicMock(), MagicMock()) self.common_handler.request = MagicMock() + def tearDown(self) -> None: + ui_compatability_check.cache_clear() + def test_get_ok(self): """Test to check if ok is returned when ODSC_MODEL_COMPARTMENT_OCID is set.""" with patch.dict( @@ -36,15 +40,22 @@ def test_get_ok(self): ): reload(ads.config) reload(ads.aqua) + reload(ads.aqua.extension.utils) reload(ads.aqua.extension.common_handler) with patch( "ads.aqua.extension.base_handler.AquaAPIhandler.finish" ) as mock_finish: - mock_finish.side_effect = lambda x: x - self.common_handler.request.path = "aqua/hello" - result = self.common_handler.get() - assert result["status"] == "ok" + with patch( + "ads.aqua.extension.utils.fetch_service_compartment" + ) as mock_fetch_service_compartment: + mock_fetch_service_compartment.return_value = ( + TestDataset.SERVICE_COMPARTMENT_ID + ) + mock_finish.side_effect = lambda x: x + self.common_handler.request.path = "aqua/hello" + result = self.common_handler.get() + assert result["status"] == "ok" def test_get_compatible_status(self): """Test to check if compatible is returned when ODSC_MODEL_COMPARTMENT_OCID is not set @@ -55,12 +66,13 @@ def test_get_compatible_status(self): ): reload(ads.config) reload(ads.aqua) + reload(ads.aqua.extension.utils) reload(ads.aqua.extension.common_handler) with patch( "ads.aqua.extension.base_handler.AquaAPIhandler.finish" ) as mock_finish: with patch( - "ads.aqua.extension.common_handler.fetch_service_compartment" + "ads.aqua.extension.utils.fetch_service_compartment" ) as mock_fetch_service_compartment: mock_fetch_service_compartment.return_value = None mock_finish.side_effect = lambda x: x @@ -77,12 +89,13 @@ def test_raise_not_compatible_error(self): ): reload(ads.config) reload(ads.aqua) + reload(ads.aqua.extension.utils) reload(ads.aqua.extension.common_handler) with patch( "ads.aqua.extension.base_handler.AquaAPIhandler.finish" ) as mock_finish: with patch( - "ads.aqua.extension.common_handler.fetch_service_compartment" + "ads.aqua.extension.utils.fetch_service_compartment" ) as mock_fetch_service_compartment: mock_fetch_service_compartment.return_value = None mock_finish.side_effect = lambda x: x diff --git a/tests/unitary/with_extras/aqua/test_handlers.py b/tests/unitary/with_extras/aqua/test_handlers.py index 97c5660f7..74b9853b4 100644 --- a/tests/unitary/with_extras/aqua/test_handlers.py +++ b/tests/unitary/with_extras/aqua/test_handlers.py @@ -13,7 +13,6 @@ from notebook.base.handlers import APIHandler, IPythonHandler from oci.exceptions import ServiceError from parameterized import parameterized -from tornado.httpserver import HTTPRequest from tornado.httputil import HTTPServerRequest from tornado.web import Application, HTTPError @@ -191,6 +190,7 @@ def setUpClass(cls): reload(ads.config) reload(ads.aqua) + reload(ads.aqua.extension.utils) reload(ads.aqua.extension.common_handler) @classmethod @@ -200,6 +200,7 @@ def tearDownClass(cls): reload(ads.config) reload(ads.aqua) + reload(ads.aqua.extension.utils) reload(ads.aqua.extension.common_handler) @parameterized.expand(