From 5fd276d27642ff4649bf6e917a21d18e791ed05e Mon Sep 17 00:00:00 2001
From: Jenn Mueng <30991498+jennmueng@users.noreply.github.com>
Date: Wed, 17 Jul 2024 03:01:59 +0700
Subject: [PATCH] fix(autofix): Correctly format chained exceptions (#915)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
The new exception chaining introduced in #899 caused error messages not
to be returned correctly to the agent:
These changes now result in the correct response being sent to the
agent:
---
src/seer/automation/agent/tools.py | 12 ++-
tests/automation/agent/test_agent_tools.py | 87 ++++++++++++++++++++++
2 files changed, 98 insertions(+), 1 deletion(-)
create mode 100644 tests/automation/agent/test_agent_tools.py
diff --git a/src/seer/automation/agent/tools.py b/src/seer/automation/agent/tools.py
index a74f0ec9a..f85cb144d 100644
--- a/src/seer/automation/agent/tools.py
+++ b/src/seer/automation/agent/tools.py
@@ -6,6 +6,16 @@
logger = logging.getLogger(__name__)
+def get_full_exception_string(exc):
+ result = str(exc)
+ if exc.__cause__:
+ if result:
+ result += f"\n\nThe above exception was the direct cause of the following exception:\n\n{str(exc.__cause__)}"
+ else:
+ result = str(exc.__cause__)
+ return result
+
+
class FunctionTool(BaseModel):
name: str
description: str
@@ -18,7 +28,7 @@ def call(self, **kwargs):
return self.fn(**kwargs)
except Exception as e:
logger.exception(e)
- return f"Error: {e}"
+ return f"Error: {get_full_exception_string(e)}"
def to_dict(self):
return {
diff --git a/tests/automation/agent/test_agent_tools.py b/tests/automation/agent/test_agent_tools.py
new file mode 100644
index 000000000..2a7c0f094
--- /dev/null
+++ b/tests/automation/agent/test_agent_tools.py
@@ -0,0 +1,87 @@
+from unittest.mock import Mock, patch
+
+import pytest
+
+from seer.automation.agent.tools import FunctionTool, get_full_exception_string
+
+
+class TestGetFullExceptionString:
+ def test_simple_exception(self):
+ exc = ValueError("Simple error")
+ assert get_full_exception_string(exc) == "Simple error"
+
+ def test_chained_exception(self):
+ try:
+ raise RuntimeError("Main error") from ValueError("Root cause")
+ except RuntimeError as exc:
+ assert (
+ get_full_exception_string(exc)
+ == "Main error\n\nThe above exception was the direct cause of the following exception:\n\nRoot cause"
+ )
+
+ def test_empty_main_exception(self):
+ try:
+ raise RuntimeError() from ValueError("Root cause")
+ except RuntimeError as exc:
+ assert get_full_exception_string(exc) == "Root cause"
+
+
+class TestFunctionTool:
+ @pytest.fixture
+ def mock_function(self):
+ return Mock(return_value="Success")
+
+ @pytest.fixture
+ def function_tool(self, mock_function):
+ return FunctionTool(
+ name="test_tool",
+ description="A test tool",
+ fn=mock_function,
+ parameters=[{"name": "param1", "type": "string"}],
+ )
+
+ def test_successful_call(self, function_tool):
+ result = function_tool.call(param1="test")
+ assert result == "Success"
+
+ def test_exception_handling(self, function_tool):
+ function_tool.fn.side_effect = ValueError("Test error")
+
+ with patch("seer.automation.agent.tools.logger") as mock_logger:
+ result = function_tool.call(param1="test")
+
+ assert result.startswith("Error: Test error")
+ mock_logger.exception.assert_called_once()
+
+ def test_chained_exception_handling(self, function_tool):
+ cause = ValueError("Root cause")
+ main_error = RuntimeError("Main error")
+ main_error.__cause__ = cause
+ function_tool.fn.side_effect = main_error
+
+ with patch("seer.automation.agent.tools.logger") as mock_logger:
+ result = function_tool.call(param1="test")
+
+ expected = "Error: Main error\n\nThe above exception was the direct cause of the following exception:\n\nRoot cause"
+ assert result == expected
+ mock_logger.exception.assert_called_once()
+
+ def test_to_dict(self, function_tool):
+ expected = {
+ "type": "function",
+ "function": {
+ "name": "test_tool",
+ "description": "A test tool",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "param1": {
+ "type": "string",
+ "description": "",
+ }
+ },
+ "required": [],
+ },
+ },
+ }
+ assert function_tool.to_dict() == expected