Skip to content

Commit

Permalink
add hf apis
Browse files Browse the repository at this point in the history
  • Loading branch information
VipulMascarenhas committed Jul 25, 2024
1 parent 94a9d8a commit 99c4ae4
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 12 deletions.
8 changes: 8 additions & 0 deletions ads/aqua/common/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
PARAM_TYPE_LLAMA_CPP = "LLAMA_CPP_PARAMS"


class EvaluationContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
AQUA_EVALUATION_CONTAINER_FAMILY = "odsc-llm-evaluate"


class FineTuningContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
AQUA_FINETUNING_CONTAINER_FAMILY = "odsc-llm-fine-tuning"


class HuggingFaceTags(str, metaclass=ExtendedEnumMeta):
TEXT_GENERATION_INFERENCE = "text-generation-inference"

Expand Down
31 changes: 30 additions & 1 deletion ads/aqua/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import os
import random
import re
import shlex
import subprocess
from datetime import datetime, timedelta
from functools import wraps
from pathlib import Path
Expand Down Expand Up @@ -47,7 +49,7 @@
VLLM_INFERENCE_RESTRICTED_PARAMS,
)
from ads.aqua.data import AquaResourceIdentifier
from ads.common.auth import default_signer
from ads.common.auth import AuthState, default_signer
from ads.common.extended_enum import ExtendedEnumMeta
from ads.common.object_storage_details import ObjectStorageDetails
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
Expand Down Expand Up @@ -771,6 +773,33 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""


def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
"""Upload the local folder to the object storage
Args:
os_path (str): object storage URI with prefix. This is the path to upload
local_dir (str): Local directory where the object is downloaded
model_name (str): Name of the huggingface model
Retuns:
str: Object name inside the bucket
"""
os_details: ObjectStorageDetails = ObjectStorageDetails.from_path(os_path)
if not os_details.is_bucket_versioned():
raise ValueError(f"Version is not enabled at object storage location {os_path}")
auth_state = AuthState()
object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite"
try:
logger.info(f"Running: {command}")
subprocess.check_call(shlex.split(command))
except subprocess.CalledProcessError as e:
logger.error(
f"Error uploading the object. Exit code: {e.returncode} with error {e.stdout}"
)

return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path


def is_service_managed_container(container):
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)

Expand Down
69 changes: 65 additions & 4 deletions ads/aqua/extension/common_handler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


from importlib import metadata

import huggingface_hub
import requests
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from tornado.web import HTTPError

