-
Notifications
You must be signed in to change notification settings - Fork 8.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add analyticdb as model provider
- Loading branch information
Showing
14 changed files
with
276 additions
and
56 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,3 +41,4 @@ | |
- mixedbread | ||
- nomic | ||
- voyage | ||
- analyticdb |
Empty file.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions
28
api/core/model_runtime/model_providers/analyticdb/analyticdb.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import logging | ||
|
||
from core.model_runtime.entities.model_entities import ModelType | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class AnalyticdbProvider(ModelProvider): | ||
def validate_provider_credentials(self, credentials: dict) -> None: | ||
""" | ||
Validate provider credentials | ||
if validate failed, raise exception | ||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`. | ||
""" | ||
try: | ||
model_instance = self.get_model_instance(ModelType.RERANK) | ||
|
||
# Use `bge-reranker-v2-m3` model for validate, | ||
model_instance.validate_credentials(model="bge-reranker-v2-m3", credentials=credentials) | ||
except CredentialsValidateFailedError as ex: | ||
raise ex | ||
except Exception as ex: | ||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") | ||
raise ex |
55 changes: 55 additions & 0 deletions
55
api/core/model_runtime/model_providers/analyticdb/analyticdb.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
provider: analyticdb | ||
label: | ||
en_US: Analyticdb | ||
description: | ||
en_US: Models provided by Analyticdb, such as bge-reranker-v2-m3, bge-reranker-v2-minicpm-layerwise | ||
zh_Hans: Analyticdb提供的模型,例如 bge-reranker-v2-m3,bge-reranker-v2-minicpm-layerwise | ||
icon_small: | ||
en_US: adbpg.png | ||
icon_large: | ||
en_US: adbpg.png | ||
background: "#EFF1FE" | ||
help: | ||
title: | ||
en_US: Create your instance on AnalyticDB | ||
zh_Hans: 在 AnalyticDB 创建实例 | ||
url: | ||
en_US: https://help.aliyun.com/zh/analyticdb/analyticdb-for-postgresql/user-guide/create-an-instance-instance-management | ||
supported_model_types: | ||
- rerank | ||
configurate_methods: | ||
- predefined-model | ||
provider_credential_schema: | ||
credential_form_schemas: | ||
- variable: access_key_id | ||
label: | ||
en_US: access key id | ||
type: secret-input | ||
required: true | ||
placeholder: | ||
zh_Hans: 在此输入您的 Access Key ID | ||
en_US: Enter your Access Key ID | ||
- variable: access_key_secret | ||
label: | ||
en_US: access key secret | ||
type: secret-input | ||
required: true | ||
placeholder: | ||
zh_Hans: 在此输入您的 Access Key Secret | ||
en_US: Enter your Access Key Secret | ||
- variable: instance_id | ||
label: | ||
en_US: instance id | ||
type: secret-input | ||
required: true | ||
placeholder: | ||
zh_Hans: 在此输入您的 实例ID | ||
en_US: Enter your instance ID | ||
- variable: region_id | ||
label: | ||
en_US: region id | ||
type: text-input | ||
required: true | ||
placeholder: | ||
zh_Hans: 在此输入您的 实例地域ID | ||
en_US: Enter your instance region ID |
Empty file.
4 changes: 4 additions & 0 deletions
4
api/core/model_runtime/model_providers/analyticdb/rerank/bge-reranker-v2-m3.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
model: bge-reranker-v2-m3 | ||
model_type: rerank | ||
model_properties: | ||
context_size: 8192 |
4 changes: 4 additions & 0 deletions
4
...re/model_runtime/model_providers/analyticdb/rerank/bge-reranker-v2-minicpm-layerwise.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
model: bge-reranker-v2-minicpm-layerwise | ||
model_type: rerank | ||
model_properties: | ||
context_size: 2048 |
127 changes: 127 additions & 0 deletions
127
api/core/model_runtime/model_providers/analyticdb/rerank/rerank.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from typing import Optional | ||
|
||
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models | ||
from alibabacloud_gpdb20160503.client import Client | ||
from alibabacloud_tea_openapi import models as open_api_models | ||
|
||
from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult | ||
from core.model_runtime.errors.invoke import ( | ||
InvokeAuthorizationError, | ||
InvokeBadRequestError, | ||
InvokeConnectionError, | ||
InvokeError, | ||
InvokeRateLimitError, | ||
InvokeServerUnavailableError, | ||
) | ||
from core.model_runtime.errors.validate import CredentialsValidateFailedError | ||
from core.model_runtime.model_providers.__base.rerank_model import RerankModel | ||
|
||
|
||
class AnalyticdbRerankModel(RerankModel): | ||
""" | ||
Model class for Analyticdb rerank model. | ||
""" | ||
|
||
def _build_client(self, credentials: dict) -> Client: | ||
config = { | ||
"access_key_id": credentials["access_key_id"], | ||
"access_key_secret": credentials["access_key_secret"], | ||
"region_id": credentials["region_id"], | ||
"read_timeout": 60000, | ||
"user_agent": "dify", | ||
} | ||
config = open_api_models.Config(**config) | ||
return Client(config) | ||
|
||
def validate_credentials(self, model: str, credentials: dict) -> None: | ||
""" | ||
Validate model credentials | ||
:param model: model name | ||
:param credentials: model credentials | ||
:return: | ||
""" | ||
|
||
try: | ||
self._invoke( | ||
model=model, | ||
credentials=credentials, | ||
query="What is the ADBPG?", | ||
docs=[ | ||
"Example doc 1", | ||
"Example doc 2", | ||
"Example doc 3", | ||
], | ||
) | ||
except Exception as ex: | ||
raise CredentialsValidateFailedError(str(ex)) | ||
|
||
def _invoke( | ||
self, | ||
model: str, | ||
credentials: dict, | ||
query: str, | ||
docs: list[str], | ||
score_threshold: Optional[float] = None, | ||
top_n: Optional[int] = None, | ||
user: Optional[str] = None, | ||
) -> RerankResult: | ||
""" | ||
Invoke rerank model | ||
:param model: model name | ||
:param credentials: model credentials | ||
:param query: search query | ||
:param docs: docs for reranking | ||
:param score_threshold: score threshold | ||
:param top_n: top n documents to return | ||
:param user: unique user id | ||
:return: rerank result | ||
""" | ||
if len(docs) == 0: | ||
return RerankResult(model=model, docs=[]) | ||
|
||
client = self._build_client(credentials) | ||
request = gpdb_20160503_models.RerankRequest( | ||
dbinstance_id=credentials["instance_id"], | ||
documents=docs, | ||
query=query, | ||
model=model, | ||
region_id=credentials["region_id"], | ||
top_k=top_n or 3, | ||
) | ||
try: | ||
response = client.rerank(request) | ||
except Exception as e: | ||
raise e | ||
rerank_documents = [] | ||
if not response.body.results: | ||
raise CredentialsValidateFailedError( | ||
""" | ||
Instance ID does not exist or RAM does not have rerank permission. | ||
Visit https://ram.console.aliyun.com/ | ||
to add `gpdb:Rerank` permission. | ||
""" | ||
) | ||
for result in response.body.results.results: | ||
if score_threshold and result["RelevanceScore"] < score_threshold: | ||
continue | ||
rerank_documents.append( | ||
RerankDocument( | ||
index=result.index, | ||
score=result.relevance_score, | ||
text=docs[result.index], | ||
) | ||
) | ||
|
||
return RerankResult(model=model, docs=rerank_documents) | ||
|
||
@property | ||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: | ||
return { | ||
InvokeConnectionError: [InvokeConnectionError], | ||
InvokeServerUnavailableError: [InvokeServerUnavailableError], | ||
InvokeRateLimitError: [InvokeRateLimitError], | ||
InvokeAuthorizationError: [InvokeAuthorizationError], | ||
InvokeBadRequestError: [InvokeBadRequestError], | ||
} |
Oops, something went wrong.