Skip to content

Commit

Permalink
Merge branch 'fix/agent-external-knowledge-retrieval' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnJyong committed Oct 11, 2024
2 parents a63a670 + 99967e6 commit e2bf129
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 94 deletions.
10 changes: 10 additions & 0 deletions api/configs/middleware/vdb/qdrant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,13 @@ class QdrantConfig(BaseSettings):
description="Port number for gRPC connection to Qdrant server (default is 6334)",
default=6334,
)

AWS_SECRET_ACCESS_KEY: Optional[str] = Field(
description="AWS secret access key for authenticating with the Qdrant server",
default=None,
)

AWS_ACCESS_KEY_ID: Optional[str] = Field(
description="AWS access key ID for authenticating with the Qdrant server",
default=None,
)
18 changes: 18 additions & 0 deletions api/controllers/console/datasets/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetServiceTest


def _validate_name(name):
Expand Down Expand Up @@ -275,8 +276,25 @@ def post(self, dataset_id):
raise InternalServerError(str(e))


class BedrockRetrievalApi(Resource):
# url : <your-endpoint>/retrieval
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
parser.add_argument("query", nullable=False, required=True, type=str, )
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
args = parser.parse_args()

# Call the knowledge retrieval service
result = ExternalDatasetServiceTest.knowledge_retrieval(
args["retrieval_setting"], args["query"], args["knowledge_id"]
)
return result, 200


api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets/<uuid:dataset_id>/external-hit-testing")
api.add_resource(ExternalDatasetCreateApi, "/datasets/external")
api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api")
api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
api.add_resource(BedrockRetrievalApi, "/datasets/retrieval")
2 changes: 1 addition & 1 deletion api/core/rag/retrieval/dataset_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def to_dataset_retriever_tool(
continue

# pass if dataset is not available
if dataset and dataset.available_document_count == 0:
if dataset and dataset.provider != "external" and dataset.available_document_count == 0:
continue

available_datasets.append(dataset)
Expand Down
217 changes: 129 additions & 88 deletions api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from core.rag.models.document import Document as RetrievalDocument
from services.external_knowledge_service import ExternalDatasetService

default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
Expand Down Expand Up @@ -53,97 +55,136 @@ def _run(self, query: str) -> str:

for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id)

# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
if dataset.provider == "external":
results = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
)
return str("\n".join([document.page_content for document in documents]))
for external_document in external_documents:
document = RetrievalDocument(
page_content=external_document.get("content"),
metadata=external_document.get("metadata"),
provider="external",
)
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset.id
document.metadata["dataset_name"] = dataset.name
results.append(document)
# deal with external documents
context_list = []
for position, item in enumerate(results, start=1):
source = {
"position": position,
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": self.retriever_from,
"score": item.metadata.get("score"),
"title": item.metadata.get("title"),
"content": item.page_content,
}
context_list.append(source)
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)

return str("\n".join([item.page_content for item in results]))
else:
if self.top_k > 0:
# retrieval source

# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
)
return str("\n".join([document.page_content for document in documents]))
else:
documents = []

for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()

if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
else:
document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list = []
resource_number = 1
if self.top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
)
else:
documents = []

for hit_callback in self.hit_callbacks:
hit_callback.on_tool_end(documents)
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id == self.dataset_id,
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()

if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
context = {}
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),
}
if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
context_list.append(source)
resource_number += 1

for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)

return str("\n".join(document_context_list))
if segment.answer:
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
else:
document_context_list.append(segment.get_sign_content())
if self.return_resource:
context_list = []
resource_number = 1
for segment in sorted_segments:
context = {}
document = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),
}
if self.retriever_from == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
if segment.answer:
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source["content"] = segment.content
context_list.append(source)
resource_number += 1

for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)

return str("\n".join(document_context_list))
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,9 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query:

results = (
db.session.query(Dataset)
.join(subquery, Dataset.id == subquery.c.dataset_id)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
.filter((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all()
)

Expand Down Expand Up @@ -120,10 +121,13 @@ def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query:
)
elif node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value:
if node_data.multiple_retrieval_config.reranking_mode == "reranking_model":
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
if node_data.multiple_retrieval_config.reranking_model:
reranking_model = {
"reranking_provider_name": node_data.multiple_retrieval_config.reranking_model.provider,
"reranking_model_name": node_data.multiple_retrieval_config.reranking_model.model,
}
else:
reranking_model = None
weights = None
elif node_data.multiple_retrieval_config.reranking_mode == "weighted_score":
reranking_model = None
Expand Down
44 changes: 44 additions & 0 deletions api/services/knowledge_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

import boto3

from configs import dify_config


class ExternalDatasetServiceTest:
@staticmethod
def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str):
# get bedrock client
client = boto3.client(
"bedrock-agent-runtime",
aws_secret_access_key=dify_config.AWS_SECRET_ACCESS_KEY,
aws_access_key_id=dify_config.AWS_ACCESS_KEY_ID,
# example: us-east-1
region_name="us-east-1",
)
# fetch external knowledge retrieval
response = client.retrieve(
knowledgeBaseId=knowledge_id,
retrievalConfiguration={
"vectorSearchConfiguration": {"numberOfResults": retrieval_setting.get("top_k"), "overrideSearchType": "HYBRID"}
},
retrievalQuery={"text": query},
)
# parse response
results = []
if response.get("ResponseMetadata") and response.get("ResponseMetadata").get("HTTPStatusCode") == 200:
if response.get("retrievalResults"):
retrieval_results = response.get("retrievalResults")
for retrieval_result in retrieval_results:
# filter out results with score less than threshold
if retrieval_result.get("score") < retrieval_setting.get("score_threshold", .0):
continue
result = {
"metadata": retrieval_result.get("metadata"),
"score": retrieval_result.get("score"),
"title": retrieval_result.get("metadata").get("x-amz-bedrock-kb-source-uri"),
"content": retrieval_result.get("content").get("text"),
}
results.append(result)
return {
"records": results
}

0 comments on commit e2bf129

Please sign in to comment.