diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 8626236856ddbf..5fec991d6e2d97 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -191,6 +191,22 @@ def BROKER_USE_SSL(self) -> bool: return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False +class InternalTestConfig(BaseSettings): + """ + Configuration settings for Internal Test + """ + + AWS_SECRET_ACCESS_KEY: Optional[str] = Field( + description="Internal test AWS secret access key", + default=None, + ) + + AWS_ACCESS_KEY_ID: Optional[str] = Field( + description="Internal test AWS access key ID", + default=None, + ) + + class MiddlewareConfig( # place the configs in alphabet order CeleryConfig, @@ -224,5 +240,6 @@ class MiddlewareConfig( TiDBVectorConfig, WeaviateConfig, ElasticsearchConfig, + InternalTestConfig, ): pass diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index dcef979360ba62..2dc054cfbdf4af 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -13,6 +13,7 @@ from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService +from services.knowledge_service import ExternalDatasetTestService def _validate_name(name): @@ -232,8 +233,31 @@ def post(self, dataset_id): raise InternalServerError(str(e)) +class BedrockRetrievalApi(Resource): + # this api is only for internal testing + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json") + parser.add_argument( + "query", + nullable=False, + required=True, + type=str, + ) + parser.add_argument("knowledge_id", nullable=False, required=True, type=str) + args = parser.parse_args() + + # Call the knowledge retrieval service + result = ExternalDatasetTestService.knowledge_retrieval( + args["retrieval_setting"], args["query"], args["knowledge_id"] + ) + return result, 200 + + api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets//external-hit-testing") api.add_resource(ExternalDatasetCreateApi, "/datasets/external") api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api") api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/") api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api//use-check") +# this api is only for internal test +api.add_resource(BedrockRetrievalApi, "/test/retrieval") diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index fe0bcf73384ed2..9e8a53bbfbf2a7 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -126,13 +126,12 @@ class ModelProviderIconApi(Resource): Get model provider icon """ - @setup_required - @login_required - @account_initialization_required def get(self, provider: str, icon_type: str, lang: str): model_provider_service = ModelProviderService() icon, mimetype = model_provider_service.get_model_provider_icon( - provider=provider, icon_type=icon_type, lang=lang + provider=provider, + icon_type=icon_type, + lang=lang, ) return send_file(io.BytesIO(icon), mimetype=mimetype) diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py index c1868b6ad02b83..4f8f4e0f61c70d 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/llm.py +++ b/api/core/model_runtime/model_providers/siliconflow/llm/llm.py @@ -1,8 +1,18 @@ from collections.abc import Generator from typing import Optional, Union -from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel @@ -29,3 +39,53 @@ def validate_credentials(self, model: str, credentials: dict) -> None: def _add_custom_parameters(cls, credentials: dict) -> None: credentials["mode"] = "chat" credentials["endpoint_url"] = "https://api.siliconflow.cn/v1" + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + return AIModelEntity( + model=model, + label=I18nObject(en_US=model, zh_Hans=model), + model_type=ModelType.LLM, + features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL] + if credentials.get("function_calling_type") == "tool_call" + else [], + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)), + ModelPropertyKey.MODE: LLMMode.CHAT.value, + }, + parameter_rules=[ + ParameterRule( + name="temperature", + use_template="temperature", + label=I18nObject(en_US="Temperature", zh_Hans="温度"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="max_tokens", + use_template="max_tokens", + default=512, + min=1, + max=int(credentials.get("max_tokens", 1024)), + label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"), + type=ParameterType.INT, + ), + ParameterRule( + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P", zh_Hans="Top P"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="top_k", + use_template="top_k", + label=I18nObject(en_US="Top K", zh_Hans="Top K"), + type=ParameterType.FLOAT, + ), + ParameterRule( + name="frequency_penalty", + use_template="frequency_penalty", + label=I18nObject(en_US="Frequency Penalty", zh_Hans="重复惩罚"), + type=ParameterType.FLOAT, + ), + ], + ) diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml index c46a891604c480..71f9a9238145c0 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml @@ -20,6 +20,7 @@ supported_model_types: - speech2text configurate_methods: - predefined-model + - customizable-model provider_credential_schema: credential_form_schemas: - variable: api_key @@ -30,3 +31,57 @@ provider_credential_schema: placeholder: zh_Hans: 在此输入您的 API Key en_US: Enter your API Key +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: context_size + label: + zh_Hans: 模型上下文长度 + en_US: Model context size + required: true + type: text-input + default: '4096' + placeholder: + zh_Hans: 在此输入您的模型上下文长度 + en_US: Enter your Model context size + - variable: max_tokens + label: + zh_Hans: 最大 token 上限 + en_US: Upper bound for max tokens + default: '4096' + type: text-input + show_on: + - variable: __model_type + value: llm + - variable: function_calling_type + label: + en_US: Function calling + type: select + required: false + default: no_call + options: + - value: no_call + label: + en_US: Not Support + zh_Hans: 不支持 + - value: function_call + label: + en_US: Support + zh_Hans: 支持 + show_on: + - variable: __model_type + value: llm diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/__init__.py b/api/core/model_runtime/model_providers/tongyi/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/_position.yaml b/api/core/model_runtime/model_providers/tongyi/rerank/_position.yaml new file mode 100644 index 00000000000000..439afda99263ad --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/rerank/_position.yaml @@ -0,0 +1 @@ +- gte-rerank diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/gte-rerank.yaml b/api/core/model_runtime/model_providers/tongyi/rerank/gte-rerank.yaml new file mode 100644 index 00000000000000..44d51b9b0d9cdd --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/rerank/gte-rerank.yaml @@ -0,0 +1,4 @@ +model: gte-rerank +model_type: rerank +model_properties: + context_size: 4000 diff --git a/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py new file mode 100644 index 00000000000000..c9245bd82ddb08 --- /dev/null +++ b/api/core/model_runtime/model_providers/tongyi/rerank/rerank.py @@ -0,0 +1,136 @@ +from typing import Optional + +import dashscope +from dashscope.common.error import ( + AuthenticationError, + InvalidParameter, + RequestFailure, + ServiceUnavailableError, + UnsupportedHTTPMethod, + UnsupportedModel, +) + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class GTERerankModel(RerankModel): + """ + Model class for GTE rerank model. + """ + + def _invoke( + self, + model: str, + credentials: dict, + query: str, + docs: list[str], + score_threshold: Optional[float] = None, + top_n: Optional[int] = None, + user: Optional[str] = None, + ) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=docs) + + # initialize client + dashscope.api_key = credentials["dashscope_api_key"] + + response = dashscope.TextReRank.call( + query=query, + documents=docs, + model=model, + top_n=top_n, + return_documents=True, + ) + + rerank_documents = [] + for _, result in enumerate(response.output.results): + # format document + rerank_document = RerankDocument( + index=result.index, + score=result.relevance_score, + text=result["document"]["text"], + ) + + # score threshold check + if score_threshold is not None: + if result.relevance_score >= score_threshold: + rerank_documents.append(rerank_document) + else: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self.invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8, + ) + except Exception as ex: + print(ex) + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + RequestFailure, + ], + InvokeServerUnavailableError: [ + ServiceUnavailableError, + ], + InvokeRateLimitError: [], + InvokeAuthorizationError: [ + AuthenticationError, + ], + InvokeBadRequestError: [ + InvalidParameter, + UnsupportedModel, + UnsupportedHTTPMethod, + ], + } diff --git a/api/core/model_runtime/model_providers/tongyi/tongyi.yaml b/api/core/model_runtime/model_providers/tongyi/tongyi.yaml index 1a09c20fd93f61..6349c227142656 100644 --- a/api/core/model_runtime/model_providers/tongyi/tongyi.yaml +++ b/api/core/model_runtime/model_providers/tongyi/tongyi.yaml @@ -18,6 +18,7 @@ supported_model_types: - llm - tts - text-embedding + - rerank configurate_methods: - predefined-model - customizable-model diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 8d578551209e09..f585e12b2e99c3 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -1,5 +1,6 @@ import json import logging +import math from typing import Any, Optional from urllib.parse import urlparse @@ -112,7 +113,8 @@ def delete(self) -> None: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: top_k = kwargs.get("top_k", 10) - knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k} + num_candidates = math.ceil(top_k * 1.5) + knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates} results = self._client.search(index=self._collection_name, knn=knn, size=top_k) diff --git a/api/core/rag/entities/context_entities.py b/api/core/rag/entities/context_entities.py index dde3beccf6e46a..cd18ad081ff4fd 100644 --- a/api/core/rag/entities/context_entities.py +++ b/api/core/rag/entities/context_entities.py @@ -1,3 +1,5 @@ +from typing import Optional + from pydantic import BaseModel @@ -7,4 +9,4 @@ class DocumentContext(BaseModel): """ content: str - score: float + score: Optional[float] = None diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index ae61ba7112243c..633e41d5cf1ed6 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -231,6 +231,9 @@ def retrieve( source["content"] = segment.content retrieval_resource_list.append(source) if hit_callback and retrieval_resource_list: + retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score"), reverse=True) + for position, item in enumerate(retrieval_resource_list, start=1): + item["position"] = position hit_callback.return_retriever_resource_info(retrieval_resource_list) if document_context_list: document_context_list = sorted(document_context_list, key=lambda x: x.score, reverse=True) @@ -536,7 +539,7 @@ def to_dataset_retriever_tool( continue # pass if dataset is not available - if dataset and dataset.available_document_count == 0: + if dataset and dataset.provider != "external" and dataset.available_document_count == 0: continue available_datasets.append(dataset) diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index 8dc60408c93b41..987f94a35046e9 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -1,10 +1,12 @@ from pydantic import BaseModel, Field from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment +from services.external_knowledge_service import ExternalDatasetService default_retrieval_model = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, @@ -53,97 +55,137 @@ def _run(self, query: str) -> str: for hit_callback in self.hit_callbacks: hit_callback.on_query(query, dataset.id) - - # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model or default_retrieval_model - if dataset.indexing_technique == "economy": - # use keyword table query - documents = RetrievalService.retrieve( - retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k + if dataset.provider == "external": + results = [] + external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + query=query, + external_retrieval_parameters=dataset.retrieval_model, ) - return str("\n".join([document.page_content for document in documents])) + for external_document in external_documents: + document = RetrievalDocument( + page_content=external_document.get("content"), + metadata=external_document.get("metadata"), + provider="external", + ) + document.metadata["score"] = external_document.get("score") + document.metadata["title"] = external_document.get("title") + document.metadata["dataset_id"] = dataset.id + document.metadata["dataset_name"] = dataset.name + results.append(document) + # deal with external documents + context_list = [] + for position, item in enumerate(results, start=1): + source = { + "position": position, + "dataset_id": item.metadata.get("dataset_id"), + "dataset_name": item.metadata.get("dataset_name"), + "document_name": item.metadata.get("title"), + "data_source_type": "external", + "retriever_from": self.retriever_from, + "score": item.metadata.get("score"), + "title": item.metadata.get("title"), + "content": item.page_content, + } + context_list.append(source) + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join([item.page_content for item in results])) else: - if self.top_k > 0: - # retrieval source + # get retrieval model , if the model is not setting , using default + retrieval_model = dataset.retrieval_model or default_retrieval_model + if dataset.indexing_technique == "economy": + # use keyword table query documents = RetrievalService.retrieve( - retrieval_method=retrieval_model.get("search_method", "semantic_search"), - dataset_id=dataset.id, - query=query, - top_k=self.top_k, - score_threshold=retrieval_model.get("score_threshold", 0.0) - if retrieval_model["score_threshold_enabled"] - else 0.0, - reranking_model=retrieval_model.get("reranking_model", None) - if retrieval_model["reranking_enable"] - else None, - reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", - weights=retrieval_model.get("weights", None), + retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k ) + return str("\n".join([document.page_content for document in documents])) else: - documents = [] - - for hit_callback in self.hit_callbacks: - hit_callback.on_tool_end(documents) - document_score_list = {} - if dataset.indexing_technique != "economy": - for item in documents: - if item.metadata.get("score"): - document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - document_context_list = [] - index_node_ids = [document.metadata["doc_id"] for document in documents] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id == self.dataset_id, - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() - - if segments: - index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} - sorted_segments = sorted( - segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) - ) - for segment in sorted_segments: - if segment.answer: - document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}") - else: - document_context_list.append(segment.get_sign_content()) - if self.return_resource: - context_list = [] - resource_number = 1 + if self.top_k > 0: + # retrieval source + documents = RetrievalService.retrieve( + retrieval_method=retrieval_model.get("search_method", "semantic_search"), + dataset_id=dataset.id, + query=query, + top_k=self.top_k, + score_threshold=retrieval_model.get("score_threshold", 0.0) + if retrieval_model["score_threshold_enabled"] + else 0.0, + reranking_model=retrieval_model.get("reranking_model", None) + if retrieval_model["reranking_enable"] + else None, + reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", + weights=retrieval_model.get("weights", None), + ) + else: + documents = [] + + for hit_callback in self.hit_callbacks: + hit_callback.on_tool_end(documents) + document_score_list = {} + if dataset.indexing_technique != "economy": + for item in documents: + if item.metadata.get("score"): + document_score_list[item.metadata["doc_id"]] = item.metadata["score"] + document_context_list = [] + index_node_ids = [document.metadata["doc_id"] for document in documents] + segments = DocumentSegment.query.filter( + DocumentSegment.dataset_id == self.dataset_id, + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids), + ).all() + + if segments: + index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} + sorted_segments = sorted( + segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf")) + ) for segment in sorted_segments: - context = {} - document = Document.query.filter( - Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() - if dataset and document: - source = { - "position": resource_number, - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, - "segment_id": segment.id, - "retriever_from": self.retriever_from, - "score": document_score_list.get(segment.index_node_id, None), - } - if self.retriever_from == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash - if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" - else: - source["content"] = segment.content - context_list.append(source) - resource_number += 1 - - for hit_callback in self.hit_callbacks: - hit_callback.return_retriever_resource_info(context_list) - - return str("\n".join(document_context_list)) + if segment.answer: + document_context_list.append( + f"question:{segment.get_sign_content()} answer:{segment.answer}" + ) + else: + document_context_list.append(segment.get_sign_content()) + if self.return_resource: + context_list = [] + resource_number = 1 + for segment in sorted_segments: + context = {} + document = Document.query.filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ).first() + if dataset and document: + source = { + "position": resource_number, + "dataset_id": dataset.id, + "dataset_name": dataset.name, + "document_id": document.id, + "document_name": document.name, + "data_source_type": document.data_source_type, + "segment_id": segment.id, + "retriever_from": self.retriever_from, + "score": document_score_list.get(segment.index_node_id, None), + } + if self.retriever_from == "dev": + source["hit_count"] = segment.hit_count + source["word_count"] = segment.word_count + source["segment_position"] = segment.position + source["index_node_hash"] = segment.index_node_hash + if segment.answer: + source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + else: + source["content"] = segment.content + context_list.append(source) + resource_number += 1 + + for hit_callback in self.hit_callbacks: + hit_callback.return_retriever_resource_info(context_list) + + return str("\n".join(document_context_list)) diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 0d801d36c4dcd5..5867a11bb39da1 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -288,7 +288,7 @@ def parse_openai_plugin_json_to_tool_bundle( @staticmethod def auto_parse_to_tool_bundle( - content: str, extra_info: Optional[dict], warning: Optional[dict] + content: str, extra_info: Optional[dict] = None, warning: Optional[dict] = None ) -> tuple[list[ApiToolBundle], str]: """ auto parse to tool bundle diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 0caf99a963822c..0b3e9bd6a888ec 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -79,8 +79,9 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: results = ( db.session.query(Dataset) - .join(subquery, Dataset.id == subquery.c.dataset_id) + .outerjoin(subquery, Dataset.id == subquery.c.dataset_id) .filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids)) + .filter((subquery.c.available_document_count > 0) | (Dataset.provider == "external")) .all() ) @@ -121,10 +122,13 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: ) elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value: if node_data.multiple_retrieval_config.reranking_mode == "reranking_model": - reranking_model = { - "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, - "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, - } + if node_data.multiple_retrieval_config.reranking_model: + reranking_model = { + "reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider, + "reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model, + } + else: + reranking_model = None weights = None elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score": reranking_model = None diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index b8f80a9f77e2a4..ede87640862199 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -8,6 +8,7 @@ from flask_login import current_user from sqlalchemy import func +from werkzeug.exceptions import NotFound from configs import dify_config from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError @@ -233,6 +234,7 @@ def update_dataset(dataset_id, data, user): dataset.name = data.get("name", dataset.name) dataset.description = data.get("description", "") external_knowledge_id = data.get("external_knowledge_id", None) + dataset.permission = data.get("permission") db.session.add(dataset) if not external_knowledge_id: raise ValueError("External knowledge id is required.") @@ -975,6 +977,8 @@ def update_document_with_dataset_id( ): DatasetService.check_dataset_model_setting(dataset) document = DocumentService.get_document(dataset.id, document_data["original_document_id"]) + if document is None: + raise NotFound("Document not found") if document.display_status != "available": raise ValueError("Document is not available") # update document name diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index ddee52164b746a..7d4fdfd2d039de 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -7,11 +7,16 @@ class EnterpriseRequest: base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") secret_key = os.environ.get("ENTERPRISE_API_SECRET_KEY", "ENTERPRISE_API_SECRET_KEY") + proxies = { + "http": None, + "https": None, + } + @classmethod def send_request(cls, method, endpoint, json=None, params=None): headers = {"Content-Type": "application/json", "Enterprise-Api-Secret-Key": cls.secret_key} url = f"{cls.base_url}{endpoint}" - response = requests.request(method, url, json=json, params=params, headers=headers) + response = requests.request(method, url, json=json, params=params, headers=headers, proxies=cls.proxies) return response.json() diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py new file mode 100644 index 00000000000000..02fe1d19bc42be --- /dev/null +++ b/api/services/knowledge_service.py @@ -0,0 +1,45 @@ +import boto3 + +from configs import dify_config + + +class ExternalDatasetTestService: + # this service is only for internal testing + @staticmethod + def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str): + # get bedrock client + client = boto3.client( + "bedrock-agent-runtime", + aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY, + aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID, + # example: us-east-1 + region_name="us-east-1", + ) + # fetch external knowledge retrieval + response = client.retrieve( + knowledgeBaseId=knowledge_id, + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": retrieval_setting.get("top_k"), + "overrideSearchType": "HYBRID", + } + }, + retrievalQuery={"text": query}, + ) + # parse response + results = [] + if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200: + if response.get("retrievalResults"): + retrieval_results = response.get("retrievalResults") + for retrieval_result in retrieval_results: + # filter out results with score less than threshold + if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0): + continue + result = { + "metadata": retrieval_result.get("metadata"), + "score": retrieval_result.get("score"), + "title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"), + "content": retrieval_result.get("content").get("text"), + } + results.append(result) + return {"records": results} diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 9f5298a506752a..257c6cf52bbb34 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -80,7 +80,9 @@ def parser_api_schema(schema: str) -> list[ApiToolBundle]: raise ValueError(f"invalid schema: {str(e)}") @staticmethod - def convert_schema_to_tool_bundles(schema: str, extra_info: Optional[dict] = None) -> list[ApiToolBundle]: + def convert_schema_to_tool_bundles( + schema: str, extra_info: Optional[dict] = None + ) -> tuple[list[ApiToolBundle], str]: """ convert schema to tool bundles diff --git a/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py index 518cee1a3ad2d5..64f2884c4b828c 100644 --- a/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py +++ b/api/tests/artifact_tests/dependencies/test_dependencies_sorted.py @@ -2,36 +2,23 @@ import toml -ALL_DEPENDENCY_GROUP_NAMES = [ - # default main group - "", - # required groups - "indirect", - "storage", - "tools", - "vdb", - # optional groups - "dev", - "lint", -] - def load_api_poetry_configs() -> dict[str, Any]: pyproject_toml = toml.load("api/pyproject.toml") - return pyproject_toml.get("tool").get("poetry") + return pyproject_toml["tool"]["poetry"] -def load_dependency_groups() -> dict[str, dict[str, dict[str, Any]]]: - poetry_configs = load_api_poetry_configs() - group_name_to_dependencies = { - group_name: (poetry_configs.get("group").get(group_name) if group_name else poetry_configs).get("dependencies") - for group_name in ALL_DEPENDENCY_GROUP_NAMES - } - return group_name_to_dependencies +def load_all_dependency_groups() -> dict[str, dict[str, dict[str, Any]]]: + configs = load_api_poetry_configs() + configs_by_group = {"main": configs} + for group_name in configs["group"]: + configs_by_group[group_name] = configs["group"][group_name] + dependencies_by_group = {group_name: base["dependencies"] for group_name, base in configs_by_group.items()} + return dependencies_by_group def test_group_dependencies_sorted(): - for group_name, dependencies in load_dependency_groups().items(): + for group_name, dependencies in load_all_dependency_groups().items(): dependency_names = list(dependencies.keys()) expected_dependency_names = sorted(set(dependency_names)) section = f"tool.poetry.group.{group_name}.dependencies" if group_name else "tool.poetry.dependencies" @@ -42,17 +29,18 @@ def test_group_dependencies_sorted(): def test_group_dependencies_version_operator(): - for group_name, dependencies in load_dependency_groups().items(): + for group_name, dependencies in load_all_dependency_groups().items(): for dependency_name, specification in dependencies.items(): - version_spec = specification if isinstance(specification, str) else specification.get("version") + version_spec = specification if isinstance(specification, str) else specification["version"] assert not version_spec.startswith("^"), ( - f"'^' is not allowed in dependency version," f" but found in '{dependency_name} = {version_spec}'" + f"Please replace '{dependency_name} = {version_spec}' with '{dependency_name} = ~{version_spec[1:]}' " + f"'^' operator is too wide and not allowed in the version specification." ) def test_duplicated_dependency_crossing_groups(): all_dependency_names: list[str] = [] - for dependencies in load_dependency_groups().values(): + for dependencies in load_all_dependency_groups().values(): dependency_names = list(dependencies.keys()) all_dependency_names.extend(dependency_names) expected_all_dependency_names = set(all_dependency_names) diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py new file mode 100644 index 00000000000000..2dcfb92c63fee2 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/tongyi/test_rerank.py @@ -0,0 +1,40 @@ +import os + +import dashscope +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.tongyi.rerank.rerank import GTERerankModel + + +def test_validate_credentials(): + model = GTERerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials(model="get-rank", credentials={"dashscope_api_key": "invalid_key"}) + + model.validate_credentials( + model="get-rank", credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")} + ) + + +def test_invoke_model(): + model = GTERerankModel() + + result = model.invoke( + model=dashscope.TextReRank.Models.gte_rerank, + credentials={"dashscope_api_key": os.environ.get("TONGYI_DASHSCOPE_API_KEY")}, + query="什么是文本排序模型", + docs=[ + "文本排序模型广泛用于搜索引擎和推荐系统中,它们根据文本相关性对候选文本进行排序", + "量子计算是计算科学的一个前沿领域", + "预训练语言模型的发展给文本排序模型带来了新的进展", + ], + score_threshold=0.7, + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 0 + assert result.docs[0].score >= 0.7 diff --git a/docker/.env.example b/docker/.env.example index eb05f7aa4f0b25..87d7709a1830af 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -797,4 +797,6 @@ POSITION_TOOL_EXCLUDES= # Example: POSITION_PROVIDER_PINS=openai,openllm POSITION_PROVIDER_PINS= POSITION_PROVIDER_INCLUDES= -POSITION_PROVIDER_EXCLUDES= \ No newline at end of file +POSITION_PROVIDER_EXCLUDES= +# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP +CSP_WHITELIST= \ No newline at end of file diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 316a6fb64e64c2..62d798a695967a 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -209,6 +209,7 @@ x-shared-env: &shared-api-worker-env SSRF_PROXY_HTTPS_URL: ${SSRF_PROXY_HTTPS_URL:-http://ssrf_proxy:3128} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} + APP_MAX_EXECUTION_TIME: ${APP_MAX_EXECUTION_TIME:-12000} services: # API service @@ -260,6 +261,7 @@ services: SENTRY_DSN: ${WEB_SENTRY_DSN:-} NEXT_TELEMETRY_DISABLED: ${NEXT_TELEMETRY_DISABLED:-0} TEXT_GENERATION_TIMEOUT_MS: ${TEXT_GENERATION_TIMEOUT_MS:-60000} + CSP_WHITELIST: ${CSP_WHITELIST:-} # The postgres database. db: @@ -279,7 +281,7 @@ services: volumes: - ./volumes/db/data:/var/lib/postgresql/data healthcheck: - test: [ "CMD", "pg_isready" ] + test: ['CMD', 'pg_isready'] interval: 1s timeout: 3s retries: 30 @@ -294,7 +296,7 @@ services: # Set the redis password when startup redis server. command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} healthcheck: - test: [ "CMD", "redis-cli", "ping" ] + test: ['CMD', 'redis-cli', 'ping'] # The DifySandbox sandbox: @@ -314,7 +316,7 @@ services: volumes: - ./volumes/sandbox/dependencies:/dependencies healthcheck: - test: [ "CMD", "curl", "-f", "http://localhost:8194/health" ] + test: ['CMD', 'curl', '-f', 'http://localhost:8194/health'] networks: - ssrf_proxy_network @@ -327,7 +329,12 @@ services: volumes: - ./ssrf_proxy/squid.conf.template:/etc/squid/squid.conf.template - ./ssrf_proxy/docker-entrypoint.sh:/docker-entrypoint-mount.sh - entrypoint: [ "sh", "-c", "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] + entrypoint: + [ + 'sh', + '-c', + "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", + ] environment: # pls clearly modify the squid env vars to fit your network environment. HTTP_PORT: ${SSRF_HTTP_PORT:-3128} @@ -356,8 +363,8 @@ services: - CERTBOT_EMAIL=${CERTBOT_EMAIL} - CERTBOT_DOMAIN=${CERTBOT_DOMAIN} - CERTBOT_OPTIONS=${CERTBOT_OPTIONS:-} - entrypoint: [ "/docker-entrypoint.sh" ] - command: [ "tail", "-f", "/dev/null" ] + entrypoint: ['/docker-entrypoint.sh'] + command: ['tail', '-f', '/dev/null'] # The nginx reverse proxy. # used for reverse proxying the API service and Web service. @@ -374,7 +381,12 @@ services: - ./volumes/certbot/conf/live:/etc/letsencrypt/live # cert dir (with certbot container) - ./volumes/certbot/conf:/etc/letsencrypt - ./volumes/certbot/www:/var/www/html - entrypoint: [ "sh", "-c", "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh" ] + entrypoint: + [ + 'sh', + '-c', + "cp /docker-entrypoint-mount.sh /docker-entrypoint.sh && sed -i 's/\r$$//' /docker-entrypoint.sh && chmod +x /docker-entrypoint.sh && /docker-entrypoint.sh", + ] environment: NGINX_SERVER_NAME: ${NGINX_SERVER_NAME:-_} NGINX_HTTPS_ENABLED: ${NGINX_HTTPS_ENABLED:-false} @@ -396,14 +408,14 @@ services: - api - web ports: - - "${EXPOSE_NGINX_PORT:-80}:${NGINX_PORT:-80}" - - "${EXPOSE_NGINX_SSL_PORT:-443}:${NGINX_SSL_PORT:-443}" + - '${EXPOSE_NGINX_PORT:-80}:${NGINX_PORT:-80}' + - '${EXPOSE_NGINX_SSL_PORT:-443}:${NGINX_SSL_PORT:-443}' # The Weaviate vector store. weaviate: image: semitechnologies/weaviate:1.19.0 profiles: - - "" + - '' - weaviate restart: always volumes: @@ -452,7 +464,7 @@ services: volumes: - ./volumes/pgvector/data:/var/lib/postgresql/data healthcheck: - test: [ "CMD", "pg_isready" ] + test: ['CMD', 'pg_isready'] interval: 1s timeout: 3s retries: 30 @@ -474,7 +486,7 @@ services: volumes: - ./volumes/pgvecto_rs/data:/var/lib/postgresql/data healthcheck: - test: [ "CMD", "pg_isready" ] + test: ['CMD', 'pg_isready'] interval: 1s timeout: 3s retries: 30 @@ -522,7 +534,7 @@ services: - ./volumes/milvus/etcd:/etcd command: etcd -advertise-client-urls=http://127.0.0.1:2379 -listen-client-urls http://0.0.0.0:2379 --data-dir /etcd healthcheck: - test: [ "CMD", "etcdctl", "endpoint", "health" ] + test: ['CMD', 'etcdctl', 'endpoint', 'health'] interval: 30s timeout: 20s retries: 3 @@ -541,7 +553,7 @@ services: - ./volumes/milvus/minio:/minio_data command: minio server /minio_data --console-address ":9001" healthcheck: - test: [ "CMD", "curl", "-f", "http://localhost:9000/minio/health/live" ] + test: ['CMD', 'curl', '-f', 'http://localhost:9000/minio/health/live'] interval: 30s timeout: 20s retries: 3 @@ -553,7 +565,7 @@ services: image: milvusdb/milvus:v2.3.1 profiles: - milvus - command: [ "milvus", "run", "standalone" ] + command: ['milvus', 'run', 'standalone'] environment: ETCD_ENDPOINTS: ${ETCD_ENDPOINTS:-etcd:2379} MINIO_ADDRESS: ${MINIO_ADDRESS:-minio:9000} @@ -561,7 +573,7 @@ services: volumes: - ./volumes/milvus/milvus:/var/lib/milvus healthcheck: - test: [ "CMD", "curl", "-f", "http://localhost:9091/healthz" ] + test: ['CMD', 'curl', '-f', 'http://localhost:9091/healthz'] interval: 30s start_period: 90s timeout: 20s @@ -643,13 +655,13 @@ services: node.name: dify-es0 discovery.type: single-node xpack.license.self_generated.type: trial - xpack.security.enabled: "true" - xpack.security.enrollment.enabled: "false" - xpack.security.http.ssl.enabled: "false" + xpack.security.enabled: 'true' + xpack.security.enrollment.enabled: 'false' + xpack.security.http.ssl.enabled: 'false' ports: - ${ELASTICSEARCH_PORT:-9200}:9200 healthcheck: - test: [ "CMD", "curl", "-s", "http://localhost:9200/_cluster/health?pretty" ] + test: ['CMD', 'curl', '-s', 'http://localhost:9200/_cluster/health?pretty'] interval: 30s timeout: 10s retries: 50 @@ -667,17 +679,17 @@ services: environment: XPACK_ENCRYPTEDSAVEDOBJECTS_ENCRYPTIONKEY: d1a66dfd-c4d3-4a0a-8290-2abcb83ab3aa NO_PROXY: localhost,127.0.0.1,elasticsearch,kibana - XPACK_SECURITY_ENABLED: "true" - XPACK_SECURITY_ENROLLMENT_ENABLED: "false" - XPACK_SECURITY_HTTP_SSL_ENABLED: "false" - XPACK_FLEET_ISAIRGAPPED: "true" + XPACK_SECURITY_ENABLED: 'true' + XPACK_SECURITY_ENROLLMENT_ENABLED: 'false' + XPACK_SECURITY_HTTP_SSL_ENABLED: 'false' + XPACK_FLEET_ISAIRGAPPED: 'true' I18N_LOCALE: zh-CN - SERVER_PORT: "5601" + SERVER_PORT: '5601' ELASTICSEARCH_HOSTS: http://elasticsearch:9200 ports: - ${KIBANA_PORT:-5601}:5601 healthcheck: - test: [ "CMD-SHELL", "curl -s http://localhost:5601 >/dev/null || exit 1" ] + test: ['CMD-SHELL', 'curl -s http://localhost:5601 >/dev/null || exit 1'] interval: 30s timeout: 10s retries: 3 diff --git a/web/.env.example b/web/.env.example index 8e254082b385fc..13ea01a2c7b120 100644 --- a/web/.env.example +++ b/web/.env.example @@ -22,3 +22,6 @@ NEXT_PUBLIC_UPLOAD_IMAGE_AS_ICON=false # The timeout for the text generation in millisecond NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS=60000 + +# CSP https://developer.mozilla.org/en-US/docs/Web/HTTP/CSP +NEXT_PUBLIC_CSP_WHITELIST= diff --git a/web/app/(shareLayout)/layout.tsx b/web/app/(shareLayout)/layout.tsx index 9c4632cd450dfa..259af2bc2dc845 100644 --- a/web/app/(shareLayout)/layout.tsx +++ b/web/app/(shareLayout)/layout.tsx @@ -1,7 +1,12 @@ import React from 'react' import type { FC } from 'react' +import type { Metadata } from 'next' import GA, { GaType } from '@/app/components/base/ga' +export const metadata: Metadata = { + icons: 'data:,', // prevent browser from using default favicon +} + const Layout: FC<{ children: React.ReactNode }> = ({ children }) => { diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 149e877fa4ac2d..7dff48c20a28ec 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -299,10 +299,14 @@ function DetailPanel { - if (appDetail?.id && detail.id && appDetail?.mode !== 'completion') + if (appDetail?.id && detail.id && appDetail?.mode !== 'completion' && !fetchInitiated.current) { + fetchInitiated.current = true fetchData() - }, [appDetail?.id, detail.id, appDetail?.mode]) + } + }, [appDetail?.id, detail.id, appDetail?.mode, fetchData]) const isChatMode = appDetail?.mode !== 'completion' const isAdvanced = appDetail?.mode === 'advanced-chat' diff --git a/web/app/components/base/chat/chat/citation/popup.tsx b/web/app/components/base/chat/chat/citation/popup.tsx index b61bf623febe99..9b98329e70beb4 100644 --- a/web/app/components/base/chat/chat/citation/popup.tsx +++ b/web/app/components/base/chat/chat/citation/popup.tsx @@ -100,7 +100,7 @@ const Popup: FC = ({ /> } /> { diff --git a/web/app/components/base/ga/index.tsx b/web/app/components/base/ga/index.tsx index ec0089ff70efc3..219724113f681d 100644 --- a/web/app/components/base/ga/index.tsx +++ b/web/app/components/base/ga/index.tsx @@ -1,6 +1,7 @@ import type { FC } from 'react' import React from 'react' import Script from 'next/script' +import { headers } from 'next/headers' import { IS_CE_EDITION } from '@/config' export enum GaType { @@ -23,9 +24,16 @@ const GA: FC = ({ if (IS_CE_EDITION) return null + const nonce = process.env.NODE_ENV === 'production' ? headers().get('x-nonce') : '' + return ( <> - + diff --git a/web/app/components/base/topbar/index.tsx b/web/app/components/base/topbar/index.tsx deleted file mode 100644 index cf67456bd3423a..00000000000000 --- a/web/app/components/base/topbar/index.tsx +++ /dev/null @@ -1,16 +0,0 @@ -'use client' - -import { AppProgressBar as ProgressBar } from 'next-nprogress-bar' - -const Topbar = () => { - return ( - <> - - ) -} - -export default Topbar diff --git a/web/app/components/datasets/documents/list.tsx b/web/app/components/datasets/documents/list.tsx index 540474e7a5739c..0e0eebb034df52 100644 --- a/web/app/components/datasets/documents/list.tsx +++ b/web/app/components/datasets/documents/list.tsx @@ -122,6 +122,7 @@ export const OperationAction: FC<{ }> = ({ embeddingAvailable, datasetId, detail, onUpdate, scene = 'list', className = '' }) => { const { id, enabled = false, archived = false, data_source_type } = detail || {} const [showModal, setShowModal] = useState(false) + const [deleting, setDeleting] = useState(false) const { notify } = useContext(ToastContext) const { t } = useTranslation() const router = useRouter() @@ -153,6 +154,7 @@ export const OperationAction: FC<{ break default: opApi = deleteDocument + setDeleting(true) break } const [e] = await asyncRunSafe(opApi({ datasetId, documentId: id }) as Promise) @@ -160,6 +162,8 @@ export const OperationAction: FC<{ notify({ type: 'success', message: t('common.actionMsg.modifiedSuccessfully') }) else notify({ type: 'error', message: t('common.actionMsg.modifiedUnsuccessfully') }) + if (operationName === 'delete') + setDeleting(false) onUpdate(operationName) } @@ -295,6 +299,8 @@ export const OperationAction: FC<{ {showModal && {