From 97a30635e584dced22ec46910d1651692b87013e Mon Sep 17 00:00:00 2001 From: Samuel CHNIBER Date: Thu, 12 Dec 2024 22:46:31 +0100 Subject: [PATCH] addition of Amazon Bedrock Guardrails for Llama Index Bedrock & BedrockConverse --- .../llama_index/llms/bedrock_converse/base.py | 27 +++++++++++++++ .../llms/bedrock_converse/utils.py | 22 ++++++++++-- .../tests/test_llms_bedrock_converse.py | 6 ++++ .../llama_index/llms/bedrock/base.py | 21 ++++++++++++ .../llama_index/llms/bedrock/utils.py | 34 ++++++++++++++++--- .../tests/test_bedrock.py | 7 ++-- 6 files changed, 109 insertions(+), 8 deletions(-) diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py index 706772fec648a..a05e76062c10a 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/base.py @@ -114,6 +114,15 @@ class BedrockConverse(FunctionCallingLLM): default=60.0, description="The timeout for the Bedrock API request in seconds. It will be used for both connect and read timeouts.", ) + guardrail_identifier: Optional[str] = Field( + description="The unique identifier of the guardrail that you want to use. If you don’t provide a value, no guardrail is applied to the invocation." + ), + guardrail_version: Optional[str] = Field( + description="The version number for the guardrail. The value can also be DRAFT" + ), + trace: Optional[str] = Field( + description="Specifies whether to enable or disable the Bedrock trace. If enabled, you can see the full Bedrock trace." + ), additional_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Additional kwargs for the bedrock invokeModel request.", @@ -145,6 +154,9 @@ def __init__( completion_to_prompt: Optional[Callable[[str], str]] = None, pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, output_parser: Optional[BaseOutputParser] = None, + guardrail_identifier: Optional[str] = None, + guardrail_version: Optional[str] = None, + trace: Optional[str] = None, ) -> None: additional_kwargs = additional_kwargs or {} callback_manager = callback_manager or CallbackManager([]) @@ -178,6 +190,9 @@ def __init__( region_name=region_name, botocore_session=botocore_session, botocore_config=botocore_config, + guardrail_identifier=guardrail_identifier, + guardrail_version=guardrail_version, + trace=trace ) self._config = None @@ -292,6 +307,9 @@ def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: system_prompt=self.system_prompt, max_retries=self.max_retries, stream=False, + guardrail_identifier=self.guardrail_identifier, + guardrail_version=self.guardrail_version, + trace=self.trace, **all_kwargs, ) @@ -336,6 +354,9 @@ def stream_chat( system_prompt=self.system_prompt, max_retries=self.max_retries, stream=True, + guardrail_identifier=self.guardrail_identifier, + guardrail_version=self.guardrail_version, + trace=self.trace, **all_kwargs, ) @@ -416,6 +437,9 @@ async def achat( system_prompt=self.system_prompt, max_retries=self.max_retries, stream=False, + guardrail_identifier=self.guardrail_identifier, + guardrail_version=self.guardrail_version, + trace=self.trace, **all_kwargs, ) @@ -461,6 +485,9 @@ async def astream_chat( system_prompt=self.system_prompt, max_retries=self.max_retries, stream=True, + guardrail_identifier=self.guardrail_identifier, + guardrail_version=self.guardrail_version, + trace=self.trace, **all_kwargs, ) diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/utils.py b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/utils.py index 7d795dfc5c5e8..e54acfa9625bf 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/llama_index/llms/bedrock_converse/utils.py @@ -307,6 +307,9 @@ def converse_with_retry( max_tokens: int = 1000, temperature: float = 0.1, stream: bool = False, + guardrail_identifier: Optional[str] = None, + guardrail_version: Optional[str] = None, + trace: Optional[str] = None, **kwargs: Any, ) -> Any: """Use tenacity to retry the completion call.""" @@ -323,8 +326,14 @@ def converse_with_retry( converse_kwargs["system"] = [{"text": system_prompt}] if tool_config := kwargs.get("tools"): converse_kwargs["toolConfig"] = tool_config + if guardrail_identifier and guardrail_version: + converse_kwargs['guardrailConfig'] = {} + converse_kwargs["guardrailConfig"]["guardrailIdentifier"] = guardrail_identifier + converse_kwargs["guardrailConfig"]["guardrailVersion"] = guardrail_version + if trace: + converse_kwargs["guardrailConfig"]["trace"] = trace converse_kwargs = join_two_dicts( - converse_kwargs, {k: v for k, v in kwargs.items() if k != "tools"} + converse_kwargs, {k: v for k, v in kwargs.items() if (k != "tools" or k != "guardrail_identifier" or k != "guardrail_version" or k != "trace")} ) @retry_decorator @@ -346,6 +355,9 @@ async def converse_with_retry_async( max_tokens: int = 1000, temperature: float = 0.1, stream: bool = False, + guardrail_identifier: Optional[str] = None, + guardrail_version: Optional[str] = None, + trace: Optional[str] = None, **kwargs: Any, ) -> Any: """Use tenacity to retry the completion call.""" @@ -362,8 +374,14 @@ async def converse_with_retry_async( converse_kwargs["system"] = [{"text": system_prompt}] if tool_config := kwargs.get("tools"): converse_kwargs["toolConfig"] = tool_config + if guardrail_identifier and guardrail_version: + converse_kwargs['guardrailConfig'] = {} + converse_kwargs["guardrailConfig"]["guardrailIdentifier"] = guardrail_identifier + converse_kwargs["guardrailConfig"]["guardrailVersion"] = guardrail_version + if trace: + converse_kwargs["guardrailConfig"]["trace"] = trace converse_kwargs = join_two_dicts( - converse_kwargs, {k: v for k, v in kwargs.items() if k != "tools"} + converse_kwargs, {k: v for k, v in kwargs.items() if (k != "tools" or k != "guardrail_identifier" or k != "guardrail_version" or k != "trace")} ) ## NOTE: Returning the generator directly from converse_stream doesn't work diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/tests/test_llms_bedrock_converse.py b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/tests/test_llms_bedrock_converse.py index b47266ee5236d..c1ecd85cc080d 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock-converse/tests/test_llms_bedrock_converse.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock-converse/tests/test_llms_bedrock_converse.py @@ -14,6 +14,9 @@ EXP_MAX_TOKENS = 100 EXP_TEMPERATURE = 0.7 EXP_MODEL = "anthropic.claude-v2" +EXP_GUARDRAIL_ID = "IDENTIFIER" +EXP_GUARDRAIL_VERSION = "DRAFT" +EXP_GUARDRAIL_TRACE = "ENABLED" # Reused chat message and prompt messages = [ChatMessage(role=MessageRole.USER, content="Test")] @@ -88,6 +91,9 @@ def bedrock_converse(mock_boto3_session, mock_aioboto3_session): model=EXP_MODEL, max_tokens=EXP_MAX_TOKENS, temperature=EXP_TEMPERATURE, + guardrail_identifier=EXP_GUARDRAIL_ID, + guardrail_version=EXP_GUARDRAIL_VERSION, + trace=EXP_GUARDRAIL_TRACE, callback_manager=CallbackManager(), ) diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/base.py b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/base.py index 55732f8598bdb..f4d159b8670ce 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/base.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/base.py @@ -94,6 +94,15 @@ class Bedrock(LLM): default=60.0, description="The timeout for the Bedrock API request in seconds. It will be used for both connect and read timeouts.", ) + guardrail_identifier: Optional[str] = Field( + description="The unique identifier of the guardrail that you want to use. If you don’t provide a value, no guardrail is applied to the invocation." + ), + guardrail_version: Optional[str] = Field( + description="The version number for the guardrail. The value can also be DRAFT" + ), + trace: Optional[str] = Field( + description="Specifies whether to enable or disable the Bedrock trace. If enabled, you can see the full Bedrock trace." + ), additional_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Additional kwargs for the bedrock invokeModel request.", @@ -125,6 +134,9 @@ def __init__( completion_to_prompt: Optional[Callable[[str], str]] = None, pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, output_parser: Optional[BaseOutputParser] = None, + guardrail_identifier: Optional[str] = None, + guardrail_version: Optional[str] = None, + trace: Optional[str] = None, **kwargs: Any, ) -> None: if context_size is None and model not in BEDROCK_FOUNDATION_LLMS: @@ -187,6 +199,9 @@ def __init__( completion_to_prompt=completion_to_prompt, pydantic_program_mode=pydantic_program_mode, output_parser=output_parser, + guardrail_identifier=guardrail_identifier, + guardrail_version=guardrail_version, + trace=trace, ) self._provider = get_provider(model) self.messages_to_prompt = ( @@ -257,6 +272,9 @@ def complete( model=self.model, request_body=request_body_str, max_retries=self.max_retries, + guardrail_identifier=self.guardrail_identifier, + guardrail_version=self.guardrail_version, + trace=self.trace, **all_kwargs, ) response_body = response["body"].read() @@ -287,6 +305,9 @@ def stream_complete( request_body=request_body_str, max_retries=self.max_retries, stream=True, + guardrail_identifier=self.guardrail_identifier, + guardrail_version=self.guardrail_version, + trace=self.trace, **all_kwargs, ) response_body = response["body"] diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py index d9b90ed8b899f..e376dc8124720 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock/llama_index/llms/bedrock/utils.py @@ -299,6 +299,9 @@ def completion_with_retry( request_body: str, max_retries: int, stream: bool = False, + guardrail_identifier: Optional[str] = None, + guardrail_version: Optional[str] = None, + trace: Optional[str] = None, **kwargs: Any, ) -> Any: """Use tenacity to retry the completion call.""" @@ -307,9 +310,32 @@ def completion_with_retry( @retry_decorator def _completion_with_retry(**kwargs: Any) -> Any: if stream: - return client.invoke_model_with_response_stream( - modelId=model, body=request_body - ) - return client.invoke_model(modelId=model, body=request_body) + if (guardrail_identifier == None or guardrail_version == None): + return client.invoke_model_with_response_stream( + modelId=model, + body=request_body, + ) + else: + return client.invoke_model_with_response_stream( + modelId=model, + body=request_body, + guardrailIdentifier=guardrail_identifier, + guardrailVersion=guardrail_version, + trace=trace + ) + else: + if (guardrail_identifier == None or guardrail_version == None): + return client.invoke_model( + modelId=model, + body=request_body + ) + else: + return client.invoke_model( + modelId=model, + body=request_body, + guardrailIdentifier=guardrail_identifier, + guardrailVersion=guardrail_version, + trace=trace + ) return _completion_with_retry(**kwargs) diff --git a/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py b/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py index 0f59f66bba493..2545ef37173f5 100644 --- a/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py +++ b/llama-index-integrations/llms/llama-index-llms-bedrock/tests/test_bedrock.py @@ -147,6 +147,9 @@ def test_model_basic( profile_name=None, region_name="us-east-1", aws_access_key_id="test", + guardrail_identifier="test", + guardrail_version="test", + trace="ENABLED", ) bedrock_stubber = Stubber(llm._client) @@ -155,13 +158,13 @@ def test_model_basic( bedrock_stubber.add_response( "invoke_model", get_invoke_model_response(response_body), - {"body": complete_request, "modelId": model}, + {"body": complete_request, "modelId": model, "guardrailIdentifier": "test", "guardrailVersion": "test", "trace": "ENABLED"}, ) # response for llm.chat() bedrock_stubber.add_response( "invoke_model", get_invoke_model_response(response_body), - {"body": chat_request, "modelId": model}, + {"body": chat_request, "modelId": model, "guardrailIdentifier": "test", "guardrailVersion": "test", "trace": "ENABLED"}, ) bedrock_stubber.activate()