Skip to content

Commit

Permalink
Fixes: less stringent tool chain checks, correct boolean eval logic f…
Browse files Browse the repository at this point in the history
…or conditional tools
  • Loading branch information
Mindy Long committed Dec 19, 2024
1 parent b4a1534 commit 9d28212
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 67 deletions.
62 changes: 6 additions & 56 deletions letta/helpers/tool_rule_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ def __init__(self, tool_rules: List[BaseToolRule], **kwargs):
assert isinstance(rule, TerminalToolRule)
self.terminal_tool_rules.append(rule)

# Validate the tool rules to ensure they form a DAG
if not self.validate_tool_rules():
raise ToolRuleValidationError("Tool rules does not have a path from Init to Terminal.")

def update_tool_usage(self, tool_name: str):
"""Update the internal state to track the last tool called."""
Expand Down Expand Up @@ -108,53 +105,6 @@ def validate_conditional_tool(self, rule: ConditionalToolRule):
raise ToolRuleValidationError("Conditional tool rule must have a child output mapping for each child tool.")
return True

def validate_tool_rules(self) -> bool:
"""
Validate that there exists a path from every init tool to a terminal tool.
Returns True if valid (path exists), otherwise False.
"""
# Build adjacency list for the tool graph
adjacency_list: Dict[str, List[str]] = {rule.tool_name: rule.children for rule in self.tool_rules}

init_tool_names = {rule.tool_name for rule in self.init_tool_rules}
terminal_tool_names = {rule.tool_name for rule in self.terminal_tool_rules}

# Initial checks
if len(init_tool_names) == 0:
if len(terminal_tool_names) + len(self.tool_rules) > 0:
return False # No init tools defined
else:
return True # No tool rules
if len(terminal_tool_names) == 0:
if len(adjacency_list) > 0:
return False # No terminal tools defined
else:
return True # Only init tools

# Define BFS helper function to find path to terminal tool
def has_path_to_terminal(start_tool: str) -> bool:
visited = set()
queue = deque([start_tool])
visited.add(start_tool)

while queue:
current_tool = queue.popleft()
if current_tool in terminal_tool_names:
return True

for child in adjacency_list.get(current_tool, []):
if child not in visited:
visited.add(child)
queue.append(child)
return False

# Check if each init tool has a path to a terminal tool
for init_tool_name in init_tool_names:
if not has_path_to_terminal(init_tool_name):
return False

return True # All init tools have paths to terminal tools

def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str:
'''
Parse function response to determine which child tool to use based on the mapping
Expand All @@ -173,18 +123,18 @@ def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_res
for key in tool.child_output_mapping:

# Convert function output to match key type for comparison
if key == "true" or key == "false":
try:
typed_output = function_output.lower()
except AttributeError:
continue
if isinstance(key, bool):
typed_output = function_output.lower() == "true"
elif isinstance(key, int):
try:
typed_output = int(function_output)
except (ValueError, TypeError):
continue
else: # string
typed_output = str(function_output)
if function_output == "True" or function_output == "False":
typed_output = function_output.lower()
else:
typed_output = function_output

if typed_output == key:
return tool.child_output_mapping[key]
Expand Down
1 change: 1 addition & 0 deletions tests/integration_test_agent_tool_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none):
tool_rules = [
InitToolRule(tool_name=t1_name),
ChildToolRule(tool_name=t1_name, children=[t2_name]),
TerminalToolRule(tool_name=t2_name)
]
tools = [t1, t2]

Expand Down
4 changes: 2 additions & 2 deletions tests/integration_test_offline_memory_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def test_ripple_edit(client, mock_e2b_api_key_none):

assert set(conversation_agent.memory.list_block_labels()) == {"persona", "human", "fact_block", "rethink_memory_block"}

rethink_memory_tool = client.create_tool(rethink_memory)
finish_rethinking_memory_tool = client.create_tool(finish_rethinking_memory)
rethink_memory_tool = client.create_or_update_tool(rethink_memory)
finish_rethinking_memory_tool = client.create_or_update_tool(finish_rethinking_memory)
offline_memory_agent = client.create_agent(
name="offline_memory_agent",
agent_type=AgentType.offline_memory_agent,
Expand Down
16 changes: 7 additions & 9 deletions tests/test_tool_rule_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def test_get_allowed_tool_names_with_subsequent_rule():
# Setup: Tool rule sequence
init_rule = InitToolRule(tool_name=START_TOOL)
rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL])
rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[END_TOOL])
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1, rule_2], terminal_tool_rules=[terminal_rule])
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[])

# Action: Update usage and get allowed tools
solver.update_tool_usage(START_TOOL)
Expand All @@ -51,9 +49,8 @@ def test_get_allowed_tool_names_with_subsequent_rule():
def test_is_terminal_tool():
# Setup: Terminal tool rule configuration
init_rule = InitToolRule(tool_name=START_TOOL)
rule_1 = ChildToolRule(tool_name=START_TOOL, children=[END_TOOL])
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[terminal_rule])
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[terminal_rule])

# Action & Assert: Verify terminal and non-terminal tools
assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as a terminal tool"
Expand Down Expand Up @@ -83,9 +80,9 @@ def test_get_allowed_tool_names_no_matching_rule_error():
init_rule = InitToolRule(tool_name=START_TOOL)
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[])

# Action & Assert: Set last tool to an unrecognized tool and expect RuntimeError when error_on_empty=True
# Action & Assert: Set last tool to an unrecognized tool and expect ValueError
solver.update_tool_usage(UNRECOGNIZED_TOOL)
with pytest.raises(RuntimeError, match="resolved to no more possible tool calls"):
with pytest.raises(ValueError, match=f"No tool rule found for {UNRECOGNIZED_TOOL}"):
solver.get_allowed_tool_names(error_on_empty=True)


Expand Down Expand Up @@ -119,7 +116,7 @@ def test_conditional_tool_rule():
rule = ConditionalToolRule(
tool_name=START_TOOL,
children=[START_TOOL, END_TOOL],
default_child=END_TOOL,
default_child=START_TOOL,
child_output_mapping={True: END_TOOL, False: START_TOOL}
)
solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule])
Expand All @@ -130,7 +127,8 @@ def test_conditional_tool_rule():

# Step 2: After using 'start_tool'
solver.update_tool_usage(START_TOOL)
assert set(solver.get_allowed_tool_names()) == set({END_TOOL, START_TOOL}), "After 'start_tool', should allow 'start_tool' or 'end_tool'"
assert solver.get_allowed_tool_names(last_function_response='{"message": "true"}') == [END_TOOL], "After 'start_tool' returns true, should allow 'end_tool'"
assert solver.get_allowed_tool_names(last_function_response='{"message": "false"}') == [START_TOOL], "After 'start_tool' returns false, should allow 'start_tool'"

# Step 3: After using 'end_tool'
assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal"
Expand Down

0 comments on commit 9d28212

Please sign in to comment.