Skip to content

Commit

Permalink
feat: add analyticdb as model provider
Browse files Browse the repository at this point in the history
  • Loading branch information
lpdink committed Nov 18, 2024
1 parent c49efc0 commit 2dbb202
Show file tree
Hide file tree
Showing 14 changed files with 276 additions and 14 deletions.
1 change: 1 addition & 0 deletions api/core/model_runtime/model_providers/_position.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@
- mixedbread
- nomic
- voyage
- analyticdb
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions api/core/model_runtime/model_providers/analyticdb/analyticdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import logging

from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider

logger = logging.getLogger(__name__)


class AnalyticdbProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
if validate failed, raise exception
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.RERANK)

# Use `bge-reranker-v2-m3` model for validate,
model_instance.validate_credentials(model="bge-reranker-v2-m3", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex
55 changes: 55 additions & 0 deletions api/core/model_runtime/model_providers/analyticdb/analyticdb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
provider: analyticdb
label:
en_US: Analyticdb
description:
en_US: Models provided by Analyticdb, such as bge-reranker-v2-m3, bge-reranker-v2-minicpm-layerwise
zh_Hans: Analyticdb提供的模型,例如 bge-reranker-v2-m3,bge-reranker-v2-minicpm-layerwise
icon_small:
en_US: adbpg.png
icon_large:
en_US: adbpg.png
background: "#EFF1FE"
help:
title:
en_US: Create your instance on AnalyticDB
zh_Hans: 在 AnalyticDB 创建实例
url:
en_US: https://help.aliyun.com/zh/analyticdb/analyticdb-for-postgresql/user-guide/create-an-instance-instance-management
supported_model_types:
- rerank
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: access_key_id
label:
en_US: access key id
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 Access Key ID
en_US: Enter your Access Key ID
- variable: access_key_secret
label:
en_US: access key secret
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 Access Key Secret
en_US: Enter your Access Key Secret
- variable: instance_id
label:
en_US: instance id
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 实例ID
en_US: Enter your instance ID
- variable: region_id
label:
en_US: region id
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的 实例地域ID
en_US: Enter your instance region ID
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
model: bge-reranker-v2-m3
model_type: rerank
model_properties:
context_size: 8192
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
model: bge-reranker-v2-minicpm-layerwise
model_type: rerank
model_properties:
context_size: 2048
127 changes: 127 additions & 0 deletions api/core/model_runtime/model_providers/analyticdb/rerank/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Optional

from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models

from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel


class AnalyticdbRerankModel(RerankModel):
"""
Model class for Analyticdb rerank model.
"""

def _build_client(self, credentials: dict) -> Client:
config = {
"access_key_id": credentials["access_key_id"],
"access_key_secret": credentials["access_key_secret"],
"region_id": credentials["region_id"],
"read_timeout": 60000,
"user_agent": "dify",
}
config = open_api_models.Config(**config)
return Client(config)

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""

try:
self._invoke(
model=model,
credentials=credentials,
query="What is the ADBPG?",
docs=[
"Example doc 1",
"Example doc 2",
"Example doc 3",
],
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n documents to return
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])

client = self._build_client(credentials)
request = gpdb_20160503_models.RerankRequest(
dbinstance_id=credentials["instance_id"],
documents=docs,
query=query,
model=model,
region_id=credentials["region_id"],
top_k=top_n or 3,
)
try:
response = client.rerank(request)
except Exception as e:
raise e
rerank_documents = []
if not response.body.results:
raise CredentialsValidateFailedError(
"""
Instance ID does not exist or RAM does not have rerank permission.
Visit https://ram.console.aliyun.com/
to add `gpdb:Rerank` permission.
"""
)
for result in response.body.results.results:
if score_threshold and result["RelevanceScore"] < score_threshold:
continue
rerank_documents.append(
RerankDocument(
index=result.index,
score=result.relevance_score,
text=docs[result.index],
)
)

return RerankResult(model=model, docs=rerank_documents)

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [InvokeBadRequestError],
}
19 changes: 7 additions & 12 deletions api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ tos = "~2.7.1"
# Required by vector store clients
############################################################
[tool.poetry.group.vdb.dependencies]
alibabacloud_gpdb20160503 = "~3.8.0"
alibabacloud_tea_openapi = "~0.3.9"
alibabacloud-gpdb20160503 = "4.1.0"
alibabacloud-tea-openapi = "0.3.12"
chromadb = "0.5.1"
clickhouse-connect = "~0.7.16"
couchbase = "~4.3.0"
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os

import pytest

from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.analyticdb.analyticdb import AnalyticdbProvider


def test_validate_provider_credentials():
provider = AnalyticdbProvider()

with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
"access_key_id": os.environ.get("ANALYTICDB_KEY_ID"),
"access_key_secret": os.environ.get("ANALYTICDB_KEY_SECRET"),
"region_id": os.environ.get("ANALYTICDB_REGION_ID"),
"instance_id": os.environ.get("ANALYTICDB_INSTANCE_ID"),
}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.model_providers.analyticdb.rerank.rerank import AnalyticdbRerankModel


def test_invoke_reranker():
model = AnalyticdbRerankModel()

result = model.invoke(
model="bge-reranker-v2-m3",
credentials={
"access_key_id": os.environ.get("ANALYTICDB_KEY_ID"),
"access_key_secret": os.environ.get("ANALYTICDB_KEY_SECRET"),
"region_id": os.environ.get("ANALYTICDB_REGION_ID"),
"instance_id": os.environ.get("ANALYTICDB_INSTANCE_ID"),
},
query="什么是文本排序模型",
docs=[
"文本排序模型广泛用于搜索引擎和推荐系统中,它们根据文本相关性对候选文本进行排序",
"量子计算是计算科学的一个前沿领域",
"预训练语言模型的发展给文本排序模型带来了新的进展",
],
)

assert isinstance(result, RerankResult)
assert result.docs[0].index == 0

0 comments on commit 2dbb202

Please sign in to comment.