Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add analyticdb as model provider #9220

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/core/model_runtime/model_providers/_position.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@
- mixedbread
- nomic
- voyage
- analyticdb
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions api/core/model_runtime/model_providers/analyticdb/analyticdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import logging

from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider

logger = logging.getLogger(__name__)


class AnalyticdbProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials

if validate failed, raise exception

:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.RERANK)

# Use `bge-reranker-v2-m3` model for validate,
model_instance.validate_credentials(model="bge-reranker-v2-m3", credentials=credentials)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
raise ex
55 changes: 55 additions & 0 deletions api/core/model_runtime/model_providers/analyticdb/analyticdb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
provider: analyticdb
label:
en_US: Analyticdb
description:
en_US: Models provided by Analyticdb, such as bge-reranker-v2-m3, bge-reranker-v2-minicpm-layerwise
zh_Hans: Analyticdb提供的模型,例如 bge-reranker-v2-m3,bge-reranker-v2-minicpm-layerwise
icon_small:
en_US: adbpg.png
icon_large:
en_US: adbpg.png
background: "#EFF1FE"
help:
title:
en_US: Create your instance on AnalyticDB
zh_Hans: 在 AnalyticDB 创建实例
url:
en_US: https://help.aliyun.com/zh/analyticdb/analyticdb-for-postgresql/user-guide/create-an-instance-instance-management
supported_model_types:
- rerank
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: access_key_id
label:
en_US: access key id
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 Access Key ID
en_US: Enter your Access Key ID
- variable: access_key_secret
label:
en_US: access key secret
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 Access Key Secret
en_US: Enter your Access Key Secret
- variable: instance_id
label:
en_US: instance id
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 实例ID
en_US: Enter your instance ID
- variable: region_id
label:
en_US: region id
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的 实例地域ID
en_US: Enter your instance region ID
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
model: bge-reranker-v2-m3
model_type: rerank
model_properties:
context_size: 8192
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
model: bge-reranker-v2-minicpm-layerwise
model_type: rerank
model_properties:
context_size: 2048
127 changes: 127 additions & 0 deletions api/core/model_runtime/model_providers/analyticdb/rerank/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from typing import Optional

from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
from alibabacloud_gpdb20160503.client import Client
from alibabacloud_tea_openapi import models as open_api_models

from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel


class AnalyticdbRerankModel(RerankModel):
"""
Model class for Analyticdb rerank model.
"""

def _build_client(self, credentials: dict) -> Client:
config = {
"access_key_id": credentials["access_key_id"],
"access_key_secret": credentials["access_key_secret"],
"region_id": credentials["region_id"],
"read_timeout": 60000,
"user_agent": "dify",
}
config = open_api_models.Config(**config)
return Client(config)

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials

:param model: model name
:param credentials: model credentials
:return:
"""

try:
self._invoke(
model=model,
credentials=credentials,
query="What is the ADBPG?",
docs=[
"Example doc 1",
"Example doc 2",
"Example doc 3",
],
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model

:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n documents to return
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=[])

client = self._build_client(credentials)
request = gpdb_20160503_models.RerankRequest(
dbinstance_id=credentials["instance_id"],
documents=docs,
query=query,
model=model,
region_id=credentials["region_id"],
top_k=top_n or 3,
)
try:
response = client.rerank(request)
except Exception as e:
raise e
rerank_documents = []
if not response.body.results:
raise CredentialsValidateFailedError(
"""
Instance ID does not exist or RAM does not have rerank permission.
Visit https://ram.console.aliyun.com/
to add `gpdb:Rerank` permission.
"""
)
for result in response.body.results.results:
if score_threshold and result["RelevanceScore"] < score_threshold:
continue
rerank_documents.append(
RerankDocument(
index=result.index,
score=result.relevance_score,
text=docs[result.index],
)
)

return RerankResult(model=model, docs=rerank_documents)

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
return {
InvokeConnectionError: [InvokeConnectionError],
InvokeServerUnavailableError: [InvokeServerUnavailableError],
InvokeRateLimitError: [InvokeRateLimitError],
InvokeAuthorizationError: [InvokeAuthorizationError],
InvokeBadRequestError: [InvokeBadRequestError],
}
19 changes: 7 additions & 12 deletions api/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ tos = "~2.7.1"
# Required by vector store clients
############################################################
[tool.poetry.group.vdb.dependencies]
alibabacloud_gpdb20160503 = "~3.8.0"
alibabacloud_tea_openapi = "~0.3.9"
alibabacloud-gpdb20160503 = "4.1.0"
alibabacloud-tea-openapi = "0.3.12"
chromadb = "0.5.1"
clickhouse-connect = "~0.7.16"
couchbase = "~4.3.0"
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os

import pytest

from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.analyticdb.analyticdb import AnalyticdbProvider


def test_validate_provider_credentials():
provider = AnalyticdbProvider()

with pytest.raises(CredentialsValidateFailedError):
provider.validate_provider_credentials(credentials={})
provider.validate_provider_credentials(
credentials={
"access_key_id": os.environ.get("ANALYTICDB_KEY_ID"),
"access_key_secret": os.environ.get("ANALYTICDB_KEY_SECRET"),
"region_id": os.environ.get("ANALYTICDB_REGION_ID"),
"instance_id": os.environ.get("ANALYTICDB_INSTANCE_ID"),
}
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.model_providers.analyticdb.rerank.rerank import AnalyticdbRerankModel


def test_invoke_reranker():
model = AnalyticdbRerankModel()

result = model.invoke(
model="bge-reranker-v2-m3",
credentials={
"access_key_id": os.environ.get("ANALYTICDB_KEY_ID"),
"access_key_secret": os.environ.get("ANALYTICDB_KEY_SECRET"),
"region_id": os.environ.get("ANALYTICDB_REGION_ID"),
"instance_id": os.environ.get("ANALYTICDB_INSTANCE_ID"),
},
query="什么是文本排序模型",
docs=[
"文本排序模型广泛用于搜索引擎和推荐系统中,它们根据文本相关性对候选文本进行排序",
"量子计算是计算科学的一个前沿领域",
"预训练语言模型的发展给文本排序模型带来了新的进展",
],
)

assert isinstance(result, RerankResult)
assert result.docs[0].index == 0