from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
Expand Down Expand Up @@ -46,16 +48,75 @@ def get(self):
"""
if ODSC_MODEL_COMPARTMENT_OCID or fetch_service_compartment():
return self.finish(dict(status="ok"))
return self.finish({"status": "ok"})
elif known_realm():
return self.finish(dict(status="compatible"))
return self.finish({"status": "compatible"})
else:
raise AquaResourceAccessError(
f"The AI Quick actions extension is not compatible in the given region."
"The AI Quick actions extension is not compatible in the given region."
)


class NetworkStatusHandler(AquaAPIhandler):
"""Handler to check internet connection."""

@handle_exceptions
def get(self):
requests.get("https://huggingface.com", timeout=2)
return self.finish("success")


class HFLoginHandler(AquaAPIhandler):
"""Handler to login to HF."""

@handle_exceptions
def post(self, *args, **kwargs):
"""Handles post request for the HF login.
Raises
------
HTTPError
Raises HTTPError if inputs are missing or are invalid.
"""
try:
input_data = self.get_json_body()
except Exception as ex:
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex

if not input_data:
raise HTTPError(400, Errors.NO_INPUT_DATA)

token = input_data.get("token")

if not token:
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("token"))

# Login to HF
huggingface_hub.login(token=token, new_session=False)

return self.finish({"status": 200, "message": "login successful"})


class HFUserStatusHandler(AquaAPIhandler):
"""Handler to check if user logged in to the HF."""

@handle_exceptions
def get(self):
try:
HfApi().whoami()
except LocalTokenNotFoundError as err:
raise AquaRuntimeError(
"You are not logged in. Please log in to Hugging Face using the `huggingface-cli login` command."
"See https://huggingface.co/settings/tokens.",
) from err

return self.finish({"status": 200, "message": "logged in"})


__handlers__ = [
("ads_version", ADSVersionHandler),
("hello", CompatibilityCheckHandler),
("network_status", NetworkStatusHandler),
("hf_login", HFLoginHandler),
("hf_logged_in", HFUserStatusHandler),
]
157 changes: 151 additions & 6 deletions ads/aqua/extension/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,26 @@
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import re
from typing import Optional
from urllib.parse import urlparse

from huggingface_hub import HfApi
from huggingface_hub.utils import (
GatedRepoError,
HfHubHTTPError,
RepositoryNotFoundError,
RevisionNotFoundError,
)
from tornado.web import HTTPError

from ads.aqua.common.decorator import handle_exceptions
from ads.aqua.common.errors import AquaValueError
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
from ads.aqua.extension.base_handler import AquaAPIhandler
from ads.aqua.extension.errors import Errors
from ads.aqua.model import AquaModelApp
from ads.aqua.model.constants import ModelTask
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
from ads.aqua.ui import ModelFormat


Expand Down Expand Up @@ -38,8 +49,8 @@ def get(
)
try:
model_format = ModelFormat(model_format.upper())
except ValueError:
raise AquaValueError(f"Invalid model format: {model_format}")
except ValueError as err:
raise AquaValueError(f"Invalid model format: {model_format}") from err
else:
return self.finish(AquaModelApp.get_model_files(os_path, model_format))
elif not model_id:
Expand All @@ -52,7 +63,7 @@ def read(self, model_id):
return self.finish(AquaModelApp().get(model_id))

@handle_exceptions
def delete(self, id=""):
def delete(self):
"""Handles DELETE request for clearing cache"""
url_parse = urlparse(self.request.path)
paths = url_parse.path.strip("/")
Expand Down Expand Up @@ -86,8 +97,8 @@ def post(self, *args, **kwargs):
"""
try:
input_data = self.get_json_body()
except Exception:
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
except Exception as ex:
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex

if not input_data:
raise HTTPError(400, Errors.NO_INPUT_DATA)
Expand Down Expand Up @@ -130,7 +141,141 @@ def get(self, model_id):
return self.finish(AquaModelApp().load_license(model_id))


class AquaHuggingFaceHandler(AquaAPIhandler):
"""Handler for Aqua Hugging Face REST APIs."""

@staticmethod
def _find_matching_aqua_model(model_id: str) -> Optional[AquaModelSummary]:
"""
Finds a matching model in AQUA based on the model ID from Hugging Face.
Parameters
----------
model_id (str): The Hugging Face model ID to match.
Returns
-------
Optional[AquaModelSummary]
Returns the matching AquaModelSummary object if found, else None.
"""
# Convert the Hugging Face model ID to lowercase once
model_id_lower = model_id.lower()

aqua_model_app = AquaModelApp()
model_ocid = aqua_model_app._find_matching_aqua_model(model_id=model_id_lower)
if model_ocid:
return aqua_model_app.get(model_ocid, load_model_card=False)

return None

def _format_custom_error_message(self, error: HfHubHTTPError):
"""
Formats a custom error message based on the Hugging Face error response.
Parameters
----------
error (HfHubHTTPError): The caught exception.
Raises
------
AquaRuntimeError: A user-friendly error message.
"""
# Extract the repository URL from the error message if present
match = re.search(r"(https://huggingface.co/[^\s]+)", str(error))
url = match.group(1) if match else "the requested Hugging Face URL."

if isinstance(error, RepositoryNotFoundError):
raise AquaRuntimeError(
reason=f"Failed to access `{url}`. Please check if the provided repository name is correct. "
"If the repo is private, make sure you are authenticated and have a valid HF token registered. "
"To register your token, run this command in your terminal: `huggingface-cli login`",
service_payload={"error": "RepositoryNotFoundError"},
)

