Skip to content

Commit

Permalink
Updated compatibility check for aqua (#952)
Browse files Browse the repository at this point in the history
  • Loading branch information
VipulMascarenhas authored Sep 23, 2024
2 parents a76b698 + 2a27ef6 commit 5b4d2d3
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 17 deletions.
14 changes: 10 additions & 4 deletions ads/aqua/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
)
from ads.aqua.data import AquaResourceIdentifier
from ads.common.auth import AuthState, default_signer
from ads.common.decorator.threaded import threaded
from ads.common.extended_enum import ExtendedEnumMeta
from ads.common.object_storage_details import ObjectStorageDetails
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
Expand Down Expand Up @@ -225,6 +226,7 @@ def read_file(file_path: str, **kwargs) -> str:
return UNKNOWN


@threaded()
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
signer = default_signer() if artifact_path.startswith("oci://") else {}
Expand Down Expand Up @@ -1065,11 +1067,15 @@ def get_hf_model_info(repo_id: str) -> ModelInfo:


@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
def list_hf_models(query:str) -> List[str]:
def list_hf_models(query: str) -> List[str]:
try:
models= HfApi().list_models(model_name=query,task="text-generation",sort="downloads",direction=-1,limit=20)
models = HfApi().list_models(
model_name=query,
task="text-generation",
sort="downloads",
direction=-1,
limit=20,
)
return [model.id for model in models if model.disabled is None]
except HfHubHTTPError as err:
raise format_hf_custom_error_message(err) from err


5 changes: 2 additions & 3 deletions ads/aqua/extension/common_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@
from huggingface_hub.utils import LocalTokenNotFoundError
from tornado.web import HTTPError

from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.common.errors import AquaResourceAccessError, AquaRuntimeError
from ads.aqua.common.utils import (
fetch_service_compartment,
get_huggingface_login_timeout,
known_realm,
)
from ads.aqua.extension.base_handler import AquaAPIhandler
from ads.aqua.extension.errors import Errors
from ads.aqua.extension.utils import ui_compatability_check


