Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix sagemaker_chinese_toxicity_detector and bedrock_retrieve #12227

Merged
merged 8 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _bedrock_retrieve(

retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}

# 如果有元数据过滤条件,则添加到检索配置中
# Add metadata filter to retrieval configuration if present
if metadata_filter:
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter

Expand Down Expand Up @@ -77,7 +77,7 @@ def _invoke(
if not query:
return self.create_text_message("Please input query")

# 获取元数据过滤条件(如果存在)
# Get metadata filter conditions (if they exist)
metadata_filter_str = tool_parameters.get("metadata_filter")
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None

Expand All @@ -86,7 +86,7 @@ def _invoke(
query_input=query,
knowledge_base_id=self.knowledge_base_id,
num_results=self.topk,
metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法
metadata_filter=metadata_filter,
)

line = 5
Expand All @@ -109,7 +109,7 @@ def validate_parameters(self, parameters: dict[str, Any]) -> None:
if not parameters.get("query"):
raise ValueError("query is required")

# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
# Optional: Validate if metadata filter is a valid JSON string (if provided)
metadata_filter_str = parameters.get("metadata_filter")
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
raise ValueError("metadata_filter must be a valid JSON object")
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ parameters:
llm_description: AWS region where the Bedrock Knowledge Base is located
form: form

- name: metadata_filter
type: string
required: false
- name: metadata_filter # 新增的元数据过滤参数
type: string # 可以是字符串类型,包含 JSON 格式的过滤条件
required: false # 可选参数
label:
en_US: Metadata Filter
zh_Hans: 元数据过滤器
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool

# 定义标签映射
LABEL_MAPPING = {"LABEL_0": "SAFE", "LABEL_1": "NO_SAFE"}
# Define label mappings
LABEL_MAPPING = {0: "SAFE", 1: "NO_SAFE"}


class ContentModerationTool(BuiltinTool):
Expand All @@ -28,12 +28,12 @@ def _invoke_sagemaker(self, payload: dict, endpoint: str):
# Handle nested JSON if present
if isinstance(json_obj, dict) and "body" in json_obj:
body_content = json.loads(json_obj["body"])
raw_label = body_content.get("label")
prediction_result = body_content.get("prediction")
else:
raw_label = json_obj.get("label")
prediction_result = json_obj.get("prediction")

# 映射标签并返回
result = LABEL_MAPPING.get(raw_label, "NO_SAFE") # 如果映射中没有找到,默认返回NO_SAFE
# Map labels and return
result = LABEL_MAPPING.get(prediction_result, "NO_SAFE") # If not found in mapping, default to NO_SAFE
return result

def _invoke(
Expand Down
Loading