forked from langgenius/dify
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat:support baidu vector db (langgenius#9185)
- Loading branch information
Showing
15 changed files
with
582 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from typing import Optional | ||
|
||
from pydantic import Field, NonNegativeInt, PositiveInt | ||
from pydantic_settings import BaseSettings | ||
|
||
|
||
class BaiduVectorDBConfig(BaseSettings): | ||
""" | ||
Configuration settings for Baidu Vector Database | ||
""" | ||
|
||
BAIDU_VECTOR_DB_ENDPOINT: Optional[str] = Field( | ||
description="URL of the Baidu Vector Database service (e.g., 'http://vdb.bj.baidubce.com')", | ||
default=None, | ||
) | ||
|
||
BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS: PositiveInt = Field( | ||
description="Timeout in milliseconds for Baidu Vector Database operations (default is 30000 milliseconds)", | ||
default=30000, | ||
) | ||
|
||
BAIDU_VECTOR_DB_ACCOUNT: Optional[str] = Field( | ||
description="Account for authenticating with the Baidu Vector Database", | ||
default=None, | ||
) | ||
|
||
BAIDU_VECTOR_DB_API_KEY: Optional[str] = Field( | ||
description="API key for authenticating with the Baidu Vector Database service", | ||
default=None, | ||
) | ||
|
||
BAIDU_VECTOR_DB_DATABASE: Optional[str] = Field( | ||
description="Name of the specific Baidu Vector Database to connect to", | ||
default=None, | ||
) | ||
|
||
BAIDU_VECTOR_DB_SHARD: PositiveInt = Field( | ||
description="Number of shards for the Baidu Vector Database (default is 1)", | ||
default=1, | ||
) | ||
|
||
BAIDU_VECTOR_DB_REPLICAS: NonNegativeInt = Field( | ||
description="Number of replicas for the Baidu Vector Database (default is 3)", | ||
default=3, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,272 @@ | ||
import json | ||
import time | ||
import uuid | ||
from typing import Any | ||
|
||
from pydantic import BaseModel, model_validator | ||
from pymochow import MochowClient | ||
from pymochow.auth.bce_credentials import BceCredentials | ||
from pymochow.configuration import Configuration | ||
from pymochow.model.enum import FieldType, IndexState, IndexType, MetricType, TableState | ||
from pymochow.model.schema import Field, HNSWParams, Schema, VectorIndex | ||
from pymochow.model.table import AnnSearch, HNSWSearchParams, Partition, Row | ||
|
||
from configs import dify_config | ||
from core.rag.datasource.entity.embedding import Embeddings | ||
from core.rag.datasource.vdb.vector_base import BaseVector | ||
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory | ||
from core.rag.datasource.vdb.vector_type import VectorType | ||
from core.rag.models.document import Document | ||
from extensions.ext_redis import redis_client | ||
from models.dataset import Dataset | ||
|
||
|
||
class BaiduConfig(BaseModel): | ||
endpoint: str | ||
connection_timeout_in_mills: int = 30 * 1000 | ||
account: str | ||
api_key: str | ||
database: str | ||
index_type: str = "HNSW" | ||
metric_type: str = "L2" | ||
shard: int = 1 | ||
replicas: int = 3 | ||
|
||
@model_validator(mode="before") | ||
@classmethod | ||
def validate_config(cls, values: dict) -> dict: | ||
if not values["endpoint"]: | ||
raise ValueError("config BAIDU_VECTOR_DB_ENDPOINT is required") | ||
if not values["account"]: | ||
raise ValueError("config BAIDU_VECTOR_DB_ACCOUNT is required") | ||
if not values["api_key"]: | ||
raise ValueError("config BAIDU_VECTOR_DB_API_KEY is required") | ||
if not values["database"]: | ||
raise ValueError("config BAIDU_VECTOR_DB_DATABASE is required") | ||
return values | ||
|
||
|
||
class BaiduVector(BaseVector): | ||
field_id: str = "id" | ||
field_vector: str = "vector" | ||
field_text: str = "text" | ||
field_metadata: str = "metadata" | ||
field_app_id: str = "app_id" | ||
field_annotation_id: str = "annotation_id" | ||
index_vector: str = "vector_idx" | ||
|
||
def __init__(self, collection_name: str, config: BaiduConfig): | ||
super().__init__(collection_name) | ||
self._client_config = config | ||
self._client = self._init_client(config) | ||
self._db = self._init_database() | ||
|
||
def get_type(self) -> str: | ||
return VectorType.BAIDU | ||
|
||
def to_index_struct(self) -> dict: | ||
return {"type": self.get_type(), "vector_store": {"class_prefix": self._collection_name}} | ||
|
||
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): | ||
self._create_table(len(embeddings[0])) | ||
self.add_texts(texts, embeddings) | ||
|
||
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): | ||
texts = [doc.page_content for doc in documents] | ||
metadatas = [doc.metadata for doc in documents] | ||
total_count = len(documents) | ||
batch_size = 1000 | ||
|
||
# upsert texts and embeddings batch by batch | ||
table = self._db.table(self._collection_name) | ||
for start in range(0, total_count, batch_size): | ||
end = min(start + batch_size, total_count) | ||
rows = [] | ||
for i in range(start, end, 1): | ||
row = Row( | ||
id=metadatas[i].get("doc_id", str(uuid.uuid4())), | ||
vector=embeddings[i], | ||
text=texts[i], | ||
metadata=json.dumps(metadatas[i]), | ||
app_id=metadatas[i].get("app_id", ""), | ||
annotation_id=metadatas[i].get("annotation_id", ""), | ||
) | ||
rows.append(row) | ||
table.upsert(rows=rows) | ||
|
||
# rebuild vector index after upsert finished | ||
table.rebuild_index(self.index_vector) | ||
while True: | ||
time.sleep(1) | ||
index = table.describe_index(self.index_vector) | ||
if index.state == IndexState.NORMAL: | ||
break | ||
|
||
def text_exists(self, id: str) -> bool: | ||
res = self._db.table(self._collection_name).query(primary_key={self.field_id: id}) | ||
if res and res.code == 0: | ||
return True | ||
return False | ||
|
||
def delete_by_ids(self, ids: list[str]) -> None: | ||
quoted_ids = [f"'{id}'" for id in ids] | ||
self._db.table(self._collection_name).delete(filter=f"id IN({', '.join(quoted_ids)})") | ||
|
||
def delete_by_metadata_field(self, key: str, value: str) -> None: | ||
self._db.table(self._collection_name).delete(filter=f"{key} = '{value}'") | ||
|
||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: | ||
anns = AnnSearch( | ||
vector_field=self.field_vector, | ||
vector_floats=query_vector, | ||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)), | ||
) | ||
res = self._db.table(self._collection_name).search( | ||
anns=anns, | ||
projections=[self.field_id, self.field_text, self.field_metadata], | ||
retrieve_vector=True, | ||
) | ||
score_threshold = float(kwargs.get("score_threshold") or 0.0) | ||
return self._get_search_res(res, score_threshold) | ||
|
||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: | ||
# baidu vector database doesn't support bm25 search on current version | ||
return [] | ||
|
||
def _get_search_res(self, res, score_threshold): | ||
docs = [] | ||
for row in res.rows: | ||
row_data = row.get("row", {}) | ||
meta = row_data.get(self.field_metadata) | ||
if meta is not None: | ||
meta = json.loads(meta) | ||
score = row.get("score", 0.0) | ||
if score > score_threshold: | ||
meta["score"] = score | ||
doc = Document(page_content=row_data.get(self.field_text), metadata=meta) | ||
docs.append(doc) | ||
|
||
return docs | ||
|
||
def delete(self) -> None: | ||
self._db.drop_table(table_name=self._collection_name) | ||
|
||
def _init_client(self, config) -> MochowClient: | ||
config = Configuration(credentials=BceCredentials(config.account, config.api_key), endpoint=config.endpoint) | ||
client = MochowClient(config) | ||
return client | ||
|
||
def _init_database(self): | ||
exists = False | ||
for db in self._client.list_databases(): | ||
if db.database_name == self._client_config.database: | ||
exists = True | ||
break | ||
# Create database if not existed | ||
if exists: | ||
return self._client.database(self._client_config.database) | ||
else: | ||
return self._client.create_database(database_name=self._client_config.database) | ||
|
||
def _table_existed(self) -> bool: | ||
tables = self._db.list_table() | ||
return any(table.table_name == self._collection_name for table in tables) | ||
|
||
def _create_table(self, dimension: int) -> None: | ||
# Try to grab distributed lock and create table | ||
lock_name = "vector_indexing_lock_{}".format(self._collection_name) | ||
with redis_client.lock(lock_name, timeout=20): | ||
table_exist_cache_key = "vector_indexing_{}".format(self._collection_name) | ||
if redis_client.get(table_exist_cache_key): | ||
return | ||
|
||
if self._table_existed(): | ||
return | ||
|
||
self.delete() | ||
|
||
# check IndexType and MetricType | ||
index_type = None | ||
for k, v in IndexType.__members__.items(): | ||
if k == self._client_config.index_type: | ||
index_type = v | ||
if index_type is None: | ||
raise ValueError("unsupported index_type") | ||
metric_type = None | ||
for k, v in MetricType.__members__.items(): | ||
if k == self._client_config.metric_type: | ||
metric_type = v | ||
if metric_type is None: | ||
raise ValueError("unsupported metric_type") | ||
|
||
# Construct field schema | ||
fields = [] | ||
fields.append( | ||
Field( | ||
self.field_id, | ||
FieldType.STRING, | ||
primary_key=True, | ||
partition_key=True, | ||
auto_increment=False, | ||
not_null=True, | ||
) | ||
) | ||
fields.append(Field(self.field_metadata, FieldType.STRING, not_null=True)) | ||
fields.append(Field(self.field_app_id, FieldType.STRING)) | ||
fields.append(Field(self.field_annotation_id, FieldType.STRING)) | ||
fields.append(Field(self.field_text, FieldType.TEXT, not_null=True)) | ||
fields.append(Field(self.field_vector, FieldType.FLOAT_VECTOR, not_null=True, dimension=dimension)) | ||
|
||
# Construct vector index params | ||
indexes = [] | ||
indexes.append( | ||
VectorIndex( | ||
index_name="vector_idx", | ||
index_type=index_type, | ||
field="vector", | ||
metric_type=metric_type, | ||
params=HNSWParams(m=16, efconstruction=200), | ||
) | ||
) | ||
|
||
# Create table | ||
self._db.create_table( | ||
table_name=self._collection_name, | ||
replication=self._client_config.replicas, | ||
partition=Partition(partition_num=self._client_config.shard), | ||
schema=Schema(fields=fields, indexes=indexes), | ||
description="Table for Dify", | ||
) | ||
|
||
redis_client.set(table_exist_cache_key, 1, ex=3600) | ||
|
||
# Wait for table created | ||
while True: | ||
time.sleep(1) | ||
table = self._db.describe_table(self._collection_name) | ||
if table.state == TableState.NORMAL: | ||
break | ||
|
||
|
||
class BaiduVectorFactory(AbstractVectorFactory): | ||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaiduVector: | ||
if dataset.index_struct_dict: | ||
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] | ||
collection_name = class_prefix.lower() | ||
else: | ||
dataset_id = dataset.id | ||
collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() | ||
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.BAIDU, collection_name)) | ||
|
||
return BaiduVector( | ||
collection_name=collection_name, | ||
config=BaiduConfig( | ||
endpoint=dify_config.BAIDU_VECTOR_DB_ENDPOINT, | ||
connection_timeout_in_mills=dify_config.BAIDU_VECTOR_DB_CONNECTION_TIMEOUT_MS, | ||
account=dify_config.BAIDU_VECTOR_DB_ACCOUNT, | ||
api_key=dify_config.BAIDU_VECTOR_DB_API_KEY, | ||
database=dify_config.BAIDU_VECTOR_DB_DATABASE, | ||
shard=dify_config.BAIDU_VECTOR_DB_SHARD, | ||
replicas=dify_config.BAIDU_VECTOR_DB_REPLICAS, | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.