Skip to content

Commit

Permalink
Merge branch 'main' into feature/ad_merlion
Browse files Browse the repository at this point in the history
  • Loading branch information
codeloop committed Oct 16, 2024
2 parents 4679082 + a37882c commit 19269a3
Show file tree
Hide file tree
Showing 71 changed files with 19,740 additions and 1,914 deletions.
12 changes: 12 additions & 0 deletions THIRD_PARTY_LICENSES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,18 @@ langchain
* Source code: https://github.com/langchain-ai/langchain
* Project home: https://www.langchain.com/

langchain-community
* Copyright (c) 2023 LangChain, Inc.
* License: MIT license
* Source code: https://github.com/langchain-ai/langchain/tree/master/libs/community
* Project home: https://github.com/langchain-ai/langchain/tree/master/libs/community

langchain-openai
* Copyright (c) 2023 LangChain, Inc.
* License: MIT license
* Source code: https://github.com/langchain-ai/langchain/tree/master/libs/partners/openai
* Project home: https://github.com/langchain-ai/langchain/tree/master/libs/partners/openai

lightgbm
* Copyright (c) 2023 Microsoft Corporation
* License: MIT license
Expand Down
1 change: 1 addition & 0 deletions ads/aqua/config/evaluation/evaluation_service_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def search_shapes(

class Config:
extra = "ignore"
protected_namespaces = ()


class EvaluationServiceConfig(Serializable):
Expand Down
38 changes: 17 additions & 21 deletions ads/aqua/extension/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ads.aqua.extension.base_handler import AquaAPIhandler
from ads.aqua.extension.errors import Errors
from ads.aqua.model import AquaModelApp
from ads.aqua.model.constants import ModelTask
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
from ads.aqua.ui import ModelFormat

Expand Down Expand Up @@ -68,7 +67,7 @@ def read(self, model_id):
return self.finish(AquaModelApp().get(model_id))

@handle_exceptions
def delete(self):
def delete(self, id=""):
"""Handles DELETE request for clearing cache"""
url_parse = urlparse(self.request.path)
paths = url_parse.path.strip("/")
Expand Down Expand Up @@ -177,10 +176,8 @@ def _find_matching_aqua_model(model_id: str) -> Optional[AquaModelSummary]:

return None



@handle_exceptions
def get(self,*args, **kwargs):
def get(self, *args, **kwargs):
"""
Finds a list of matching models from hugging face based on query string provided from users.
Expand All @@ -194,13 +191,11 @@ def get(self,*args, **kwargs):
Returns the matching model ids string
"""

query=self.get_argument("query",default=None)
query = self.get_argument("query", default=None)
if not query:
raise HTTPError(400,Errors.MISSING_REQUIRED_PARAMETER.format("query"))
models=list_hf_models(query)
return self.finish({"models":models})


raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("query"))
models = list_hf_models(query)
return self.finish({"models": models})

@handle_exceptions
def post(self, *args, **kwargs):
Expand Down Expand Up @@ -234,16 +229,17 @@ def post(self, *args, **kwargs):
"Please verify the model's status on the Hugging Face Model Hub or select a different model."
)

# Check pipeline_tag, it should be `text-generation`
if (
not hf_model_info.pipeline_tag
or hf_model_info.pipeline_tag.lower() != ModelTask.TEXT_GENERATION
):
raise AquaRuntimeError(
f"Unsupported pipeline tag for the chosen model: '{hf_model_info.pipeline_tag}'. "
f"AQUA currently supports the following tasks only: {', '.join(ModelTask.values())}. "
"Please select a model with a compatible pipeline tag."
)
# Commented the validation below to let users to register any model type.
# # Check pipeline_tag, it should be `text-generation`
# if not (
# hf_model_info.pipeline_tag
# and hf_model_info.pipeline_tag.lower() in ModelTask
# ):
# raise AquaRuntimeError(
# f"Unsupported pipeline tag for the chosen model: '{hf_model_info.pipeline_tag}'. "
# f"AQUA currently supports the following tasks only: {', '.join(ModelTask.values())}. "
# "Please select a model with a compatible pipeline tag."
# )

# Check if it is a service/verified model
aqua_model_info: AquaModelSummary = self._find_matching_aqua_model(
Expand Down
4 changes: 3 additions & 1 deletion ads/aqua/model/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

Expand All @@ -9,6 +8,7 @@
This module contains constants/enums used in Aqua Model.
"""

from ads.common.extended_enum import ExtendedEnumMeta


Expand All @@ -21,6 +21,8 @@ class ModelCustomMetadataFields(str, metaclass=ExtendedEnumMeta):

class ModelTask(str, metaclass=ExtendedEnumMeta):
TEXT_GENERATION = "text-generation"
IMAGE_TEXT_TO_TEXT = "image-text-to-text"
IMAGE_TO_TEXT = "image-to-text"


class FineTuningMetricCategories(str, metaclass=ExtendedEnumMeta):
Expand Down
14 changes: 10 additions & 4 deletions ads/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@

try:
import langchain
from ads.llm.langchain.plugins.llm_gen_ai import GenerativeAI
from ads.llm.langchain.plugins.llm_md import ModelDeploymentTGI
from ads.llm.langchain.plugins.llm_md import ModelDeploymentVLLM
from ads.llm.langchain.plugins.embeddings import GenerativeAIEmbeddings
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
OCIModelDeploymentVLLM,
OCIModelDeploymentTGI,
)
from ads.llm.langchain.plugins.chat_models.oci_data_science import (
ChatOCIModelDeployment,
ChatOCIModelDeploymentVLLM,
ChatOCIModelDeploymentTGI,
)
from ads.llm.chat_template import ChatTemplates
except ImportError as ex:
if ex.name == "langchain":
raise ImportError(
Expand Down
31 changes: 31 additions & 0 deletions ads/llm/chat_template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


import os


class ChatTemplates:
"""Contains chat templates."""

@staticmethod
def _read_template(filename):
with open(
os.path.join(os.path.dirname(__file__), "templates", filename),
mode="r",
encoding="utf-8",
) as f:
return f.read()

@staticmethod
def mistral():
"""Chat template for auto tool calling with Mistral model deploy with vLLM."""
return ChatTemplates._read_template("tool_chat_template_mistral_parallel.jinja")

@staticmethod
def hermes():
"""Chat template for auto tool calling with Hermes model deploy with vLLM."""
return ChatTemplates._read_template("tool_chat_template_hermes.jinja")
5 changes: 3 additions & 2 deletions ads/llm/guardrails/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Any, List, Dict, Tuple
from langchain.schema.prompt import PromptValue
from langchain.tools.base import BaseTool, ToolException
from langchain.pydantic_v1 import BaseModel, root_validator
from pydantic import BaseModel, model_validator


class RunInfo(BaseModel):
Expand Down Expand Up @@ -190,7 +190,8 @@ class Config:
This is used by the ``apply_filter()`` method.
"""

@root_validator
@model_validator(mode="before")
@classmethod
def default_name(cls, values):
"""Sets the default name of the guardrail."""
if not values.get("name"):
Expand Down
2 changes: 1 addition & 1 deletion ads/llm/guardrails/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


import evaluate
from langchain.pydantic_v1 import root_validator
from pydantic.v1 import root_validator
from .base import Guardrail


Expand Down
118 changes: 0 additions & 118 deletions ads/llm/langchain/plugins/base.py

This file was deleted.

5 changes: 5 additions & 0 deletions ads/llm/langchain/plugins/chat_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
Loading

0 comments on commit 19269a3

Please sign in to comment.