if isinstance(error, GatedRepoError):
raise AquaRuntimeError(
reason=f"Access denied to `{url}` "
"This repository is gated. Access is restricted to authorized users. "
"Please request access or check with the repository administrator. "
"If you are trying to access a gated repository, ensure you have a valid HF token registered. "
"To register your token, run this command in your terminal: `huggingface-cli login`",
service_payload={"error": "GatedRepoError"},
)

if isinstance(error, RevisionNotFoundError):
raise AquaRuntimeError(
reason=f"The specified revision could not be found at `{url}` "
"Please check the revision identifier and try again.",
service_payload={"error": "RevisionNotFoundError"},
)

raise AquaRuntimeError(
reason=f"An error occurred while accessing `{url}` "
"Please check your network connection and try again. "
"If you are trying to access a gated repository, ensure you have a valid HF token registered. "
"To register your token, run this command in your terminal: `huggingface-cli login`",
service_payload={"error": "Error"},
)

@handle_exceptions
def post(self, *args, **kwargs):
"""Handles post request for the HF Models APIs
Raises
------
HTTPError
Raises HTTPError if inputs are missing or are invalid.
"""
try:
input_data = self.get_json_body()
except Exception as ex:
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex

if not input_data:
raise HTTPError(400, Errors.NO_INPUT_DATA)

model_id = input_data.get("model_id")
token = input_data.get("token")

if not model_id:
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model_id"))

# Get model info from the HF
try:
hf_model_info = HfApi(token=token).model_info(model_id)
except HfHubHTTPError as err:
raise self._format_custom_error_message(err) from err

# Check if model is not disabled
if hf_model_info.disabled:
raise AquaRuntimeError(
f"The chosen model '{hf_model_info.id}' is currently disabled and cannot be imported into AQUA. "
"Please verify the model's status on the Hugging Face Model Hub or select a different model."
)

# Check pipeline_tag, it should be `text-generation`
if (
not hf_model_info.pipeline_tag
or hf_model_info.pipeline_tag.lower() != ModelTask.TEXT_GENERATION
):
raise AquaRuntimeError(
f"Unsupported pipeline tag for the chosen model: '{hf_model_info.pipeline_tag}'. "
f"AQUA currently supports the following tasks only: {', '.join(ModelTask.values())}. "
"Please select a model with a compatible pipeline tag."
)

# Check if it is a service/verified model
aqua_model_info: AquaModelSummary = self._find_matching_aqua_model(
model_id=hf_model_info.id
)

return self.finish(
HFModelSummary(model_info=hf_model_info, aqua_model_info=aqua_model_info)
)


__handlers__ = [
("model/?([^/]*)", AquaModelHandler),
("model/?([^/]*)/license", AquaModelLicenseHandler),
("model/hf/search/?([^/]*)", AquaHuggingFaceHandler),
]
12 changes: 12 additions & 0 deletions ads/aqua/model/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from typing import List, Optional

import oci
from huggingface_hub import hf_api

from ads.aqua import logger
from ads.aqua.app import CLIBuilderMixin
from ads.aqua.common import utils
Expand Down Expand Up @@ -105,6 +107,16 @@ class HFModelContainerInfo:
finetuning_container: str = None


@dataclass(repr=False)
class HFModelSummary:
"""Represents a summary of Hugging Face model."""

model_info: hf_api.ModelInfo = field(default_factory=hf_api.ModelInfo)
aqua_model_info: Optional[AquaModelSummary] = field(
default_factory=AquaModelSummary
)


@dataclass(repr=False)
class AquaEvalFTCommon(DataClassSerializable):
"""Represents common fields for evaluation and fine-tuning."""
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ opctl = [
"py-cpuinfo",
"rich",
"fire",
"cachetools"
"cachetools",
"huggingface_hub==0.23.4"
]
optuna = ["optuna==2.9.0", "oracle_ads[viz]"]
spark = ["pyspark>=3.0.0"]
Expand Down

0 comments on commit 99c4ae4

Please sign in to comment.