Skip to content

Commit

Permalink
Tool chaining in agent flow
Browse files Browse the repository at this point in the history
* added state to track last function result in agent
* added logic in ToolRuleSolver to choose correct next tool
* Integrated test cases for conditional tools in agent
  • Loading branch information
Mindy Long committed Dec 18, 2024
1 parent 1dea00c commit b4a1534
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 15 deletions.
7 changes: 6 additions & 1 deletion letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ def __init__(

self.first_message_verify_mono = first_message_verify_mono

# State needed for conditional tool chaining
self.last_function_response = None

# Controls if the convo memory pressure warning is triggered
# When an alert is sent in the message queue, set this to True (to avoid repeat alerts)
# When the summarizer is run, set this back to False (to reset)
Expand Down Expand Up @@ -586,7 +589,7 @@ def _get_ai_reply(
) -> ChatCompletionResponse:
"""Get response from LLM API with robust retry mechanism."""

allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names()
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(last_function_response=self.last_function_response)
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]

allowed_functions = (
Expand Down Expand Up @@ -826,6 +829,7 @@ def _handle_ai_response(
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
printd(error_msg_user)
function_response = package_function_response(False, error_msg)
self.last_function_response = function_response
# TODO: truncate error message somehow
messages.append(
Message.dict_to_message(
Expand Down Expand Up @@ -861,6 +865,7 @@ def _handle_ai_response(
) # extend conversation with function response
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1])
self.last_function_response = function_response

else:
# Standard non-function reply
Expand Down
71 changes: 60 additions & 11 deletions letta/helpers/tool_rule_solver.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Dict, List, Optional, Union
from collections import deque

Expand Down Expand Up @@ -58,7 +59,7 @@ def update_tool_usage(self, tool_name: str):
"""Update the internal state to track the last tool called."""
self.last_tool_name = tool_name

def get_allowed_tool_names(self, error_on_empty: bool = False) -> List[str]:
def get_allowed_tool_names(self, error_on_empty: bool = False, last_function_response: Optional[str] = None) -> List[str]:
"""Get a list of tool names allowed based on the last tool called."""
if self.last_tool_name is None:
# Use initial tool rules if no tool has been called yet
Expand All @@ -67,18 +68,20 @@ def get_allowed_tool_names(self, error_on_empty: bool = False) -> List[str]:
# Find a matching ToolRule for the last tool used
current_rule = next((rule for rule in self.tool_rules if rule.tool_name == self.last_tool_name), None)

# Return children which must exist on ToolRule
if current_rule:
return current_rule.children

# Default to empty if no rule matches
message = "User provided tool rules and execution state resolved to no more possible tool calls."
if error_on_empty:
raise RuntimeError(message)
else:
# warnings.warn(message)
if current_rule is None:
if error_on_empty:
raise ValueError(f"No tool rule found for {self.last_tool_name}")
return []

# If the current rule is a conditional tool rule, use the LLM response to
# determine which child tool to use
if isinstance(current_rule, ConditionalToolRule):
if not last_function_response:
raise ValueError("Conditional tool rule requires an LLM response to determine which child tool to use")
return [self.evaluate_conditional_tool(current_rule, last_function_response)]

return current_rule.children if current_rule.children else []

def is_terminal_tool(self, tool_name: str) -> bool:
"""Check if the tool is defined as a terminal tool in the terminal tool rules."""
return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules)
Expand All @@ -88,6 +91,15 @@ def has_children_tools(self, tool_name):
return any(rule.tool_name == tool_name for rule in self.tool_rules)

def validate_conditional_tool(self, rule: ConditionalToolRule):
'''
Validate a conditional tool rule
Args:
rule (ConditionalToolRule): The conditional tool rule to validate
Raises:
ToolRuleValidationError: If the rule is invalid
'''
if rule.children is None or len(rule.children) == 0:
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
if len(rule.children) != len(rule.child_output_mapping):
Expand Down Expand Up @@ -142,3 +154,40 @@ def has_path_to_terminal(start_tool: str) -> bool:
return False

return True # All init tools have paths to terminal tools

def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str:
'''
Parse function response to determine which child tool to use based on the mapping
Args:
tool (ConditionalToolRule): The conditional tool rule
last_function_response (str): The function response in JSON format
Returns:
str: The name of the child tool to use next
'''
json_response = json.loads(last_function_response)
function_output = json_response["message"]

# Try to match the function output with a mapping key
for key in tool.child_output_mapping:

# Convert function output to match key type for comparison
if key == "true" or key == "false":
try:
typed_output = function_output.lower()
except AttributeError:
continue
elif isinstance(key, int):
try:
typed_output = int(function_output)
except (ValueError, TypeError):
continue
else: # string
typed_output = str(function_output)

if typed_output == key:
return tool.child_output_mapping[key]

# If no match found, use default
return tool.default_child
7 changes: 5 additions & 2 deletions letta/orm/custom_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule


class EmbeddingConfigColumn(TypeDecorator):
Expand Down Expand Up @@ -80,7 +80,7 @@ def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, Init
return value

@staticmethod
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
if rule_type == ToolRuleType.run_first:
Expand All @@ -90,6 +90,9 @@ def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, Term
elif rule_type == ToolRuleType.constrain_child_tools:
rule = ChildToolRule(**data)
return rule
elif rule_type == ToolRuleType.conditional:
rule = ConditionalToolRule(**data)
return rule
else:
raise ValueError(f"Unknown tool rule type: {rule_type}")

Expand Down
211 changes: 210 additions & 1 deletion tests/integration_test_agent_tool_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import pytest
from letta import create_client
from letta.schemas.letta_message import FunctionCallMessage
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from letta.schemas.tool_rule import (
ChildToolRule,
ConditionalToolRule,
InitToolRule,
TerminalToolRule,
)
from tests.helpers.endpoints_helper import (
assert_invoked_function_call,
assert_invoked_send_message_with_keyword,
Expand Down Expand Up @@ -68,6 +73,50 @@ def fourth_secret_word(prev_secret_word: str):
return "banana"


def flip_coin():
"""
Call this to retrieve the password to the secret word, which you will need to output in a send_message later.
If it returns an empty string, try flipping again!
Returns:
str: The password or an empty string
"""
import random

# Flip a coin with 50% chance
if random.random() < 0.5:
return ""
return "hj2hwibbqm"


def flip_coin_hard():
"""
Call this to retrieve the password to the secret word, which you will need to output in a send_message later.
If it returns an empty string, try flipping again!
Returns:
str: The password or an empty string
"""
import random

# Flip a coin with 50% chance
result = random.random()
if result < 0.5:
return ""
if result < 0.75:
return "START_OVER"
return "hj2hwibbqm"


def can_play_game():
"""
Call this to start the tool chain.
"""
import random

return random.random() < 0.5


def auto_error():
"""
If you call this function, it will throw an error automatically.
Expand Down Expand Up @@ -282,3 +331,163 @@ def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none):

print(f"Got successful response from client: \n\n{response}")
cleanup(client=client, agent_uuid=agent_uuid)


@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely
def test_agent_conditional_tool_easy(mock_e2b_api_key_none):
"""
Test the agent with a conditional tool that has a child tool.
Tool Flow:
-------
| |
| v
-- flip_coin
|
v
reveal_secret_word
"""

client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)

coin_flip_name = "flip_coin"
secret_word_tool = "fourth_secret_word"
flip_coin_tool = client.create_or_update_tool(flip_coin, name=coin_flip_name)
reveal_secret = client.create_or_update_tool(fourth_secret_word, name=secret_word_tool)

# Make tool rules
tool_rules = [
InitToolRule(tool_name=coin_flip_name),
ConditionalToolRule(
tool_name=coin_flip_name,
default_child=coin_flip_name,
children=[secret_word_tool],
child_output_mapping={
"hj2hwibbqm": secret_word_tool,
}
),
TerminalToolRule(tool_name=secret_word_tool),
]
tools = [flip_coin_tool, reveal_secret]

config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
response = client.user_message(agent_id=agent_state.id, message="flip a coin until you get the secret word")

# Make checks
assert_sanity_checks(response)

# Assert the tools were called
assert_invoked_function_call(response.messages, "flip_coin")
assert_invoked_function_call(response.messages, "fourth_secret_word")

# Check ordering of tool calls
found_secret_word = False
for m in response.messages:
if isinstance(m, FunctionCallMessage):
if m.function_call.name == secret_word_tool:
# Should be the last tool call
found_secret_word = True
else:
# Before finding secret_word, only flip_coin should be called
assert m.function_call.name == coin_flip_name
assert not found_secret_word

# Ensure we found the secret word exactly once
assert found_secret_word

print(f"Got successful response from client: \n\n{response}")
cleanup(client=client, agent_uuid=agent_uuid)



@pytest.mark.timeout(90) # Longer timeout since this test has more steps
def test_agent_conditional_tool_hard(mock_e2b_api_key_none):
"""
Test the agent with a complex conditional tool graph
Tool Flow:
can_play_game <---+
| |
v |
flip_coin -----+
|
v
fourth_secret_word
"""
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)

# Create tools
play_game = "can_play_game"
coin_flip_name = "flip_coin_hard"
final_tool = "fourth_secret_word"
play_game_tool = client.create_or_update_tool(can_play_game, name=play_game)
flip_coin_tool = client.create_or_update_tool(flip_coin_hard, name=coin_flip_name)
reveal_secret = client.create_or_update_tool(fourth_secret_word, name=final_tool)

# Make tool rules - chain them together with conditional rules
tool_rules = [
InitToolRule(tool_name=play_game),
ConditionalToolRule(
tool_name=play_game,
default_child=play_game, # Keep trying if we can't play
children=[coin_flip_name],
child_output_mapping={
True: coin_flip_name # Only allow access when can_play_game returns True
}
),
ConditionalToolRule(
tool_name=coin_flip_name,
default_child=coin_flip_name,
children=[play_game, final_tool],
child_output_mapping={
"hj2hwibbqm": final_tool, "START_OVER": play_game
}
),
TerminalToolRule(tool_name=final_tool),
]

# Setup agent with all tools
tools = [play_game_tool, flip_coin_tool, reveal_secret]
config_file = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json"
agent_state = setup_agent(
client,
config_file,
agent_uuid=agent_uuid,
tool_ids=[t.id for t in tools],
tool_rules=tool_rules
)

# Ask agent to try to get all secret words
response = client.user_message(agent_id=agent_state.id, message="hi")

# Make checks
assert_sanity_checks(response)

# Assert all tools were called
assert_invoked_function_call(response.messages, play_game)
assert_invoked_function_call(response.messages, final_tool)

# Check ordering of tool calls
found_words = []
for m in response.messages:
if isinstance(m, FunctionCallMessage):
name = m.function_call.name
if name in [play_game, coin_flip_name]:
# Before finding secret_word, only can_play_game and flip_coin should be called
assert name in [play_game, coin_flip_name]
else:
# Should find secret words in order
expected_word = final_tool
assert name == expected_word, f"Found {name} but expected {expected_word}"
found_words.append(name)

# Ensure we found all secret words in order
assert found_words == [final_tool]

print(f"Got successful response from client: \n\n{response}")
cleanup(client=client, agent_uuid=agent_uuid)

0 comments on commit b4a1534

Please sign in to comment.