Skip to content

Commit

Permalink
propagate error on tool failure
Browse files Browse the repository at this point in the history
  • Loading branch information
Caren Thomas committed Dec 18, 2024
1 parent b135223 commit 3527008
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 4 deletions.
25 changes: 23 additions & 2 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def _handle_ai_response(
function_args.pop("self", None)
# error_msg = f"Error calling function {function_name} with args {function_args}: {str(e)}"
# Less detailed - don't provide full args, idea is that it should be in recent context so no need (just adds noise)
error_msg = f"Error calling function {function_name}: {str(e)}"
error_msg = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
printd(error_msg_user)
function_response = package_function_response(False, error_msg)
Expand All @@ -844,8 +844,29 @@ def _handle_ai_response(
self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1])
return messages, False, True # force a heartbeat to allow agent to handle error

# Step 4: check if function response is an error
if function_response_string.startswith("Error"):
function_response = package_function_response(False, function_response_string)
# TODO: truncate error message somehow
messages.append(
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict={
"role": "tool",
"name": function_name,
"content": function_response,
"tool_call_id": tool_call_id,
},
)
) # extend conversation with function response
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
self.interface.function_message(f"Error: {function_response_string}", msg_obj=messages[-1])
return messages, False, True # force a heartbeat to allow agent to handle error

# If no failures happened along the way: ...
# Step 4: send the info on the function call and function response to GPT
# Step 5: send the info on the function call and function response to GPT
messages.append(
Message.dict_to_message(
agent_id=self.agent_state.id,
Expand Down
4 changes: 2 additions & 2 deletions letta/server/rest_api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_
new_message = {"function_return": msg, "status": "success"}

elif msg.startswith("Error: "):
msg = msg.replace("Error: ", "")
msg = msg.replace("Error: ", "", count=1)
new_message = {"function_return": msg, "status": "error"}

else:
Expand Down Expand Up @@ -951,7 +951,7 @@ def function_message(self, msg: str, msg_obj: Optional[Message] = None):
)

elif msg.startswith("Error: "):
msg = msg.replace("Error: ", "")
msg = msg.replace("Error: ", "", count=1)
# new_message = {"function_return": msg, "status": "error"}
assert msg_obj.tool_call_id is not None
new_message = FunctionReturn(
Expand Down
34 changes: 34 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import json
import os
import threading
import time
Expand Down Expand Up @@ -382,6 +383,39 @@ def big_return():
client.delete_agent(agent_id=agent.id)


def test_function_always_error(client: Union[LocalClient, RESTClient]):
"""Test to see if function that errors works correctly"""

def always_error():
"""
Always throw an error.
"""
return 5/0

tool = client.create_or_update_tool(func=always_error)
agent = client.create_agent(tool_ids=[tool.id])
# get function response
response = client.send_message(agent_id=agent.id, message="call the always_error function", role="user")
print(response.messages)

response_message = None
for message in response.messages:
if isinstance(message, FunctionReturn):
response_message = message
break

assert response_message, "FunctionReturn message not found in response"
assert response_message.status == "error"
if isinstance(client, RESTClient):
assert response_message.function_return == "Error executing function always_error: ZeroDivisionError: division by zero"
else:
response_json = json.loads(response_message.function_return)
assert response_json['status'] == "Failed"
assert response_json['message'] == "Error executing function always_error: ZeroDivisionError: division by zero"

client.delete_agent(agent_id=agent.id)


@pytest.mark.asyncio
async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request):
"""
Expand Down

0 comments on commit 3527008

Please sign in to comment.