-
Notifications
You must be signed in to change notification settings - Fork 8.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add analyticdb as model provider
- Loading branch information
Showing
14 changed files
with
276 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,3 +41,4 @@ | |
- mixedbread | ||
- nomic | ||
- voyage | ||
- analyticdb |
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions
28
api/core/model_runtime/model_providers/analyticdb/analyticdb.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import logging | ||
|
||
from core.model_runtime.entities.model_entities import ModelType | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AnalyticdbProvider(ModelProvider): | ||
def validate_provider_credentials(self, credentials: dict) -> None: | ||
""" | ||
Validate provider credentials | ||
if validate failed, raise exception | ||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`. | ||
""" | ||
try: | ||
model_instance = self.get_model_instance(ModelType.RERANK) | ||
|
||
# Use `bge-reranker-v2-m3` model for validate, | ||
model_instance.validate_credentials(model="bge-reranker-v2-m3", credentials=credentials) | ||
except CredentialsValidateFailedError as ex: | ||
raise ex | ||
except Exception as ex: | ||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") | ||
raise ex |
55 changes: 55 additions & 0 deletions
55
api/core/model_runtime/model_providers/analyticdb/analyticdb.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
provider: analyticdb | ||
label: | ||
en_US: Analyticdb | ||
description: | ||
en_US: Models provided by Analyticdb, such as bge-reranker-v2-m3, bge-reranker-v2-minicpm-layerwise | ||
zh_Hans: Analyticdb提供的模型,例如 bge-reranker-v2-m3,bge-reranker-v2-minicpm-layerwise | ||
icon_small: | ||
en_US: adbpg.png | ||
icon_large: | ||
en_US: adbpg.png | ||
background: "#EFF1FE" | ||
help: | ||
title: | ||
en_US: Create your instance on AnalyticDB | ||
zh_Hans: 在 AnalyticDB 创建实例 | ||
url: | ||
en_US: https://help.aliyun.com/zh/analyticdb/analyticdb-for-postgresql/user-guide/create-an-instance-instance-management | ||
supported_model_types: | ||
- rerank | ||
configurate_methods: | ||
- predefined-model | ||
provider_credential_schema: | ||
credential_form_schemas: | ||
- variable: access_key_id | ||
label: | ||
en_US: access key id | ||
type: secret-input | ||
required: true | ||
placeholder: | ||
zh_Hans: 在此输入您的 Access Key ID | ||
en_US: Enter your Access Key ID | ||
- variable: access_key_secret | ||
label: | ||
en_US: access key secret | ||
type: secret-input | ||
required: true | ||
placeholder: | ||
zh_Hans: 在此输入您的 Access Key Secret | ||
en_US: Enter your Access Key Secret | ||
- variable: instance_id | ||
label: | ||
en_US: instance id | ||
type: secret-input | ||
required: true | ||
placeholder: | ||
zh_Hans: 在此输入您的 实例ID | ||
en_US: Enter your instance ID | ||
- variable: region_id | ||
label: | ||
en_US: region id | ||
type: text-input | ||
required: true | ||
placeholder: | ||
zh_Hans: 在此输入您的 实例地域ID | ||
en_US: Enter your instance region ID |
Empty file.
4 changes: 4 additions & 0 deletions
4
api/core/model_runtime/model_providers/analyticdb/rerank/bge-reranker-v2-m3.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
model: bge-reranker-v2-m3 | ||
model_type: rerank | ||
model_properties: | ||
context_size: 8192 |
4 changes: 4 additions & 0 deletions
4
...re/model_runtime/model_providers/analyticdb/rerank/bge-reranker-v2-minicpm-layerwise.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
model: bge-reranker-v2-minicpm-layerwise | ||
model_type: rerank | ||
model_properties: | ||
context_size: 2048 |
127 changes: 127 additions & 0 deletions
127
api/core/model_runtime/model_providers/analyticdb/rerank/rerank.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from typing import Optional | ||
|
||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | ||
from alibabacloud_gpdb20160503.client import Client | ||
from alibabacloud_tea_openapi import models as open_api_models | ||
|
||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult | ||
from core.model_runtime.errors.invoke import ( | ||
InvokeAuthorizationError, | ||
InvokeBadRequestError, | ||
InvokeConnectionError, | ||
InvokeError, | ||
InvokeRateLimitError, | ||
InvokeServerUnavailableError, | ||
) | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel | ||
|
||
|
||
class AnalyticdbRerankModel(RerankModel): | ||
""" | ||
Model class for Analyticdb rerank model. | ||
""" | ||
|
||
def _build_client(self, credentials: dict) -> Client: | ||
config = { | ||
"access_key_id": credentials["access_key_id"], | ||
"access_key_secret": credentials["access_key_secret"], | ||
"region_id": credentials["region_id"], | ||
"read_timeout": 60000, | ||
"user_agent": "dify", | ||
} | ||
config = open_api_models.Config(**config) | ||
return Client(config) | ||
|
||
def validate_credentials(self, model: str, credentials: dict) -> None: | ||
""" | ||
Validate model credentials | ||
:param model: model name | ||
:param credentials: model credentials | ||
:return: | ||
""" | ||
|
||
try: | ||
self._invoke( | ||
model=model, | ||
credentials=credentials, | ||
query="What is the ADBPG?", | ||
docs=[ | ||
"Example doc 1", | ||
"Example doc 2", | ||
"Example doc 3", | ||
], | ||
) | ||
except Exception as ex: | ||
raise CredentialsValidateFailedError(str(ex)) | ||
|
||
def _invoke( | ||
self, | ||
model: str, | ||
credentials: dict, | ||
query: str, | ||
docs: list[str], | ||
score_threshold: Optional[float] = None, | ||
top_n: Optional[int] = None, | ||
user: Optional[str] = None, | ||
) -> RerankResult: | ||
""" | ||
Invoke rerank model | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param query: search query | ||
:param docs: docs for reranking | ||
:param score_threshold: score threshold | ||
:param top_n: top n documents to return | ||
:param user: unique user id | ||
:return: rerank result | ||
""" | ||
if len(docs) == 0: | ||
return RerankResult(model=model, docs=[]) | ||
|
||
client = self._build_client(credentials) | ||
request = gpdb_20160503_models.RerankRequest( | ||
dbinstance_id=credentials["instance_id"], | ||
documents=docs, | ||
query=query, | ||
model=model, | ||
region_id=credentials["region_id"], | ||
top_k=top_n or 3, | ||
) | ||
try: | ||
response = client.rerank(request) | ||
except Exception as e: | ||
raise e | ||
rerank_documents = [] | ||
if not response.body.results: | ||
raise CredentialsValidateFailedError( | ||
""" | ||
Instance ID does not exist or RAM does not have rerank permission. | ||
Visit https://ram.console.aliyun.com/ | ||
to add `gpdb:Rerank` permission. | ||
""" | ||
) | ||
for result in response.body.results.results: | ||
if score_threshold and result["RelevanceScore"] < score_threshold: | ||
continue | ||
rerank_documents.append( | ||
RerankDocument( | ||
index=result.index, | ||
score=result.relevance_score, | ||
text=docs[result.index], | ||
) | ||
) | ||
|
||
return RerankResult(model=model, docs=rerank_documents) | ||
|
||
@property | ||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | ||
return { | ||
InvokeConnectionError: [InvokeConnectionError], | ||
InvokeServerUnavailableError: [InvokeServerUnavailableError], | ||
InvokeRateLimitError: [InvokeRateLimitError], | ||
InvokeAuthorizationError: [InvokeAuthorizationError], | ||
InvokeBadRequestError: [InvokeBadRequestError], | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
21 changes: 21 additions & 0 deletions
21
api/tests/integration_tests/model_runtime/analyticdb/test_provider.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import os | ||
|
||
import pytest | ||
|
||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.analyticdb.analyticdb import AnalyticdbProvider | ||
|
||
|
||
def test_validate_provider_credentials(): | ||
provider = AnalyticdbProvider() | ||
|
||
with pytest.raises(CredentialsValidateFailedError): | ||
provider.validate_provider_credentials(credentials={}) | ||
provider.validate_provider_credentials( | ||
credentials={ | ||
"access_key_id": os.environ.get("ANALYTICDB_KEY_ID"), | ||
"access_key_secret": os.environ.get("ANALYTICDB_KEY_SECRET"), | ||
"region_id": os.environ.get("ANALYTICDB_REGION_ID"), | ||
"instance_id": os.environ.get("ANALYTICDB_INSTANCE_ID"), | ||
} | ||
) |
27 changes: 27 additions & 0 deletions
27
api/tests/integration_tests/model_runtime/analyticdb/test_rerank.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import os | ||
|
||
from core.model_runtime.entities.rerank_entities import RerankResult | ||
from core.model_runtime.model_providers.analyticdb.rerank.rerank import AnalyticdbRerankModel | ||
|
||
|
||
def test_invoke_reranker(): | ||
model = AnalyticdbRerankModel() | ||
|
||
result = model.invoke( | ||
model="bge-reranker-v2-m3", | ||
credentials={ | ||
"access_key_id": os.environ.get("ANALYTICDB_KEY_ID"), | ||
"access_key_secret": os.environ.get("ANALYTICDB_KEY_SECRET"), | ||
"region_id": os.environ.get("ANALYTICDB_REGION_ID"), | ||
"instance_id": os.environ.get("ANALYTICDB_INSTANCE_ID"), | ||
}, | ||
query="什么是文本排序模型", | ||
docs=[ | ||
"文本排序模型广泛用于搜索引擎和推荐系统中,它们根据文本相关性对候选文本进行排序", | ||
"量子计算是计算科学的一个前沿领域", | ||
"预训练语言模型的发展给文本排序模型带来了新的进展", | ||
], | ||
) | ||
|
||
assert isinstance(result, RerankResult) | ||
assert result.docs[0].index == 0 |