diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index 89fccef6598fdd..2b5d76804e7b35 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -41,3 +41,4 @@ - mixedbread - nomic - voyage +- analyticdb diff --git a/api/core/model_runtime/model_providers/analyticdb/__init__.py b/api/core/model_runtime/model_providers/analyticdb/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/analyticdb/_assets/adbpg.png b/api/core/model_runtime/model_providers/analyticdb/_assets/adbpg.png new file mode 100644 index 00000000000000..95fd8fcec897bd Binary files /dev/null and b/api/core/model_runtime/model_providers/analyticdb/_assets/adbpg.png differ diff --git a/api/core/model_runtime/model_providers/analyticdb/analyticdb.py b/api/core/model_runtime/model_providers/analyticdb/analyticdb.py new file mode 100644 index 00000000000000..ae39ebeb81eaf9 --- /dev/null +++ b/api/core/model_runtime/model_providers/analyticdb/analyticdb.py @@ -0,0 +1,28 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class AnalyticdbProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.RERANK) + + # Use `bge-reranker-v2-m3` model for validate, + model_instance.validate_credentials(model="bge-reranker-v2-m3", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/analyticdb/analyticdb.yaml b/api/core/model_runtime/model_providers/analyticdb/analyticdb.yaml new file mode 100644 index 00000000000000..f59d10a87518b2 --- /dev/null +++ b/api/core/model_runtime/model_providers/analyticdb/analyticdb.yaml @@ -0,0 +1,55 @@ +provider: analyticdb +label: + en_US: Analyticdb +description: + en_US: Models provided by Analyticdb, such as bge-reranker-v2-m3, bge-reranker-v2-minicpm-layerwise + zh_Hans: Analyticdb提供的模型,例如 bge-reranker-v2-m3,bge-reranker-v2-minicpm-layerwise +icon_small: + en_US: adbpg.png +icon_large: + en_US: adbpg.png +background: "#EFF1FE" +help: + title: + en_US: Create your instance on AnalyticDB + zh_Hans: 在 AnalyticDB 创建实例 + url: + en_US: https://help.aliyun.com/zh/analyticdb/analyticdb-for-postgresql/user-guide/create-an-instance-instance-management +supported_model_types: + - rerank +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: access_key_id + label: + en_US: access key id + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 Access Key ID + en_US: Enter your Access Key ID + - variable: access_key_secret + label: + en_US: access key secret + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 Access Key Secret + en_US: Enter your Access Key Secret + - variable: instance_id + label: + en_US: instance id + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 实例ID + en_US: Enter your instance ID + - variable: region_id + label: + en_US: region id + type: text-input + required: true + placeholder: + zh_Hans: 在此输入您的 实例地域ID + en_US: Enter your instance region ID diff --git a/api/core/model_runtime/model_providers/analyticdb/rerank/__init__.py b/api/core/model_runtime/model_providers/analyticdb/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/analyticdb/rerank/bge-reranker-v2-m3.yaml b/api/core/model_runtime/model_providers/analyticdb/rerank/bge-reranker-v2-m3.yaml new file mode 100644 index 00000000000000..a623a8738b5d46 --- /dev/null +++ b/api/core/model_runtime/model_providers/analyticdb/rerank/bge-reranker-v2-m3.yaml @@ -0,0 +1,4 @@ +model: bge-reranker-v2-m3 +model_type: rerank +model_properties: + context_size: 8192 diff --git a/api/core/model_runtime/model_providers/analyticdb/rerank/bge-reranker-v2-minicpm-layerwise.yaml b/api/core/model_runtime/model_providers/analyticdb/rerank/bge-reranker-v2-minicpm-layerwise.yaml new file mode 100644 index 00000000000000..8727b107e9fad2 --- /dev/null +++ b/api/core/model_runtime/model_providers/analyticdb/rerank/bge-reranker-v2-minicpm-layerwise.yaml @@ -0,0 +1,4 @@ +model: bge-reranker-v2-minicpm-layerwise +model_type: rerank +model_properties: + context_size: 2048 diff --git a/api/core/model_runtime/model_providers/analyticdb/rerank/rerank.py b/api/core/model_runtime/model_providers/analyticdb/rerank/rerank.py new file mode 100644 index 00000000000000..5d87c1fcc2490d --- /dev/null +++ b/api/core/model_runtime/model_providers/analyticdb/rerank/rerank.py @@ -0,0 +1,127 @@ +from typing import Optional + +from alibabacloud_gpdb20160503 import models as gpdb_20160503_models +from alibabacloud_gpdb20160503.client import Client +from alibabacloud_tea_openapi import models as open_api_models + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class AnalyticdbRerankModel(RerankModel): + """ + Model class for Analyticdb rerank model. + """ + + def _build_client(self, credentials: dict) -> Client: + config = { + "access_key_id": credentials["access_key_id"], + "access_key_secret": credentials["access_key_secret"], + "region_id": credentials["region_id"], + "read_timeout": 60000, + "user_agent": "dify", + } + config = open_api_models.Config(**config) + return Client(config) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + + try: + self._invoke( + model=model, + credentials=credentials, + query="What is the ADBPG?", + docs=[ + "Example doc 1", + "Example doc 2", + "Example doc 3", + ], + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + client = self._build_client(credentials) + request = gpdb_20160503_models.RerankRequest( + dbinstance_id=credentials["instance_id"], + documents=docs, + query=query, + model=model, + region_id=credentials["region_id"], + top_k=top_n or 3, + ) + try: + response = client.rerank(request) + except Exception as e: + raise e + rerank_documents = [] + if not response.body.results: + raise CredentialsValidateFailedError( + """ + Instance ID does not exist or RAM does not have rerank permission. + Visit https://ram.console.aliyun.com/ + to add `gpdb:Rerank` permission. + """ + ) + for result in response.body.results.results: + if score_threshold and result["RelevanceScore"] < score_threshold: + continue + rerank_documents.append( + RerankDocument( + index=result.index, + score=result.relevance_score, + text=docs[result.index], + ) + ) + + return RerankResult(model=model, docs=rerank_documents) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [InvokeConnectionError], + InvokeServerUnavailableError: [InvokeServerUnavailableError], + InvokeRateLimitError: [InvokeRateLimitError], + InvokeAuthorizationError: [InvokeAuthorizationError], + InvokeBadRequestError: [InvokeBadRequestError], + } diff --git a/api/poetry.lock b/api/poetry.lock index 6021ae5c740ab7..5cbe17256d6606 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -229,13 +229,13 @@ alibabacloud_credentials = ">=0.3.4,<1.0.0" [[package]] name = "alibabacloud-gpdb20160503" -version = "3.8.3" +version = "4.1.0" description = "Alibaba Cloud AnalyticDB for PostgreSQL (20160503) SDK Library for Python" optional = false python-versions = ">=3.6" files = [ - {file = "alibabacloud_gpdb20160503-3.8.3-py3-none-any.whl", hash = "sha256:06e1c46ce5e4e9d1bcae76e76e51034196c625799d06b2efec8d46a7df323fe8"}, - {file = "alibabacloud_gpdb20160503-3.8.3.tar.gz", hash = "sha256:4dfcc0d9cff5a921d529d76f4bf97e2ceb9dc2fa53f00ab055f08509423d8e30"}, + {file = "alibabacloud_gpdb20160503-4.1.0-py3-none-any.whl", hash = "sha256:eeed96abce482e331293241e8339f6b41fa78c284122c11aac60ec0f18ebc8c1"}, + {file = "alibabacloud_gpdb20160503-4.1.0.tar.gz", hash = "sha256:837fbf288f4af0fcbad5e424ca65f47662f6b9da0ca7ae2e0226bdd9350bedd7"}, ] [package.dependencies] @@ -245,8 +245,8 @@ alibabacloud-openplatform20191219 = ">=2.0.0,<3.0.0" alibabacloud-oss-sdk = ">=0.1.0,<1.0.0" alibabacloud-oss-util = ">=0.0.5,<1.0.0" alibabacloud-tea-fileform = ">=0.0.3,<1.0.0" -alibabacloud-tea-openapi = ">=0.3.10,<1.0.0" -alibabacloud-tea-util = ">=0.3.12,<1.0.0" +alibabacloud-tea-openapi = ">=0.3.11,<1.0.0" +alibabacloud-tea-util = ">=0.3.13,<1.0.0" [[package]] name = "alibabacloud-openapi-util" @@ -8761,11 +8761,6 @@ files = [ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, - {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, - {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, - {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, - {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, - {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, @@ -11041,4 +11036,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "69a3f471f85dce9e5fb889f739e148a4a6d95aaf94081414503867c7157dba69" +content-hash = "f1979c80c3856f35c4207f57abe1b3772a08b0fb7cbac1540469d24078768826" diff --git a/api/pyproject.toml b/api/pyproject.toml index 0d87c1b1c8988f..0a73401b6a4aef 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -240,8 +240,8 @@ tos = "~2.7.1" # Required by vector store clients ############################################################ [tool.poetry.group.vdb.dependencies] -alibabacloud_gpdb20160503 = "~3.8.0" -alibabacloud_tea_openapi = "~0.3.9" +alibabacloud-gpdb20160503 = "4.1.0" +alibabacloud-tea-openapi = "0.3.12" chromadb = "0.5.1" clickhouse-connect = "~0.7.16" couchbase = "~4.3.0" diff --git a/api/tests/integration_tests/model_runtime/analyticdb/__init__.py b/api/tests/integration_tests/model_runtime/analyticdb/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/model_runtime/analyticdb/test_provider.py b/api/tests/integration_tests/model_runtime/analyticdb/test_provider.py new file mode 100644 index 00000000000000..949a5bdffa7fa6 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/analyticdb/test_provider.py @@ -0,0 +1,21 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.analyticdb.analyticdb import AnalyticdbProvider + + +def test_validate_provider_credentials(): + provider = AnalyticdbProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials(credentials={}) + provider.validate_provider_credentials( + credentials={ + "access_key_id": os.environ.get("ANALYTICDB_KEY_ID"), + "access_key_secret": os.environ.get("ANALYTICDB_KEY_SECRET"), + "region_id": os.environ.get("ANALYTICDB_REGION_ID"), + "instance_id": os.environ.get("ANALYTICDB_INSTANCE_ID"), + } + ) diff --git a/api/tests/integration_tests/model_runtime/analyticdb/test_rerank.py b/api/tests/integration_tests/model_runtime/analyticdb/test_rerank.py new file mode 100644 index 00000000000000..ba66e889cb851d --- /dev/null +++ b/api/tests/integration_tests/model_runtime/analyticdb/test_rerank.py @@ -0,0 +1,27 @@ +import os + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.model_providers.analyticdb.rerank.rerank import AnalyticdbRerankModel + + +def test_invoke_reranker(): + model = AnalyticdbRerankModel() + + result = model.invoke( + model="bge-reranker-v2-m3", + credentials={ + "access_key_id": os.environ.get("ANALYTICDB_KEY_ID"), + "access_key_secret": os.environ.get("ANALYTICDB_KEY_SECRET"), + "region_id": os.environ.get("ANALYTICDB_REGION_ID"), + "instance_id": os.environ.get("ANALYTICDB_INSTANCE_ID"), + }, + query="什么是文本排序模型", + docs=[ + "文本排序模型广泛用于搜索引擎和推荐系统中,它们根据文本相关性对候选文本进行排序", + "量子计算是计算科学的一个前沿领域", + "预训练语言模型的发展给文本排序模型带来了新的进展", + ], + ) + + assert isinstance(result, RerankResult) + assert result.docs[0].index == 0