From 72ebc3fd3f99e78bae73d7aefe59fba3417bd1db Mon Sep 17 00:00:00 2001 From: caoshili Date: Thu, 10 Oct 2024 16:41:40 +0800 Subject: [PATCH 1/5] feat:support baidu vector db --- api/.env.example | 9 + api/commands.py | 8 + .../middleware/vdb/baidu_vector_config.py | 45 +++ api/controllers/console/datasets/datasets.py | 2 + api/core/rag/datasource/vdb/baidu/__init__.py | 0 .../rag/datasource/vdb/baidu/baidu_vector.py | 273 ++++++++++++++++++ api/core/rag/datasource/vdb/vector_factory.py | 4 + api/core/rag/datasource/vdb/vector_type.py | 1 + api/poetry.lock | 29 +- api/pyproject.toml | 2 + .../vdb/__mock/baiduvectordb.py | 155 ++++++++++ .../integration_tests/vdb/baidu/__init__.py | 0 .../integration_tests/vdb/baidu/test_baidu.py | 36 +++ dev/pytest/pytest_vdb.sh | 3 +- docker/.env.example | 9 + docker/docker-compose.yaml | 7 + 16 files changed, 581 insertions(+), 2 deletions(-) create mode 100644 api/configs/middleware/vdb/baidu_vector_config.py create mode 100644 api/core/rag/datasource/vdb/baidu/__init__.py create mode 100644 api/core/rag/datasource/vdb/baidu/baidu_vector.py create mode 100644 api/tests/integration_tests/vdb/__mock/baiduvectordb.py create mode 100644 api/tests/integration_tests/vdb/baidu/__init__.py create mode 100644 api/tests/integration_tests/vdb/baidu/test_baidu.py diff --git a/api/.env.example b/api/.env.example index 71f0e5db8f8b9b..164fedb8847c57 100644 --- a/api/.env.example +++ b/api/.env.example @@ -203,6 +203,15 @@ OPENSEARCH_USER=admin OPENSEARCH_PASSWORD=admin OPENSEARCH_SECURE=true +# Baidu configuration +BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287 +BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000 +BAIDU_VECTOR_DB_ACCOUNT=root +BAIDU_VECTOR_DB_API_KEY=dify +BAIDU_VECTOR_DB_DATABASE=dify +BAIDU_VECTOR_DB_SHARD=1 +BAIDU_VECTOR_DB_REPLICAS=3 + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 UPLOAD_FILE_BATCH_LIMIT=5 diff --git a/api/commands.py b/api/commands.py index 7ef4aed7f77664..dbcd8a744d3a45 100644 --- a/api/commands.py +++ b/api/commands.py @@ -347,6 +347,14 @@ def migrate_knowledge_vector_database(): index_name = Dataset.gen_collection_name_by_id(dataset_id) index_struct_dict = {"type": "elasticsearch", "vector_store": {"class_prefix": index_name}} dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type == VectorType.BAIDU: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + index_struct_dict = { + "type": VectorType.BAIDU, + "vector_store": {"class_prefix": collection_name}, + } + dataset.index_struct = json.dumps(index_struct_dict) else: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/configs/middleware/vdb/baidu_vector_config.py b/api/configs/middleware/vdb/baidu_vector_config.py new file mode 100644 index 00000000000000..64c5504b1f8aa5 --- /dev/null +++ b/api/configs/middleware/vdb/baidu_vector_config.py @@ -0,0 +1,45 @@ +from typing import Optional + +from pydantic import Field, NonNegativeInt, PositiveInt +from pydantic_settings import BaseSettings + + +class BaiduVectorDBConfig(BaseSettings): + """ + Configuration settings for Baidu Vector Database + """ + + BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field( + description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')", + default=None, + ) + + BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field( + description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)", + default=30, + ) + + BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field( + description="Account for authenticating with the Baidu Vector Database", + default=None, + ) + + BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field( + description="API key for authenticating with the Baidu Vector Database service", + default=None, + ) + + BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field( + description="Name of the specific Baidu Vector Database to connect to", + default=None, + ) + + BAIDU_VECTOR_DB_SHARD: PositiveInt = Field( + description="Number of shards for the Baidu Vector Database (default is 1)", + default=1, + ) + + BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field( + description="Number of replicas for the Baidu Vector Database (default is 3)", + default=3, + ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 9561fd8b70e4b9..102089bf071ac2 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -617,6 +617,7 @@ def get(self): | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS + | VectorType.BAIDU ): return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} case ( @@ -653,6 +654,7 @@ def get(self, vector_type): | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS + | VectorType.BAIDU ): return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} case ( diff --git a/api/core/rag/datasource/vdb/baidu/__init__.py b/api/core/rag/datasource/vdb/baidu/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py new file mode 100644 index 00000000000000..e599e03c06d2a6 --- /dev/null +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -0,0 +1,273 @@ +import json +import uuid +import time +from typing import Any + +from pydantic import BaseModel, model_validator + +from pymochow import MochowClient +from pymochow.configuration import Configuration +from pymochow.auth.bce_credentials import BceCredentials +from pymochow.model.enum import FieldType, IndexType, MetricType, TableState, IndexState +from pymochow.model.schema import Schema, Field, VectorIndex, HNSWParams +from pymochow.model.table import Partition, Row, AnnSearch, HNSWSearchParams + +from configs import dify_config +from core.rag.datasource.entity.embedding import Embeddings +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + + +class BaiduConfig(BaseModel): + endpoint: str + connection_timeout_in_mills: int = 30 * 1000 + account: str + api_key: str + database: str + index_type: str = "HNSW" + metric_type: str = "L2" + shard: int = 1 + replicas: int = 3 + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + if not values["endpoint"]: + raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required") + if not values["account"]: + raise ValueError("config BAIDU_VECTOR_DB_ACCOUNT is required") + if not values["api_key"]: + raise ValueError("config BAIDU_VECTOR_DB_API_KEY is required") + if not values["database"]: + raise ValueError("config BAIDU_VECTOR_DB_DATABASE is required") + return values + + +class BaiduVector(BaseVector): + field_id: str = "id" + field_vector: str = "vector" + field_text: str = "text" + field_metadata: str = "metadata" + field_app_id: str = "app_id" + field_annotation_id: str = "annotation_id" + index_vector: str = "vector_idx" + + def __init__(self, collection_name: str, config: BaiduConfig): + super().__init__(collection_name) + self._client_config = config + self._client = self._init_client(config) + self._db = self._init_database() + + def get_type(self) -> str: + return VectorType.BAIDU + + def to_index_struct(self) -> dict: + return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + self._create_table(len(embeddings[0])) + self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + texts = [doc.page_content for doc in documents] + metadatas = [doc.metadata for doc in documents] + total_count = len(documents) + batch_size = 1000 + + # upsert texts and embeddings batch by batch + table = self._db.table(self._collection_name) + for start in range(0, total_count, batch_size): + end = min(start + batch_size, total_count) + rows = [] + for i in range(start, end, 1): + row = Row( + id=metadatas[i].get("doc_id", str(uuid.uuid4())), + vector=embeddings[i], + text=texts[i], + metadata=json.dumps(metadatas[i]), + app_id=metadatas[i].get("app_id", ""), + annotation_id=metadatas[i].get("annotation_id", ""), + ) + rows.append(row) + table.upsert(rows=rows) + + # rebuild vector index after upsert finished + table.rebuild_index(self.index_vector) + while True: + time.sleep(1) + index = table.describe_index(self.index_vector) + if index.state == IndexState.NORMAL: + break + + def text_exists(self, id: str) -> bool: + res = self._db.table(self._collection_name).query(primary_key={self.field_id: id}) + if res and res.code == 0: + return True + return False + + def delete_by_ids(self, ids: list[str]) -> None: + quoted_ids = [f"'{id}'" for id in ids] + self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})") + + def delete_by_metadata_field(self, key: str, value: str) -> None: + self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'") + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + anns = AnnSearch( + vector_field=self.field_vector, + vector_floats=query_vector, + params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), + ) + res = self._db.table(self._collection_name).search( + anns=anns, + projections=[self.field_id, self.field_text, self.field_metadata], + retrieve_vector=True, + ) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + return self._get_search_res(res, score_threshold) + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + # baidu vector database doesn't support bm25 search on current version + return [] + + def _get_search_res(self, res, score_threshold): + docs = [] + for row in res.rows: + row_data = row.get("row", {}) + meta = row_data.get(self.field_metadata) + if meta is not None: + meta = json.loads(meta) + score = row.get("score", 0.0) + if score > score_threshold: + meta["score"] = score + doc = Document(page_content=row_data.get(self.field_text), metadata=meta) + docs.append(doc) + + return docs + + def delete(self) -> None: + self._db.drop_table(table_name=self._collection_name) + + def _init_client(self, config) -> MochowClient: + config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint) + client = MochowClient(config) + return client + + def _init_database(self): + exists = False + for db in self._client.list_databases(): + if db.database_name == self._client_config.database: + exists = True + break + # Create database if not existed + if exists: + return self._client.database(self._client_config.database) + else: + return self._client.create_database(database_name=self._client_config.database) + + def _table_existed(self) -> bool: + tables = self._db.list_table() + return any(table.table_name == self._collection_name for table in tables) + + def _create_table(self, dimension: int) -> None: + # Try to grab distributed lock and create table + lock_name = "vector_indexing_lock_{}".format(self._collection_name) + with redis_client.lock(lock_name, timeout=20): + table_exist_cache_key = "vector_indexing_{}".format(self._collection_name) + if redis_client.get(table_exist_cache_key): + return + + if self._table_existed(): + return + + self.delete() + + # check IndexType and MetricType + index_type = None + for k, v in IndexType.__members__.items(): + if k == self._client_config.index_type: + index_type = v + if index_type is None: + raise ValueError("unsupported index_type") + metric_type = None + for k, v in MetricType.__members__.items(): + if k == self._client_config.metric_type: + metric_type = v + if metric_type is None: + raise ValueError("unsupported metric_type") + + # Construct field schema + fields = [] + fields.append( + Field( + self.field_id, + FieldType.STRING, + primary_key=True, + partition_key=True, + auto_increment=False, + not_null=True, + ) + ) + fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True)) + fields.append(Field(self.field_app_id, FieldType.STRING)) + fields.append(Field(self.field_annotation_id, FieldType.STRING)) + fields.append(Field(self.field_text, FieldType.TEXT, not_null=True)) + fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension)) + + # Construct vector index params + indexes = [] + indexes.append( + VectorIndex( + index_name="vector_idx", + index_type=index_type, + field="vector", + metric_type=metric_type, + params=HNSWParams(m=16, efconstruction=200), + ) + ) + + # Create table + self._db.create_table( + table_name=self._collection_name, + replication=self._client_config.replicas, + partition=Partition(partition_num=self._client_config.shard), + schema=Schema(fields=fields, indexes=indexes), + description="Table for Dify", + ) + + redis_client.set(table_exist_cache_key, 1, ex=3600) + + # Wait for table created + while True: + time.sleep(1) + table = self._db.describe_table(self._collection_name) + if table.state == TableState.NORMAL: + break + + +class BaiduVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix.lower() + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.BAIDU, collection_name)) + + return BaiduVector( + collection_name=collection_name, + config=BaiduConfig( + endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT, + connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS, + account=dify_config.BAIDU_VECTOR_DB_ACCOUNT, + api_key=dify_config.BAIDU_VECTOR_DB_API_KEY, + database=dify_config.BAIDU_VECTOR_DB_DATABASE, + shard=dify_config.BAIDU_VECTOR_DB_SHARD, + replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, + ), + ) diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 943b23870cc5cb..1f4a4d44a23eea 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -103,6 +103,10 @@ def get_vector_factory(vector_type: str) -> type[AbstractVectorFactory]: from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVectorFactory return AnalyticdbVectorFactory + case VectorType.BAIDU: + from core.rag.datasource.vdb.baidu.baidu_vector import BaiduVectorFactory + + return BaiduVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index ba04ea879d9b43..996ff48615c901 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -16,3 +16,4 @@ class VectorType(str, Enum): TENCENT = "tencent" ORACLE = "oracle" ELASTICSEARCH = "elasticsearch" + BAIDU = "baidu" diff --git a/api/poetry.lock b/api/poetry.lock index b7421ca566929a..c4c31aa1b9fcb5 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -7001,6 +7001,22 @@ bulk-writer = ["azure-storage-blob", "minio (>=7.0.0)", "pyarrow (>=12.0.0)", "r dev = ["black", "grpcio (==1.62.2)", "grpcio-testing (==1.62.2)", "grpcio-tools (==1.62.2)", "pytest (>=5.3.4)", "pytest-cov (>=2.8.1)", "pytest-timeout (>=1.3.4)", "ruff (>0.4.0)"] model = ["milvus-model (>=0.1.0)"] +[[package]] +name = "pymochow" +version = "1.3.1" +description = "Python SDK for mochow" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pymochow-1.3.1-py3-none-any.whl", hash = "sha256:a7f3b34fd6ea5d1d8413650bb6678365aa148fc396ae945e4ccb4f2365a52327"}, + {file = "pymochow-1.3.1.tar.gz", hash = "sha256:1693d10cd0bb7bce45327890a90adafb503155922ccc029acb257699a73a20ba"}, +] + +[package.dependencies] +future = "*" +orjson = "*" +requests = "*" + [[package]] name = "pymysql" version = "1.1.1" @@ -8354,6 +8370,11 @@ files = [ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, + {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, + {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, + {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, @@ -10339,31 +10360,37 @@ python-versions = ">=3.8" files = [ {file = "zope.interface-7.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:2bd9e9f366a5df08ebbdc159f8224904c1c5ce63893984abb76954e6fbe4381a"}, {file = "zope.interface-7.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:661d5df403cd3c5b8699ac480fa7f58047a3253b029db690efa0c3cf209993ef"}, + {file = "zope.interface-7.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91b6c30689cfd87c8f264acb2fc16ad6b3c72caba2aec1bf189314cf1a84ca33"}, {file = "zope.interface-7.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b6a4924f5bad9fe21d99f66a07da60d75696a136162427951ec3cb223a5570d"}, {file = "zope.interface-7.1.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80a3c00b35f6170be5454b45abe2719ea65919a2f09e8a6e7b1362312a872cd3"}, {file = "zope.interface-7.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:b936d61dbe29572fd2cfe13e30b925e5383bed1aba867692670f5a2a2eb7b4e9"}, {file = "zope.interface-7.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0ac20581fc6cd7c754f6dff0ae06fedb060fa0e9ea6309d8be8b2701d9ea51c4"}, {file = "zope.interface-7.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:848b6fa92d7c8143646e64124ed46818a0049a24ecc517958c520081fd147685"}, + {file = "zope.interface-7.1.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec1ef1fdb6f014d5886b97e52b16d0f852364f447d2ab0f0c6027765777b6667"}, {file = "zope.interface-7.1.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bcff5c09d0215f42ba64b49205a278e44413d9bf9fa688fd9e42bfe472b5f4f"}, {file = "zope.interface-7.1.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07add15de0cc7e69917f7d286b64d54125c950aeb43efed7a5ea7172f000fbc1"}, {file = "zope.interface-7.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:9940d5bc441f887c5f375ec62bcf7e7e495a2d5b1da97de1184a88fb567f06af"}, {file = "zope.interface-7.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f245d039f72e6f802902375755846f5de1ee1e14c3e8736c078565599bcab621"}, {file = "zope.interface-7.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6159e767d224d8f18deff634a1d3722e68d27488c357f62ebeb5f3e2f5288b1f"}, + {file = "zope.interface-7.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e956b1fd7f3448dd5e00f273072e73e50dfafcb35e4227e6d5af208075593c9"}, {file = "zope.interface-7.1.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ff115ef91c0eeac69cd92daeba36a9d8e14daee445b504eeea2b1c0b55821984"}, {file = "zope.interface-7.1.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bec001798ab62c3fc5447162bf48496ae9fba02edc295a9e10a0b0c639a6452e"}, {file = "zope.interface-7.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:124149e2d42067b9c6597f4dafdc7a0983d0163868f897b7bb5dc850b14f9a87"}, {file = "zope.interface-7.1.0-cp313-cp313-macosx_10_9_x86_64.whl", hash = "sha256:9733a9a0f94ef53d7aa64661811b20875b5bc6039034c6e42fb9732170130573"}, {file = "zope.interface-7.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5fcf379b875c610b5a41bc8a891841533f98de0520287d7f85e25386cd10d3e9"}, + {file = "zope.interface-7.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0a45b5af9f72c805ee668d1479480ca85169312211bed6ed18c343e39307d5f"}, {file = "zope.interface-7.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4af4a12b459a273b0b34679a5c3dc5e34c1847c3dd14a628aa0668e19e638ea2"}, {file = "zope.interface-7.1.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a735f82d2e3ed47ca01a20dfc4c779b966b16352650a8036ab3955aad151ed8a"}, {file = "zope.interface-7.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:5501e772aff595e3c54266bc1bfc5858e8f38974ce413a8f1044aae0f32a83a3"}, {file = "zope.interface-7.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ec59fe53db7d32abb96c6d4efeed84aab4a7c38c62d7a901a9b20c09dd936e7a"}, {file = "zope.interface-7.1.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e53c291debef523b09e1fe3dffe5f35dde164f1c603d77f770b88a1da34b7ed6"}, + {file = "zope.interface-7.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:711eebc77f2092c6a8b304bad0b81a6ce3cf5490b25574e7309fbc07d881e3af"}, {file = "zope.interface-7.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a00ead2e24c76436e1b457a5132d87f83858330f6c923640b7ef82d668525d1"}, {file = "zope.interface-7.1.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e28ea0bc4b084fc93a483877653a033062435317082cdc6388dec3438309faf"}, {file = "zope.interface-7.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:27cfb5205d68b12682b6e55ab8424662d96e8ead19550aad0796b08dd2c9a45e"}, {file = "zope.interface-7.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9e3e48f3dea21c147e1b10c132016cb79af1159facca9736d231694ef5a740a8"}, {file = "zope.interface-7.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a99240b1d02dc469f6afbe7da1bf617645e60290c272968f4e53feec18d7dce8"}, + {file = "zope.interface-7.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc8a318162123eddbdf22fcc7b751288ce52e4ad096d3766ff1799244352449d"}, {file = "zope.interface-7.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b7b25db127db3e6b597c5f74af60309c4ad65acd826f89609662f0dc33a54728"}, {file = "zope.interface-7.1.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a29ac607e970b5576547f0e3589ec156e04de17af42839eedcf478450687317"}, {file = "zope.interface-7.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:a14c9decf0eb61e0892631271d500c1e306c7b6901c998c7035e194d9150fdd1"}, @@ -10493,4 +10520,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "d324192116c4b243e504d57f4605b79c46592a976201d903b16a910b71d84b57" +content-hash = "fa17453ddd2952e1129322c4cd82b1b3db8cdc31f7ed1543549428989eb357c0" diff --git a/api/pyproject.toml b/api/pyproject.toml index 11bcc255d7cd98..c9bffe77cf4649 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -190,6 +190,7 @@ zhipuai = "1.0.7" # Related transparent dependencies with pinned version # required by main implementations ############################################################ +pymochow = "1.3.1" [tool.poetry.group.indirect.dependencies] kaleido = "0.2.1" rank-bm25 = "~0.2.2" @@ -241,6 +242,7 @@ oracledb = "~2.2.1" pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } pgvector = "0.2.5" pymilvus = "~2.4.4" +pymochow = "1.3.1" qdrant-client = "1.7.3" tcvectordb = "1.3.2" tidb-vector = "0.0.9" diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py new file mode 100644 index 00000000000000..27af4d8e5fe099 --- /dev/null +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -0,0 +1,155 @@ +import os + +import pytest +from _pytest.monkeypatch import MonkeyPatch +from requests.adapters import HTTPAdapter + +from pymochow import MochowClient +from pymochow.model.database import Database +from pymochow.model.table import Table +from pymochow.model.enum import IndexType, MetricType, TableState, IndexState, ReadConsistency +from pymochow.model.schema import VectorIndex, HNSWParams + + +class MockBaiduVectorDBClass: + def mock_vector_db_client( + self, + config=None, + adapter: HTTPAdapter = None, + ): + self._conn = None + self._config = None + + def list_databases(self, config=None) -> list[Database]: + return [ + Database( + conn=self._conn, + database_name="dify", + config=self._config, + ) + ] + + def create_database(self, database_name: str, config=None) -> Database: + return Database(conn=self._conn, database_name=database_name, config=config) + + def list_table(self, config=None) -> list[Table]: + return [] + + def drop_table(self, table_name: str, config=None): + return {"code": 0, "msg": "Success"} + + def create_table( + self, + table_name: str, + replication: int, + partition: int, + schema, + enable_dynamic_field=False, + description: str = "", + config=None, + ) -> Table: + return Table(self, table_name, replication, partition, schema, enable_dynamic_field, description, config) + + def describe_table(self, table_name: str, config=None) -> Table: + return Table( + self, + table_name, + 3, + 1, + None, + enable_dynamic_field=False, + description="table for dify", + config=config, + state=TableState.NORMAL, + ) + + def upsert(self, rows, config=None): + return {"code": 0, "msg": "operation success", "affectedCount": 1} + + def rebuild_index(self, index_name: str, config=None): + return {"code": 0, "msg": "Success"} + + def describe_index(self, index_name: str, config=None): + return VectorIndex( + index_name=index_name, + index_type=IndexType.HNSW, + field="vector", + metric_type=MetricType.L2, + params=HNSWParams(m=16, efconstruction=200), + auto_build=False, + state=IndexState.NORMAL, + ) + + def query( + self, + primary_key, + partition_key=None, + projections=None, + retrieve_vector=False, + read_consistency=ReadConsistency.EVENTUAL, + config=None, + ): + return { + "row": { + "id": "doc_id_001", + "vector": [0.23432432, 0.8923744, 0.89238432], + "text": "text", + "metadata": {"doc_id": "doc_id_001"}, + }, + "code": 0, + "msg": "Success", + } + + def delete(self, primary_key=None, partition_key=None, filter=None, config=None): + return {"code": 0, "msg": "Success"} + + def search( + self, + anns, + partition_key=None, + projections=None, + retrieve_vector=False, + read_consistency=ReadConsistency.EVENTUAL, + config=None, + ): + return { + "rows": [ + { + "row": { + "id": "doc_id_001", + "vector": [0.23432432, 0.8923744, 0.89238432], + "text": "text", + "metadata": {"doc_id": "doc_id_001"}, + }, + "distance": 0.1, + "score": 0.5, + } + ], + "code": 0, + "msg": "Success", + } + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(MochowClient, "__init__", MockBaiduVectorDBClass.mock_vector_db_client) + monkeypatch.setattr(MochowClient, "list_databases", MockBaiduVectorDBClass.list_databases) + monkeypatch.setattr(MochowClient, "create_database", MockBaiduVectorDBClass.create_database) + monkeypatch.setattr(Database, "table", MockBaiduVectorDBClass.describe_table) + monkeypatch.setattr(Database, "list_table", MockBaiduVectorDBClass.list_table) + monkeypatch.setattr(Database, "create_table", MockBaiduVectorDBClass.create_table) + monkeypatch.setattr(Database, "drop_table", MockBaiduVectorDBClass.drop_table) + monkeypatch.setattr(Database, "describe_table", MockBaiduVectorDBClass.describe_table) + monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index) + monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index) + monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete) + monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search) + + yield + + if MOCK: + monkeypatch.undo() diff --git a/api/tests/integration_tests/vdb/baidu/__init__.py b/api/tests/integration_tests/vdb/baidu/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/vdb/baidu/test_baidu.py b/api/tests/integration_tests/vdb/baidu/test_baidu.py new file mode 100644 index 00000000000000..01a7f8853ac367 --- /dev/null +++ b/api/tests/integration_tests/vdb/baidu/test_baidu.py @@ -0,0 +1,36 @@ +from unittest.mock import MagicMock + +from core.rag.datasource.vdb.baidu.baidu_vector import BaiduConfig, BaiduVector +from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis + +mock_client = MagicMock() +mock_client.list_databases.return_value = [{"name": "test"}] + + +class BaiduVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + self.vector = BaiduVector( + "dify", + BaiduConfig( + endpoint="http://127.0.0.1:5287", + account="root", + api_key="dify", + database="dify", + shard=1, + replicas=3, + ), + ) + + def search_by_vector(self): + hits_by_vector = self.vector.search_by_vector(query_vector=self.example_embedding) + assert len(hits_by_vector) == 1 + + def search_by_full_text(self): + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + assert len(hits_by_full_text) == 0 + + +def test_baidu_vector(setup_mock_redis, setup_baiduvectordb_mock): + BaiduVectorTest().run_all_tests() diff --git a/dev/pytest/pytest_vdb.sh b/dev/pytest/pytest_vdb.sh index bad809cbfdb90a..f0cc9970170a4a 100755 --- a/dev/pytest/pytest_vdb.sh +++ b/dev/pytest/pytest_vdb.sh @@ -7,4 +7,5 @@ pytest api/tests/integration_tests/vdb/chroma \ api/tests/integration_tests/vdb/pgvector \ api/tests/integration_tests/vdb/qdrant \ api/tests/integration_tests/vdb/weaviate \ - api/tests/integration_tests/vdb/elasticsearch \ No newline at end of file + api/tests/integration_tests/vdb/elasticsearch \ + api/tests/integration_tests/vdb/baidu \ No newline at end of file diff --git a/docker/.env.example b/docker/.env.example index eb05f7aa4f0b25..241b6046eb87bc 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -462,6 +462,15 @@ ELASTICSEARCH_PORT=9200 ELASTICSEARCH_USERNAME=elastic ELASTICSEARCH_PASSWORD=elastic +# baidu vector configurations, only available when VECTOR_STORE is `baidu` +BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287 +BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS=30000 +BAIDU_VECTOR_DB_ACCOUNT=root +BAIDU_VECTOR_DB_API_KEY=dify +BAIDU_VECTOR_DB_DATABASE=dify +BAIDU_VECTOR_DB_SHARD=1 +BAIDU_VECTOR_DB_REPLICAS=3 + # ------------------------------ # Knowledge Configuration # ------------------------------ diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 316a6fb64e64c2..c3ef7b7e686e26 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -165,6 +165,13 @@ x-shared-env: &shared-api-worker-env TENCENT_VECTOR_DB_DATABASE: ${TENCENT_VECTOR_DB_DATABASE:-dify} TENCENT_VECTOR_DB_SHARD: ${TENCENT_VECTOR_DB_SHARD:-1} TENCENT_VECTOR_DB_REPLICAS: ${TENCENT_VECTOR_DB_REPLICAS:-2} + BAIDU_VECTOR_DB_ENDPOINT: ${BAIDU_VECTOR_DB_ENDPOINT:-http://127.0.0.1:5287} + BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: ${BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS:-30000} + BAIDU_VECTOR_DB_ACCOUNT: ${BAIDU_VECTOR_DB_ACCOUNT:-root} + BAIDU_VECTOR_DB_API_KEY: ${BAIDU_VECTOR_DB_API_KEY:-dify} + BAIDU_VECTOR_DB_DATABASE: ${BAIDU_VECTOR_DB_DATABASE:-dify} + BAIDU_VECTOR_DB_SHARD: ${BAIDU_VECTOR_DB_SHARD:-1} + BAIDU_VECTOR_DB_REPLICAS: ${BAIDU_VECTOR_DB_REPLICAS:-3} UPLOAD_FILE_SIZE_LIMIT: ${UPLOAD_FILE_SIZE_LIMIT:-15} UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} ETL_TYPE: ${ETL_TYPE:-dify} From 595c38fd3f98c8e7a55cf1ad181977530226394a Mon Sep 17 00:00:00 2001 From: caoshili Date: Fri, 11 Oct 2024 11:54:18 +0800 Subject: [PATCH 2/5] fix python style --- api/core/rag/datasource/vdb/baidu/baidu_vector.py | 11 +++++------ .../integration_tests/vdb/__mock/baiduvectordb.py | 7 +++---- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index e599e03c06d2a6..543cfa67b35409 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -1,16 +1,15 @@ import json -import uuid import time +import uuid from typing import Any from pydantic import BaseModel, model_validator - from pymochow import MochowClient -from pymochow.configuration import Configuration from pymochow.auth.bce_credentials import BceCredentials -from pymochow.model.enum import FieldType, IndexType, MetricType, TableState, IndexState -from pymochow.model.schema import Schema, Field, VectorIndex, HNSWParams -from pymochow.model.table import Partition, Row, AnnSearch, HNSWSearchParams +from pymochow.configuration import Configuration +from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState +from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex +from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row from configs import dify_config from core.rag.datasource.entity.embedding import Embeddings diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index 27af4d8e5fe099..a8eaf42b7de1de 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -2,13 +2,12 @@ import pytest from _pytest.monkeypatch import MonkeyPatch -from requests.adapters import HTTPAdapter - from pymochow import MochowClient from pymochow.model.database import Database +from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState +from pymochow.model.schema import HNSWParams, VectorIndex from pymochow.model.table import Table -from pymochow.model.enum import IndexType, MetricType, TableState, IndexState, ReadConsistency -from pymochow.model.schema import VectorIndex, HNSWParams +from requests.adapters import HTTPAdapter class MockBaiduVectorDBClass: From a5be94d9ec37caeb01828f7a31c808fe8b88bacf Mon Sep 17 00:00:00 2001 From: caoshili Date: Fri, 11 Oct 2024 13:44:06 +0800 Subject: [PATCH 3/5] fix problems in comments --- api/configs/middleware/vdb/baidu_vector_config.py | 2 +- api/pyproject.toml | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/api/configs/middleware/vdb/baidu_vector_config.py b/api/configs/middleware/vdb/baidu_vector_config.py index 64c5504b1f8aa5..44742c2e2f4349 100644 --- a/api/configs/middleware/vdb/baidu_vector_config.py +++ b/api/configs/middleware/vdb/baidu_vector_config.py @@ -16,7 +16,7 @@ class BaiduVectorDBConfig(BaseSettings): BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field( description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)", - default=30, + default=30000, ) BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field( diff --git a/api/pyproject.toml b/api/pyproject.toml index c9bffe77cf4649..6a1723047cb4be 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -190,7 +190,6 @@ zhipuai = "1.0.7" # Related transparent dependencies with pinned version # required by main implementations ############################################################ -pymochow = "1.3.1" [tool.poetry.group.indirect.dependencies] kaleido = "0.2.1" rank-bm25 = "~0.2.2" From 9d1666105c401ed9fa5d0dc74d720463fcb4d95a Mon Sep 17 00:00:00 2001 From: caoshili Date: Sat, 12 Oct 2024 19:24:28 +0800 Subject: [PATCH 4/5] fix poetry.lock --- api/poetry.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/poetry.lock b/api/poetry.lock index 44ed8a22752376..6565db27ad5725 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -10636,4 +10636,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "29e8172c41238fa7f3f6625d1ad339ead48aec965b824c6b34b9bfc74fec935d" +content-hash = "375ac3a91760513924647e67376cb6018505ec61d967651b254c68af9808d774" From 90e6d9adc88cc8327c2e8e41014a69bb8d3157b8 Mon Sep 17 00:00:00 2001 From: caoshili Date: Sat, 12 Oct 2024 20:59:33 +0800 Subject: [PATCH 5/5] fix test failed --- dev/pytest/pytest_vdb.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dev/pytest/pytest_vdb.sh b/dev/pytest/pytest_vdb.sh index f0cc9970170a4a..bad809cbfdb90a 100755 --- a/dev/pytest/pytest_vdb.sh +++ b/dev/pytest/pytest_vdb.sh @@ -7,5 +7,4 @@ pytest api/tests/integration_tests/vdb/chroma \ api/tests/integration_tests/vdb/pgvector \ api/tests/integration_tests/vdb/qdrant \ api/tests/integration_tests/vdb/weaviate \ - api/tests/integration_tests/vdb/elasticsearch \ - api/tests/integration_tests/vdb/baidu \ No newline at end of file + api/tests/integration_tests/vdb/elasticsearch \ No newline at end of file