diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index b8f2abbd71..7fd8410eb9 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -51,9 +51,6 @@ def __init__(self, tool_rules: List[BaseToolRule], **kwargs): assert isinstance(rule, TerminalToolRule) self.terminal_tool_rules.append(rule) - # Validate the tool rules to ensure they form a DAG - if not self.validate_tool_rules(): - raise ToolRuleValidationError("Tool rules does not have a path from Init to Terminal.") def update_tool_usage(self, tool_name: str): """Update the internal state to track the last tool called.""" @@ -108,53 +105,6 @@ def validate_conditional_tool(self, rule: ConditionalToolRule): raise ToolRuleValidationError("Conditional tool rule must have a child output mapping for each child tool.") return True - def validate_tool_rules(self) -> bool: - """ - Validate that there exists a path from every init tool to a terminal tool. - Returns True if valid (path exists), otherwise False. - """ - # Build adjacency list for the tool graph - adjacency_list: Dict[str, List[str]] = {rule.tool_name: rule.children for rule in self.tool_rules} - - init_tool_names = {rule.tool_name for rule in self.init_tool_rules} - terminal_tool_names = {rule.tool_name for rule in self.terminal_tool_rules} - - # Initial checks - if len(init_tool_names) == 0: - if len(terminal_tool_names) + len(self.tool_rules) > 0: - return False # No init tools defined - else: - return True # No tool rules - if len(terminal_tool_names) == 0: - if len(adjacency_list) > 0: - return False # No terminal tools defined - else: - return True # Only init tools - - # Define BFS helper function to find path to terminal tool - def has_path_to_terminal(start_tool: str) -> bool: - visited = set() - queue = deque([start_tool]) - visited.add(start_tool) - - while queue: - current_tool = queue.popleft() - if current_tool in terminal_tool_names: - return True - - for child in adjacency_list.get(current_tool, []): - if child not in visited: - visited.add(child) - queue.append(child) - return False - - # Check if each init tool has a path to a terminal tool - for init_tool_name in init_tool_names: - if not has_path_to_terminal(init_tool_name): - return False - - return True # All init tools have paths to terminal tools - def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str: ''' Parse function response to determine which child tool to use based on the mapping @@ -173,18 +123,18 @@ def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_res for key in tool.child_output_mapping: # Convert function output to match key type for comparison - if key == "true" or key == "false": - try: - typed_output = function_output.lower() - except AttributeError: - continue + if isinstance(key, bool): + typed_output = function_output.lower() == "true" elif isinstance(key, int): try: typed_output = int(function_output) except (ValueError, TypeError): continue else: # string - typed_output = str(function_output) + if function_output == "True" or function_output == "False": + typed_output = function_output.lower() + else: + typed_output = function_output if typed_output == key: return tool.child_output_mapping[key] diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index aa59c47c4d..e5acfd4824 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -250,6 +250,7 @@ def test_claude_initial_tool_rule_enforced(mock_e2b_api_key_none): tool_rules = [ InitToolRule(tool_name=t1_name), ChildToolRule(tool_name=t1_name, children=[t2_name]), + TerminalToolRule(tool_name=t2_name) ] tools = [t1, t2] diff --git a/tests/integration_test_offline_memory_agent.py b/tests/integration_test_offline_memory_agent.py index 07b7c732b2..15d4161d5e 100644 --- a/tests/integration_test_offline_memory_agent.py +++ b/tests/integration_test_offline_memory_agent.py @@ -74,8 +74,8 @@ def test_ripple_edit(client, mock_e2b_api_key_none): assert set(conversation_agent.memory.list_block_labels()) == {"persona", "human", "fact_block", "rethink_memory_block"} - rethink_memory_tool = client.create_tool(rethink_memory) - finish_rethinking_memory_tool = client.create_tool(finish_rethinking_memory) + rethink_memory_tool = client.create_or_update_tool(rethink_memory) + finish_rethinking_memory_tool = client.create_or_update_tool(finish_rethinking_memory) offline_memory_agent = client.create_agent( name="offline_memory_agent", agent_type=AgentType.offline_memory_agent, diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index e1170e4b73..25434ca28d 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -36,9 +36,7 @@ def test_get_allowed_tool_names_with_subsequent_rule(): # Setup: Tool rule sequence init_rule = InitToolRule(tool_name=START_TOOL) rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL]) - rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[END_TOOL]) - terminal_rule = TerminalToolRule(tool_name=END_TOOL) - solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1, rule_2], terminal_tool_rules=[terminal_rule]) + solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[]) # Action: Update usage and get allowed tools solver.update_tool_usage(START_TOOL) @@ -51,9 +49,8 @@ def test_get_allowed_tool_names_with_subsequent_rule(): def test_is_terminal_tool(): # Setup: Terminal tool rule configuration init_rule = InitToolRule(tool_name=START_TOOL) - rule_1 = ChildToolRule(tool_name=START_TOOL, children=[END_TOOL]) terminal_rule = TerminalToolRule(tool_name=END_TOOL) - solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[terminal_rule]) + solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[terminal_rule]) # Action & Assert: Verify terminal and non-terminal tools assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as a terminal tool" @@ -83,9 +80,9 @@ def test_get_allowed_tool_names_no_matching_rule_error(): init_rule = InitToolRule(tool_name=START_TOOL) solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[]) - # Action & Assert: Set last tool to an unrecognized tool and expect RuntimeError when error_on_empty=True + # Action & Assert: Set last tool to an unrecognized tool and expect ValueError solver.update_tool_usage(UNRECOGNIZED_TOOL) - with pytest.raises(RuntimeError, match="resolved to no more possible tool calls"): + with pytest.raises(ValueError, match=f"No tool rule found for {UNRECOGNIZED_TOOL}"): solver.get_allowed_tool_names(error_on_empty=True) @@ -119,7 +116,7 @@ def test_conditional_tool_rule(): rule = ConditionalToolRule( tool_name=START_TOOL, children=[START_TOOL, END_TOOL], - default_child=END_TOOL, + default_child=START_TOOL, child_output_mapping={True: END_TOOL, False: START_TOOL} ) solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule]) @@ -130,7 +127,8 @@ def test_conditional_tool_rule(): # Step 2: After using 'start_tool' solver.update_tool_usage(START_TOOL) - assert set(solver.get_allowed_tool_names()) == set({END_TOOL, START_TOOL}), "After 'start_tool', should allow 'start_tool' or 'end_tool'" + assert solver.get_allowed_tool_names(last_function_response='{"message": "true"}') == [END_TOOL], "After 'start_tool' returns true, should allow 'end_tool'" + assert solver.get_allowed_tool_names(last_function_response='{"message": "false"}') == [START_TOOL], "After 'start_tool' returns false, should allow 'start_tool'" # Step 3: After using 'end_tool' assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal"