From 1dea00c268e6851637eee0004ffc3a1f919aa8e3 Mon Sep 17 00:00:00 2001 From: Mindy Long Date: Wed, 18 Dec 2024 11:23:09 -0800 Subject: [PATCH] Add ConditionalToolRule * update ToolRuleSolver to * check for Init->Terminal paths * remove cycle detection * updated tests for conditional rules, valid paths --- letta/helpers/tool_rule_solver.py | 93 +++++++++++++++++--------- letta/schemas/enums.py | 1 + letta/schemas/tool_rule.py | 15 ++++- tests/test_tool_rule_solver.py | 104 ++++++++++++++++++++++++------ 4 files changed, 160 insertions(+), 53 deletions(-) diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index ef4d9a9b37..8d5aec466e 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -1,4 +1,5 @@ -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Union +from collections import deque from pydantic import BaseModel, Field @@ -6,6 +7,7 @@ from letta.schemas.tool_rule import ( BaseToolRule, ChildToolRule, + ConditionalToolRule, InitToolRule, TerminalToolRule, ) @@ -22,7 +24,7 @@ class ToolRulesSolver(BaseModel): init_tool_rules: List[InitToolRule] = Field( default_factory=list, description="Initial tool rules to be used at the start of tool execution." ) - tool_rules: List[ChildToolRule] = Field( + tool_rules: List[Union[ChildToolRule, ConditionalToolRule]] = Field( default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions." ) terminal_tool_rules: List[TerminalToolRule] = Field( @@ -35,15 +37,22 @@ def __init__(self, tool_rules: List[BaseToolRule], **kwargs): # Separate the provided tool rules into init, standard, and terminal categories for rule in tool_rules: if rule.type == ToolRuleType.run_first: + assert isinstance(rule, InitToolRule) self.init_tool_rules.append(rule) elif rule.type == ToolRuleType.constrain_child_tools: + assert isinstance(rule, ChildToolRule) + self.tool_rules.append(rule) + elif rule.type == ToolRuleType.conditional: + assert isinstance(rule, ConditionalToolRule) + self.validate_conditional_tool(rule) self.tool_rules.append(rule) elif rule.type == ToolRuleType.exit_loop: + assert isinstance(rule, TerminalToolRule) self.terminal_tool_rules.append(rule) # Validate the tool rules to ensure they form a DAG if not self.validate_tool_rules(): - raise ToolRuleValidationError("Tool rules contain cycles, which are not allowed in a valid configuration.") + raise ToolRuleValidationError("Tool rules does not have a path from Init to Terminal.") def update_tool_usage(self, tool_name: str): """Update the internal state to track the last tool called.""" @@ -78,38 +87,58 @@ def has_children_tools(self, tool_name): """Check if the tool has children tools""" return any(rule.tool_name == tool_name for rule in self.tool_rules) + def validate_conditional_tool(self, rule: ConditionalToolRule): + if rule.children is None or len(rule.children) == 0: + raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.") + if len(rule.children) != len(rule.child_output_mapping): + raise ToolRuleValidationError("Conditional tool rule must have a child output mapping for each child tool.") + if set(rule.children) != set(rule.child_output_mapping.values()): + raise ToolRuleValidationError("Conditional tool rule must have a child output mapping for each child tool.") + return True + def validate_tool_rules(self) -> bool: """ - Validate that the tool rules define a directed acyclic graph (DAG). - Returns True if valid (no cycles), otherwise False. + Validate that there exists a path from every init tool to a terminal tool. + Returns True if valid (path exists), otherwise False. """ # Build adjacency list for the tool graph adjacency_list: Dict[str, List[str]] = {rule.tool_name: rule.children for rule in self.tool_rules} - # Track visited nodes - visited: Set[str] = set() - path_stack: Set[str] = set() - - # Define DFS helper function - def dfs(tool_name: str) -> bool: - if tool_name in path_stack: - return False # Cycle detected - if tool_name in visited: - return True # Already validated - - # Mark the node as visited in the current path - path_stack.add(tool_name) - for child in adjacency_list.get(tool_name, []): - if not dfs(child): - return False # Cycle detected in DFS - path_stack.remove(tool_name) # Remove from current path - visited.add(tool_name) - return True - - # Run DFS from each tool in `tool_rules` - for rule in self.tool_rules: - if rule.tool_name not in visited: - if not dfs(rule.tool_name): - return False # Cycle found, invalid tool rules - - return True # No cycles, valid DAG + init_tool_names = {rule.tool_name for rule in self.init_tool_rules} + terminal_tool_names = {rule.tool_name for rule in self.terminal_tool_rules} + + # Initial checks + if len(init_tool_names) == 0: + if len(terminal_tool_names) + len(self.tool_rules) > 0: + return False # No init tools defined + else: + return True # No tool rules + if len(terminal_tool_names) == 0: + if len(adjacency_list) > 0: + return False # No terminal tools defined + else: + return True # Only init tools + + # Define BFS helper function to find path to terminal tool + def has_path_to_terminal(start_tool: str) -> bool: + visited = set() + queue = deque([start_tool]) + visited.add(start_tool) + + while queue: + current_tool = queue.popleft() + if current_tool in terminal_tool_names: + return True + + for child in adjacency_list.get(current_tool, []): + if child not in visited: + visited.add(child) + queue.append(child) + return False + + # Check if each init tool has a path to a terminal tool + for init_tool_name in init_tool_names: + if not has_path_to_terminal(init_tool_name): + return False + + return True # All init tools have paths to terminal tools diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 8b74b83732..6183033f54 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -45,5 +45,6 @@ class ToolRuleType(str, Enum): run_first = "InitToolRule" exit_loop = "TerminalToolRule" # reasoning loop should exit continue_loop = "continue_loop" # reasoning loop should continue + conditional = "conditional" constrain_child_tools = "ToolRule" require_parent_tools = "require_parent_tools" diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index b320917d25..a7f4f7cc1b 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -1,4 +1,4 @@ -from typing import List, Union +from typing import Dict, List, Union from pydantic import Field @@ -21,6 +21,17 @@ class ChildToolRule(BaseToolRule): children: List[str] = Field(..., description="The children tools that can be invoked.") +class ConditionalToolRule(BaseToolRule): + """ + A ToolRule that conditionally maps to different child tools based on the output. + """ + type: ToolRuleType = ToolRuleType.conditional + default_child: str = Field(..., description="The default child tool to be called") + child_output_mapping: Dict[Union[bool, str, int], str] = Field(..., description="The output case to check for mapping") + children: List[str] = Field(..., description="The child tool to call when output matches the case") + throw_error: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case") + + class InitToolRule(BaseToolRule): """ Represents the initial tool rule configuration. @@ -37,4 +48,4 @@ class TerminalToolRule(BaseToolRule): type: ToolRuleType = ToolRuleType.exit_loop -ToolRule = Union[ChildToolRule, InitToolRule, TerminalToolRule] +ToolRule = Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule] diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index 9de6a6302b..e1170e4b73 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -2,7 +2,12 @@ from letta.helpers import ToolRulesSolver from letta.helpers.tool_rule_solver import ToolRuleValidationError -from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule +from letta.schemas.tool_rule import ( + ChildToolRule, + ConditionalToolRule, + InitToolRule, + TerminalToolRule +) # Constants for tool names used in the tests START_TOOL = "start_tool" @@ -31,7 +36,9 @@ def test_get_allowed_tool_names_with_subsequent_rule(): # Setup: Tool rule sequence init_rule = InitToolRule(tool_name=START_TOOL) rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL]) - solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[]) + rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[END_TOOL]) + terminal_rule = TerminalToolRule(tool_name=END_TOOL) + solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1, rule_2], terminal_tool_rules=[terminal_rule]) # Action: Update usage and get allowed tools solver.update_tool_usage(START_TOOL) @@ -44,21 +51,22 @@ def test_get_allowed_tool_names_with_subsequent_rule(): def test_is_terminal_tool(): # Setup: Terminal tool rule configuration init_rule = InitToolRule(tool_name=START_TOOL) + rule_1 = ChildToolRule(tool_name=START_TOOL, children=[END_TOOL]) terminal_rule = TerminalToolRule(tool_name=END_TOOL) - solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[terminal_rule]) + solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[terminal_rule]) # Action & Assert: Verify terminal and non-terminal tools assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as a terminal tool" assert solver.is_terminal_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a terminal tool" -def test_get_allowed_tool_names_no_matching_rule_warning(): - # Setup: Tool rules with no matching rule for the last tool - init_rule = InitToolRule(tool_name=START_TOOL) - solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[]) +# def test_get_allowed_tool_names_no_matching_rule_warning(): +# # Setup: Tool rules with no matching rule for the last tool +# init_rule = InitToolRule(tool_name=START_TOOL) +# solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[]) - # Action: Set last tool to an unrecognized tool and check warnings - solver.update_tool_usage(UNRECOGNIZED_TOOL) +# # Action: Set last tool to an unrecognized tool and check warnings +# solver.update_tool_usage(UNRECOGNIZED_TOOL) # NOTE: removed for now since this warning is getting triggered on every LLM call # with warnings.catch_warnings(record=True) as w: @@ -104,7 +112,65 @@ def test_update_tool_usage_and_get_allowed_tool_names_combined(): assert solver.is_terminal_tool(FINAL_TOOL) is True, "Should recognize 'final_tool' as terminal" -def test_tool_rules_with_cycle_detection(): +def test_conditional_tool_rule(): + # Setup: Define a conditional tool rule + init_rule = InitToolRule(tool_name=START_TOOL) + terminal_rule = TerminalToolRule(tool_name=END_TOOL) + rule = ConditionalToolRule( + tool_name=START_TOOL, + children=[START_TOOL, END_TOOL], + default_child=END_TOOL, + child_output_mapping={True: END_TOOL, False: START_TOOL} + ) + solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule]) + + # Action & Assert: Verify the rule properties + # Step 1: Initially allowed tools + assert solver.get_allowed_tool_names() == [START_TOOL], "Initial allowed tool should be 'start_tool'" + + # Step 2: After using 'start_tool' + solver.update_tool_usage(START_TOOL) + assert set(solver.get_allowed_tool_names()) == set({END_TOOL, START_TOOL}), "After 'start_tool', should allow 'start_tool' or 'end_tool'" + + # Step 3: After using 'end_tool' + assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal" + + +def test_invalid_conditional_tool_rule(): + # Setup: Define an invalid conditional tool rule + init_rule = InitToolRule(tool_name=START_TOOL) + terminal_rule = TerminalToolRule(tool_name=END_TOOL) + invalid_rule_1 = ConditionalToolRule( + tool_name=START_TOOL, + children=[START_TOOL], + default_child=END_TOOL, + child_output_mapping={True: END_TOOL, False: START_TOOL} + ) + invalid_rule_2 = ConditionalToolRule( + tool_name=START_TOOL, + children=[START_TOOL, END_TOOL], + default_child=END_TOOL, + child_output_mapping={True: END_TOOL} + ) + invalid_rule_3 = ConditionalToolRule( + tool_name=START_TOOL, + children=[START_TOOL, FINAL_TOOL], + default_child=FINAL_TOOL, + child_output_mapping={True: END_TOOL, False: START_TOOL} + ) + + # Test 1: Missing child output mapping + with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have a child output mapping for each child tool."): + ToolRulesSolver(tool_rules=[init_rule, invalid_rule_1, terminal_rule]) + with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have a child output mapping for each child tool."): + ToolRulesSolver(tool_rules=[init_rule, invalid_rule_2, terminal_rule]) + + # Test 2: Missing child + with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have a child output mapping for each child tool."): + ToolRulesSolver(tool_rules=[init_rule, invalid_rule_3, terminal_rule]) + + +def test_tool_rules_with_invalid_path(): # Setup: Define tool rules with both connected, disconnected nodes and a cycle init_rule = InitToolRule(tool_name=START_TOOL) rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) @@ -114,14 +180,14 @@ def test_tool_rules_with_cycle_detection(): terminal_rule = TerminalToolRule(tool_name=END_TOOL) # Action & Assert: Attempt to create the ToolRulesSolver with a cycle should raise ValidationError - with pytest.raises(ToolRuleValidationError, match="Tool rules contain cycles"): + with pytest.raises(ToolRuleValidationError, match="Tool rules does not have a path from Init to Terminal."): ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, terminal_rule]) - # Extra setup: Define tool rules without a cycle but with hanging nodes - rule_5 = ChildToolRule(tool_name=PREP_TOOL, children=[FINAL_TOOL]) # Hanging node with no connection to start_tool - - # Assert that a configuration without cycles does not raise an error - try: - ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_4, rule_5, terminal_rule]) - except ToolRuleValidationError: - pytest.fail("ToolRulesSolver raised ValidationError unexpectedly on a valid DAG with hanging nodes") + # Now: add a path from the start tool to the final tool + rule_5 = ConditionalToolRule( + tool_name=HELPER_TOOL, + children=[START_TOOL, FINAL_TOOL], + default_child=FINAL_TOOL, + child_output_mapping={True: START_TOOL, False: FINAL_TOOL}, + ) + ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, rule_5, terminal_rule])