diff --git a/tests/automation/agent/test_agent_tools.py b/tests/automation/agent/test_agent_tools.py index 2a7c0f094..311db73a8 100644 --- a/tests/automation/agent/test_agent_tools.py +++ b/tests/automation/agent/test_agent_tools.py @@ -66,22 +66,47 @@ def test_chained_exception_handling(self, function_tool): assert result == expected mock_logger.exception.assert_called_once() - def test_to_dict(self, function_tool): - expected = { - "type": "function", - "function": { - "name": "test_tool", - "description": "A test tool", - "parameters": { - "type": "object", - "properties": { - "param1": { - "type": "string", - "description": "", - } + @pytest.mark.parametrize( + "model, expected", + [ + ( + "gpt", + { + "type": "function", + "function": { + "name": "test_tool", + "description": "A test tool", + "parameters": { + "type": "object", + "properties": { + "param1": { + "type": "string", + "description": "", + } + }, + "required": [], + }, }, - "required": [], }, - }, - } - assert function_tool.to_dict() == expected + ), + ( + "claude", + { + "name": "test_tool", + "description": "A test tool", + "input_schema": { + "type": "object", + "properties": { + "param1": { + "type": "string", + "description": "", + } + }, + "required": [], + }, + }, + ), + ], + ) + def test_to_dict(self, function_tool, model, expected): + assert function_tool.to_dict(model=model) == expected