Skip to content

Commit

Permalink
added compatibility check
Browse files Browse the repository at this point in the history
  • Loading branch information
VipulMascarenhas committed Sep 21, 2024
1 parent a76b698 commit 882a215
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 10 deletions.
14 changes: 10 additions & 4 deletions ads/aqua/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
)
from ads.aqua.data import AquaResourceIdentifier
from ads.common.auth import AuthState, default_signer
from ads.common.decorator.threaded import threaded
from ads.common.extended_enum import ExtendedEnumMeta
from ads.common.object_storage_details import ObjectStorageDetails
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
Expand Down Expand Up @@ -225,6 +226,7 @@ def read_file(file_path: str, **kwargs) -> str:
return UNKNOWN


@threaded()
def load_config(file_path: str, config_file_name: str, **kwargs) -> dict:
artifact_path = f"{file_path.rstrip('/')}/{config_file_name}"
signer = default_signer() if artifact_path.startswith("oci://") else {}
Expand Down Expand Up @@ -1065,11 +1067,15 @@ def get_hf_model_info(repo_id: str) -> ModelInfo:


@cached(cache=TTLCache(maxsize=1, ttl=timedelta(hours=5), timer=datetime.now))
def list_hf_models(query:str) -> List[str]:
def list_hf_models(query: str) -> List[str]:
try:
models= HfApi().list_models(model_name=query,task="text-generation",sort="downloads",direction=-1,limit=20)
models = HfApi().list_models(
model_name=query,
task="text-generation",
sort="downloads",
direction=-1,
limit=20,
)
return [model.id for model in models if model.disabled is None]
except HfHubHTTPError as err:
raise format_hf_custom_error_message(err) from err


5 changes: 2 additions & 3 deletions ads/aqua/extension/common_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,15 @@
from huggingface_hub.utils import LocalTokenNotFoundError
from tornado.web import HTTPError

from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.common.errors import AquaResourceAccessError, AquaRuntimeError
from ads.aqua.common.utils import (
fetch_service_compartment,
get_huggingface_login_timeout,
known_realm,
)
from ads.aqua.extension.base_handler import AquaAPIhandler
from ads.aqua.extension.errors import Errors
from ads.aqua.extension.utils import ui_compatability_check


class ADSVersionHandler(AquaAPIhandler):
Expand Down Expand Up @@ -51,7 +50,7 @@ def get(self):
AquaResourceAccessError: raised when aqua is not accessible in the given session/region.
"""
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
if ui_compatability_check():
return self.finish({"status": "ok"})
elif known_realm():
return self.finish({"status": "compatible"})
Expand Down
4 changes: 2 additions & 2 deletions ads/aqua/extension/common_ws_msg_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from importlib import metadata
from typing import List, Union

from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, fetch_service_compartment
from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.common.errors import AquaResourceAccessError
from ads.aqua.common.utils import known_realm
Expand All @@ -17,6 +16,7 @@
CompatibilityCheckResponse,
RequestResponseType,
)
from ads.aqua.extension.utils import ui_compatability_check


class AquaCommonWsMsgHandler(AquaWSMsgHandler):
Expand All @@ -39,7 +39,7 @@ def process(self) -> Union[AdsVersionResponse, CompatibilityCheckResponse]:
)
return response
if request.get("kind") == "CompatibilityCheck":
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
if ui_compatability_check():
return CompatibilityCheckResponse(
message_id=request.get("message_id"),
kind=RequestResponseType.CompatibilityCheck,
Expand Down
10 changes: 9 additions & 1 deletion ads/aqua/extension/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
from dataclasses import fields
from datetime import datetime, timedelta
from typing import Dict, Optional

from cachetools import TTLCache, cached
from tornado.web import HTTPError

from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
from ads.aqua.common.utils import fetch_service_compartment
from ads.aqua.extension.errors import Errors


Expand All @@ -21,3 +24,8 @@ def validate_function_parameters(data_class, input_data: Dict):
raise HTTPError(
400, Errors.MISSING_REQUIRED_PARAMETER.format(required_parameter)
)


@cached(cache=TTLCache(maxsize=1, ttl=timedelta(minutes=1), timer=datetime.now))
def ui_compatability_check():
return ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment()

0 comments on commit 882a215

Please sign in to comment.