Skip to content

Commit

Permalink
Fix langgenius#12448 - update bedrock retrieve tool, support hybrid s…
Browse files Browse the repository at this point in the history
…earch type and re… (langgenius#12446)

Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
  • Loading branch information
2 people authored and alexcodelf committed Jan 21, 2025
1 parent 9665be5 commit 0eaa8a5
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 5 deletions.
39 changes: 34 additions & 5 deletions api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,38 @@ class BedrockRetrieveTool(BuiltinTool):
topk: int = None

def _bedrock_retrieve(
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
self,
query_input: str,
knowledge_base_id: str,
num_results: int,
search_type: str,
rerank_model_id: str,
metadata_filter: Optional[dict] = None,
):
try:
retrieval_query = {"text": query_input}

retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
if search_type not in ["HYBRID", "SEMANTIC"]:
raise RuntimeException("search_type should be HYBRID or SEMANTIC")

retrieval_configuration = {
"vectorSearchConfiguration": {"numberOfResults": num_results, "overrideSearchType": search_type}
}

if rerank_model_id != "default":
model_for_rerank_arn = f"arn:aws:bedrock:us-west-2::foundation-model/{rerank_model_id}"
rerankingConfiguration = {
"bedrockRerankingConfiguration": {
"numberOfRerankedResults": num_results,
"modelConfiguration": {"modelArn": model_for_rerank_arn},
},
"type": "BEDROCK_RERANKING_MODEL",
}

# Add metadata filter to retrieval configuration if present
retrieval_configuration["vectorSearchConfiguration"]["rerankingConfiguration"] = rerankingConfiguration
retrieval_configuration["vectorSearchConfiguration"]["numberOfResults"] = num_results * 5

# 如果有元数据过滤条件,则添加到检索配置中
if metadata_filter:
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter

Expand Down Expand Up @@ -77,15 +101,20 @@ def _invoke(
if not query:
return self.create_text_message("Please input query")

# Get metadata filter conditions (if they exist)
# 获取元数据过滤条件(如果存在)
metadata_filter_str = tool_parameters.get("metadata_filter")
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None

search_type = tool_parameters.get("search_type")
rerank_model_id = tool_parameters.get("rerank_model_id")

line = 4
retrieved_docs = self._bedrock_retrieve(
query_input=query,
knowledge_base_id=self.knowledge_base_id,
num_results=self.topk,
search_type=search_type,
rerank_model_id=rerank_model_id,
metadata_filter=metadata_filter,
)

Expand All @@ -109,7 +138,7 @@ def validate_parameters(self, parameters: dict[str, Any]) -> None:
if not parameters.get("query"):
raise ValueError("query is required")

# Optional: Validate if metadata filter is a valid JSON string (if provided)
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
metadata_filter_str = parameters.get("metadata_filter")
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
raise ValueError("metadata_filter must be a valid JSON object")
51 changes: 51 additions & 0 deletions api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,57 @@ parameters:
max: 10
default: 5

- name: search_type
type: select
required: false
label:
en_US: search type
zh_Hans: 搜索类型
pt_BR: search type
human_description:
en_US: search type
zh_Hans: 搜索类型
pt_BR: search type
llm_description: search type
default: SEMANTIC
options:
- value: SEMANTIC
label:
en_US: SEMANTIC
zh_Hans: 语义搜索
- value: HYBRID
label:
en_US: HYBRID
zh_Hans: 混合搜索
form: form

- name: rerank_model_id
type: select
required: false
label:
en_US: rerank model id
zh_Hans: 重拍模型ID
pt_BR: rerank model id
human_description:
en_US: rerank model id
zh_Hans: 重拍模型ID
pt_BR: rerank model id
llm_description: rerank model id
options:
- value: default
label:
en_US: default
zh_Hans: 默认
- value: cohere.rerank-v3-5:0
label:
en_US: cohere.rerank-v3-5:0
zh_Hans: cohere.rerank-v3-5:0
- value: amazon.rerank-v1:0
label:
en_US: amazon.rerank-v1:0
zh_Hans: amazon.rerank-v1:0
form: form

- name: aws_region
type: string
required: false
Expand Down

0 comments on commit 0eaa8a5

Please sign in to comment.