class ADSVersionHandler(AquaAPIhandler):
Expand Down Expand Up @@ -51,7 +50,7 @@ def get(self):
AquaResourceAccessError: raised when aqua is not accessible in the given session/region.
"""
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
if ui_compatability_check():
return self.finish({"status": "ok"})
elif known_realm():
return self.finish({"status": "compatible"})
Expand Down
4 changes: 2 additions & 2 deletions ads/aqua/extension/common_ws_msg_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from importlib import metadata
from typing import List, Union

from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, fetch_service_compartment
from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.common.errors import AquaResourceAccessError
from ads.aqua.common.utils import known_realm
Expand All @@ -17,6 +16,7 @@
CompatibilityCheckResponse,
RequestResponseType,
)
from ads.aqua.extension.utils import ui_compatability_check


class AquaCommonWsMsgHandler(AquaWSMsgHandler):
Expand All @@ -39,7 +39,7 @@ def process(self) -> Union[AdsVersionResponse, CompatibilityCheckResponse]:
)
return response
if request.get("kind") == "CompatibilityCheck":
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
if ui_compatability_check():
return CompatibilityCheckResponse(
message_id=request.get("message_id"),
kind=RequestResponseType.CompatibilityCheck,
Expand Down
13 changes: 12 additions & 1 deletion ads/aqua/extension/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
from dataclasses import fields
from datetime import datetime, timedelta
from typing import Dict, Optional

from cachetools import TTLCache, cached
from tornado.web import HTTPError

from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
from ads.aqua.common.utils import fetch_service_compartment
from ads.aqua.extension.errors import Errors


Expand All @@ -21,3 +24,11 @@ def validate_function_parameters(data_class, input_data: Dict):
raise HTTPError(
400, Errors.MISSING_REQUIRED_PARAMETER.format(required_parameter)
)


@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
def ui_compatability_check():
"""This method caches the service compartment OCID details that is set by either the environment variable or if
fetched from the configuration. The cached result is returned when multiple calls are made in quick succession
from the UI to avoid multiple config file loads."""
return ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment()
25 changes: 19 additions & 6 deletions tests/unitary/with_extras/aqua/test_common_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ads.config
from ads.aqua.constants import AQUA_GA_LIST
from ads.aqua.extension.common_handler import CompatibilityCheckHandler
from ads.aqua.extension.utils import ui_compatability_check


class TestDataset:
Expand All @@ -28,6 +29,9 @@ def setUp(self, ipython_init_mock) -> None:
self.common_handler = CompatibilityCheckHandler(MagicMock(), MagicMock())
self.common_handler.request = MagicMock()

def tearDown(self) -> None:
ui_compatability_check.cache_clear()

def test_get_ok(self):
"""Test to check if ok is returned when ODSC_MODEL_COMPARTMENT_OCID is set."""
with patch.dict(
Expand All @@ -36,15 +40,22 @@ def test_get_ok(self):
):
reload(ads.config)
reload(ads.aqua)
reload(ads.aqua.extension.utils)
reload(ads.aqua.extension.common_handler)

with patch(
"ads.aqua.extension.base_handler.AquaAPIhandler.finish"
) as mock_finish:
mock_finish.side_effect = lambda x: x
self.common_handler.request.path = "aqua/hello"
result = self.common_handler.get()
assert result["status"] == "ok"
with patch(
"ads.aqua.extension.utils.fetch_service_compartment"
) as mock_fetch_service_compartment:
mock_fetch_service_compartment.return_value = (
TestDataset.SERVICE_COMPARTMENT_ID
)
mock_finish.side_effect = lambda x: x
self.common_handler.request.path = "aqua/hello"
result = self.common_handler.get()
assert result["status"] == "ok"

def test_get_compatible_status(self):
"""Test to check if compatible is returned when ODSC_MODEL_COMPARTMENT_OCID is not set
Expand All @@ -55,12 +66,13 @@ def test_get_compatible_status(self):
):
reload(ads.config)
reload(ads.aqua)
reload(ads.aqua.extension.utils)
reload(ads.aqua.extension.common_handler)
with patch(
"ads.aqua.extension.base_handler.AquaAPIhandler.finish"
) as mock_finish:
with patch(
"ads.aqua.extension.common_handler.fetch_service_compartment"
"ads.aqua.extension.utils.fetch_service_compartment"
) as mock_fetch_service_compartment:
mock_fetch_service_compartment.return_value = None
mock_finish.side_effect = lambda x: x
Expand All @@ -77,12 +89,13 @@ def test_raise_not_compatible_error(self):
):
reload(ads.config)
reload(ads.aqua)
reload(ads.aqua.extension.utils)
reload(ads.aqua.extension.common_handler)
with patch(
"ads.aqua.extension.base_handler.AquaAPIhandler.finish"
) as mock_finish:
with patch(
"ads.aqua.extension.common_handler.fetch_service_compartment"
"ads.aqua.extension.utils.fetch_service_compartment"
) as mock_fetch_service_compartment:
mock_fetch_service_compartment.return_value = None
mock_finish.side_effect = lambda x: x
Expand Down
3 changes: 2 additions & 1 deletion tests/unitary/with_extras/aqua/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from notebook.base.handlers import APIHandler, IPythonHandler
from oci.exceptions import ServiceError
from parameterized import parameterized
from tornado.httpserver import HTTPRequest
from tornado.httputil import HTTPServerRequest
from tornado.web import Application, HTTPError

Expand Down Expand Up @@ -191,6 +190,7 @@ def setUpClass(cls):

reload(ads.config)
reload(ads.aqua)
reload(ads.aqua.extension.utils)
reload(ads.aqua.extension.common_handler)

@classmethod
Expand All @@ -200,6 +200,7 @@ def tearDownClass(cls):

reload(ads.config)
reload(ads.aqua)
reload(ads.aqua.extension.utils)
reload(ads.aqua.extension.common_handler)

@parameterized.expand(
Expand Down

0 comments on commit 5b4d2d3

Please sign in to comment.