Skip to content

Commit

Permalink
core[minor]: add validation error handler to BaseTool (langchain-ai…
Browse files Browse the repository at this point in the history
…#14007)

- **Description:** add a ValidationError handler as a field of
[`BaseTool`](https://github.com/langchain-ai/langchain/blob/master/libs/core/langchain_core/tools.py#L101)
and add unit tests for the code change.
- **Issue:** langchain-ai#12721 langchain-ai#13662
- **Dependencies:** None
- **Tag maintainer:** 
- **Twitter handle:** @hmdev3
- **NOTE:**
  - I'm wondering if the update of document is required.

---------

Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
  • Loading branch information
hmasdev and eyurtsev authored Feb 2, 2024
1 parent bdacfaf commit cc17334
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 4 deletions.
36 changes: 36 additions & 0 deletions libs/core/langchain_core/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
BaseModel,
Extra,
Field,
ValidationError,
create_model,
root_validator,
validate_arguments,
Expand Down Expand Up @@ -169,6 +170,11 @@ class ChildTool(BaseTool):
] = False
"""Handle the content of the ToolException thrown."""

handle_validation_error: Optional[
Union[bool, str, Callable[[ValidationError], str]]
] = False
"""Handle the content of the ValidationError thrown."""

class Config(Serializable.Config):
"""Configuration for this pydantic object."""

Expand Down Expand Up @@ -346,6 +352,21 @@ def run(
if new_arg_supported
else self._run(*tool_args, **tool_kwargs)
)
except ValidationError as e:
if not self.handle_validation_error:
raise e
elif isinstance(self.handle_validation_error, bool):
observation = "Tool input validation error"
elif isinstance(self.handle_validation_error, str):
observation = self.handle_validation_error
elif callable(self.handle_validation_error):
observation = self.handle_validation_error(e)
else:
raise ValueError(
f"Got unexpected type of `handle_validation_error`. Expected bool, "
f"str or callable. Received: {self.handle_validation_error}"
)
return observation
except ToolException as e:
if not self.handle_tool_error:
run_manager.on_tool_error(e)
Expand Down Expand Up @@ -422,6 +443,21 @@ async def arun(
if new_arg_supported
else await self._arun(*tool_args, **tool_kwargs)
)
except ValidationError as e:
if not self.handle_validation_error:
raise e
elif isinstance(self.handle_validation_error, bool):
observation = "Tool input validation error"
elif isinstance(self.handle_validation_error, str):
observation = self.handle_validation_error
elif callable(self.handle_validation_error):
observation = self.handle_validation_error(e)
else:
raise ValueError(
f"Got unexpected type of `handle_validation_error`. Expected bool, "
f"str or callable. Received: {self.handle_validation_error}"
)
return observation
except ToolException as e:
if not self.handle_tool_error:
await run_manager.on_tool_error(e)
Expand Down
138 changes: 134 additions & 4 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from datetime import datetime
from enum import Enum
from functools import partial
from typing import Any, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, Union

import pytest

from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.tools import (
BaseTool,
SchemaAnnotationError,
Expand Down Expand Up @@ -620,7 +620,10 @@ def test_exception_handling_str() -> None:

def test_exception_handling_callable() -> None:
expected = "foo bar"
handling = lambda _: expected # noqa: E731

def handling(e: ToolException) -> str:
return expected # noqa: E731

_tool = _FakeExceptionTool(handle_tool_error=handling)
actual = _tool.run({})
assert expected == actual
Expand Down Expand Up @@ -648,7 +651,10 @@ async def test_async_exception_handling_str() -> None:

async def test_async_exception_handling_callable() -> None:
expected = "foo bar"
handling = lambda _: expected # noqa: E731

def handling(e: ToolException) -> str:
return expected # noqa: E731

_tool = _FakeExceptionTool(handle_tool_error=handling)
actual = await _tool.arun({})
assert expected == actual
Expand Down Expand Up @@ -691,3 +697,127 @@ def foo(bar: int, baz: str) -> str:
prefix = "foo(bar: int, baz: str) -> str - "
assert foo.__doc__ is not None
assert structured_tool.description == prefix + foo.__doc__.strip()


def test_validation_error_handling_bool() -> None:
"""Test that validation errors are handled correctly."""
expected = "Tool input validation error"
_tool = _MockStructuredTool(handle_validation_error=True)
actual = _tool.run({})
assert expected == actual


def test_validation_error_handling_str() -> None:
"""Test that validation errors are handled correctly."""
expected = "foo bar"
_tool = _MockStructuredTool(handle_validation_error=expected)
actual = _tool.run({})
assert expected == actual


def test_validation_error_handling_callable() -> None:
"""Test that validation errors are handled correctly."""
expected = "foo bar"

def handling(e: ValidationError) -> str:
return expected # noqa: E731

_tool = _MockStructuredTool(handle_validation_error=handling)
actual = _tool.run({})
assert expected == actual


@pytest.mark.parametrize(
"handler",
[
True,
"foo bar",
lambda _: "foo bar",
],
)
def test_validation_error_handling_non_validation_error(
handler: Union[bool, str, Callable[[ValidationError], str]]
) -> None:
"""Test that validation errors are handled correctly."""

class _RaiseNonValidationErrorTool(BaseTool):
name = "raise_non_validation_error_tool"
description = "A tool that raises a non-validation error"

def _parse_input(
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
raise NotImplementedError()

def _run(self) -> str:
return "dummy"

async def _arun(self) -> str:
return "dummy"

_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
with pytest.raises(NotImplementedError):
_tool.run({})


async def test_async_validation_error_handling_bool() -> None:
"""Test that validation errors are handled correctly."""
expected = "Tool input validation error"
_tool = _MockStructuredTool(handle_validation_error=True)
actual = await _tool.arun({})
assert expected == actual


async def test_async_validation_error_handling_str() -> None:
"""Test that validation errors are handled correctly."""
expected = "foo bar"
_tool = _MockStructuredTool(handle_validation_error=expected)
actual = await _tool.arun({})
assert expected == actual


async def test_async_validation_error_handling_callable() -> None:
"""Test that validation errors are handled correctly."""
expected = "foo bar"

def handling(e: ValidationError) -> str:
return expected # noqa: E731

_tool = _MockStructuredTool(handle_validation_error=handling)
actual = await _tool.arun({})
assert expected == actual


@pytest.mark.parametrize(
"handler",
[
True,
"foo bar",
lambda _: "foo bar",
],
)
async def test_async_validation_error_handling_non_validation_error(
handler: Union[bool, str, Callable[[ValidationError], str]]
) -> None:
"""Test that validation errors are handled correctly."""

class _RaiseNonValidationErrorTool(BaseTool):
name = "raise_non_validation_error_tool"
description = "A tool that raises a non-validation error"

def _parse_input(
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
raise NotImplementedError()

def _run(self) -> str:
return "dummy"

async def _arun(self) -> str:
return "dummy"

_tool = _RaiseNonValidationErrorTool(handle_validation_error=handler)
with pytest.raises(NotImplementedError):
await _tool.arun({})

0 comments on commit cc17334

Please sign in to comment.