diff --git a/libs/core/langchain_core/tools.py b/libs/core/langchain_core/tools.py index 10cc95cdebfc7..b1d9e9ead3c34 100644 --- a/libs/core/langchain_core/tools.py +++ b/libs/core/langchain_core/tools.py @@ -20,6 +20,7 @@ BaseModel, Extra, Field, + ValidationError, create_model, root_validator, validate_arguments, @@ -169,6 +170,11 @@ class ChildTool(BaseTool): ] = False """Handle the content of the ToolException thrown.""" + handle_validation_error: Optional[ + Union[bool, str, Callable[[ValidationError], str]] + ] = False + """Handle the content of the ValidationError thrown.""" + class Config(Serializable.Config): """Configuration for this pydantic object.""" @@ -346,6 +352,21 @@ def run( if new_arg_supported else self._run(*tool_args, **tool_kwargs) ) + except ValidationError as e: + if not self.handle_validation_error: + raise e + elif isinstance(self.handle_validation_error, bool): + observation = "Tool input validation error" + elif isinstance(self.handle_validation_error, str): + observation = self.handle_validation_error + elif callable(self.handle_validation_error): + observation = self.handle_validation_error(e) + else: + raise ValueError( + f"Got unexpected type of `handle_validation_error`. Expected bool, " + f"str or callable. Received: {self.handle_validation_error}" + ) + return observation except ToolException as e: if not self.handle_tool_error: run_manager.on_tool_error(e) @@ -422,6 +443,21 @@ async def arun( if new_arg_supported else await self._arun(*tool_args, **tool_kwargs) ) + except ValidationError as e: + if not self.handle_validation_error: + raise e + elif isinstance(self.handle_validation_error, bool): + observation = "Tool input validation error" + elif isinstance(self.handle_validation_error, str): + observation = self.handle_validation_error + elif callable(self.handle_validation_error): + observation = self.handle_validation_error(e) + else: + raise ValueError( + f"Got unexpected type of `handle_validation_error`. Expected bool, " + f"str or callable. Received: {self.handle_validation_error}" + ) + return observation except ToolException as e: if not self.handle_tool_error: await run_manager.on_tool_error(e) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index dda3be342f3a5..33b5d0ff9b562 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -3,7 +3,7 @@ from datetime import datetime from enum import Enum from functools import partial -from typing import Any, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, Union import pytest @@ -11,7 +11,7 @@ AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) -from langchain_core.pydantic_v1 import BaseModel +from langchain_core.pydantic_v1 import BaseModel, ValidationError from langchain_core.tools import ( BaseTool, SchemaAnnotationError, @@ -620,7 +620,10 @@ def test_exception_handling_str() -> None: def test_exception_handling_callable() -> None: expected = "foo bar" - handling = lambda _: expected # noqa: E731 + + def handling(e: ToolException) -> str: + return expected # noqa: E731 + _tool = _FakeExceptionTool(handle_tool_error=handling) actual = _tool.run({}) assert expected == actual @@ -648,7 +651,10 @@ async def test_async_exception_handling_str() -> None: async def test_async_exception_handling_callable() -> None: expected = "foo bar" - handling = lambda _: expected # noqa: E731 + + def handling(e: ToolException) -> str: + return expected # noqa: E731 + _tool = _FakeExceptionTool(handle_tool_error=handling) actual = await _tool.arun({}) assert expected == actual @@ -691,3 +697,127 @@ def foo(bar: int, baz: str) -> str: prefix = "foo(bar: int, baz: str) -> str - " assert foo.__doc__ is not None assert structured_tool.description == prefix + foo.__doc__.strip() + + +def test_validation_error_handling_bool() -> None: + """Test that validation errors are handled correctly.""" + expected = "Tool input validation error" + _tool = _MockStructuredTool(handle_validation_error=True) + actual = _tool.run({}) + assert expected == actual + + +def test_validation_error_handling_str() -> None: + """Test that validation errors are handled correctly.""" + expected = "foo bar" + _tool = _MockStructuredTool(handle_validation_error=expected) + actual = _tool.run({}) + assert expected == actual + + +def test_validation_error_handling_callable() -> None: + """Test that validation errors are handled correctly.""" + expected = "foo bar" + + def handling(e: ValidationError) -> str: + return expected # noqa: E731 + + _tool = _MockStructuredTool(handle_validation_error=handling) + actual = _tool.run({}) + assert expected == actual + + +@pytest.mark.parametrize( + "handler", + [ + True, + "foo bar", + lambda _: "foo bar", + ], +) +def test_validation_error_handling_non_validation_error( + handler: Union[bool, str, Callable[[ValidationError], str]] +) -> None: + """Test that validation errors are handled correctly.""" + + class _RaiseNonValidationErrorTool(BaseTool): + name = "raise_non_validation_error_tool" + description = "A tool that raises a non-validation error" + + def _parse_input( + self, + tool_input: Union[str, Dict], + ) -> Union[str, Dict[str, Any]]: + raise NotImplementedError() + + def _run(self) -> str: + return "dummy" + + async def _arun(self) -> str: + return "dummy" + + _tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) + with pytest.raises(NotImplementedError): + _tool.run({}) + + +async def test_async_validation_error_handling_bool() -> None: + """Test that validation errors are handled correctly.""" + expected = "Tool input validation error" + _tool = _MockStructuredTool(handle_validation_error=True) + actual = await _tool.arun({}) + assert expected == actual + + +async def test_async_validation_error_handling_str() -> None: + """Test that validation errors are handled correctly.""" + expected = "foo bar" + _tool = _MockStructuredTool(handle_validation_error=expected) + actual = await _tool.arun({}) + assert expected == actual + + +async def test_async_validation_error_handling_callable() -> None: + """Test that validation errors are handled correctly.""" + expected = "foo bar" + + def handling(e: ValidationError) -> str: + return expected # noqa: E731 + + _tool = _MockStructuredTool(handle_validation_error=handling) + actual = await _tool.arun({}) + assert expected == actual + + +@pytest.mark.parametrize( + "handler", + [ + True, + "foo bar", + lambda _: "foo bar", + ], +) +async def test_async_validation_error_handling_non_validation_error( + handler: Union[bool, str, Callable[[ValidationError], str]] +) -> None: + """Test that validation errors are handled correctly.""" + + class _RaiseNonValidationErrorTool(BaseTool): + name = "raise_non_validation_error_tool" + description = "A tool that raises a non-validation error" + + def _parse_input( + self, + tool_input: Union[str, Dict], + ) -> Union[str, Dict[str, Any]]: + raise NotImplementedError() + + def _run(self) -> str: + return "dummy" + + async def _arun(self) -> str: + return "dummy" + + _tool = _RaiseNonValidationErrorTool(handle_validation_error=handler) + with pytest.raises(NotImplementedError): + await _tool.arun({})