Skip to content

Commit

Permalink
Add ConditionalToolRule
Browse files Browse the repository at this point in the history
* update ToolRuleSolver to
    * check for Init->Terminal paths
    * remove cycle detection
* updated tests for conditional rules, valid paths
  • Loading branch information
Mindy Long committed Dec 18, 2024
1 parent 6203560 commit 1dea00c
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 53 deletions.
93 changes: 61 additions & 32 deletions letta/helpers/tool_rule_solver.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Dict, List, Optional, Set
from typing import Dict, List, Optional, Union
from collections import deque

from pydantic import BaseModel, Field

from letta.schemas.enums import ToolRuleType
from letta.schemas.tool_rule import (
BaseToolRule,
ChildToolRule,
ConditionalToolRule,
InitToolRule,
TerminalToolRule,
)
Expand All @@ -22,7 +24,7 @@ class ToolRulesSolver(BaseModel):
init_tool_rules: List[InitToolRule] = Field(
default_factory=list, description="Initial tool rules to be used at the start of tool execution."
)
tool_rules: List[ChildToolRule] = Field(
tool_rules: List[Union[ChildToolRule, ConditionalToolRule]] = Field(
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
)
terminal_tool_rules: List[TerminalToolRule] = Field(
Expand All @@ -35,15 +37,22 @@ def __init__(self, tool_rules: List[BaseToolRule], **kwargs):
# Separate the provided tool rules into init, standard, and terminal categories
for rule in tool_rules:
if rule.type == ToolRuleType.run_first:
assert isinstance(rule, InitToolRule)
self.init_tool_rules.append(rule)
elif rule.type == ToolRuleType.constrain_child_tools:
assert isinstance(rule, ChildToolRule)
self.tool_rules.append(rule)
elif rule.type == ToolRuleType.conditional:
assert isinstance(rule, ConditionalToolRule)
self.validate_conditional_tool(rule)
self.tool_rules.append(rule)
elif rule.type == ToolRuleType.exit_loop:
assert isinstance(rule, TerminalToolRule)
self.terminal_tool_rules.append(rule)

# Validate the tool rules to ensure they form a DAG
if not self.validate_tool_rules():
raise ToolRuleValidationError("Tool rules contain cycles, which are not allowed in a valid configuration.")
raise ToolRuleValidationError("Tool rules does not have a path from Init to Terminal.")

def update_tool_usage(self, tool_name: str):
"""Update the internal state to track the last tool called."""
Expand Down Expand Up @@ -78,38 +87,58 @@ def has_children_tools(self, tool_name):
"""Check if the tool has children tools"""
return any(rule.tool_name == tool_name for rule in self.tool_rules)

def validate_conditional_tool(self, rule: ConditionalToolRule):
if rule.children is None or len(rule.children) == 0:
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
if len(rule.children) != len(rule.child_output_mapping):
raise ToolRuleValidationError("Conditional tool rule must have a child output mapping for each child tool.")
if set(rule.children) != set(rule.child_output_mapping.values()):
raise ToolRuleValidationError("Conditional tool rule must have a child output mapping for each child tool.")
return True

def validate_tool_rules(self) -> bool:
"""
Validate that the tool rules define a directed acyclic graph (DAG).
Returns True if valid (no cycles), otherwise False.
Validate that there exists a path from every init tool to a terminal tool.
Returns True if valid (path exists), otherwise False.
"""
# Build adjacency list for the tool graph
adjacency_list: Dict[str, List[str]] = {rule.tool_name: rule.children for rule in self.tool_rules}

# Track visited nodes
visited: Set[str] = set()
path_stack: Set[str] = set()

# Define DFS helper function
def dfs(tool_name: str) -> bool:
if tool_name in path_stack:
return False # Cycle detected
if tool_name in visited:
return True # Already validated

# Mark the node as visited in the current path
path_stack.add(tool_name)
for child in adjacency_list.get(tool_name, []):
if not dfs(child):
return False # Cycle detected in DFS
path_stack.remove(tool_name) # Remove from current path
visited.add(tool_name)
return True

# Run DFS from each tool in `tool_rules`
for rule in self.tool_rules:
if rule.tool_name not in visited:
if not dfs(rule.tool_name):
return False # Cycle found, invalid tool rules

return True # No cycles, valid DAG
init_tool_names = {rule.tool_name for rule in self.init_tool_rules}
terminal_tool_names = {rule.tool_name for rule in self.terminal_tool_rules}

# Initial checks
if len(init_tool_names) == 0:
if len(terminal_tool_names) + len(self.tool_rules) > 0:
return False # No init tools defined
else:
return True # No tool rules
if len(terminal_tool_names) == 0:
if len(adjacency_list) > 0:
return False # No terminal tools defined
else:
return True # Only init tools

# Define BFS helper function to find path to terminal tool
def has_path_to_terminal(start_tool: str) -> bool:
visited = set()
queue = deque([start_tool])
visited.add(start_tool)

while queue:
current_tool = queue.popleft()
if current_tool in terminal_tool_names:
return True

for child in adjacency_list.get(current_tool, []):
if child not in visited:
visited.add(child)
queue.append(child)
return False

# Check if each init tool has a path to a terminal tool
for init_tool_name in init_tool_names:
if not has_path_to_terminal(init_tool_name):
return False

return True # All init tools have paths to terminal tools
1 change: 1 addition & 0 deletions letta/schemas/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@ class ToolRuleType(str, Enum):
run_first = "InitToolRule"
exit_loop = "TerminalToolRule" # reasoning loop should exit
continue_loop = "continue_loop" # reasoning loop should continue
conditional = "conditional"
constrain_child_tools = "ToolRule"
require_parent_tools = "require_parent_tools"
15 changes: 13 additions & 2 deletions letta/schemas/tool_rule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import Dict, List, Union

from pydantic import Field

Expand All @@ -21,6 +21,17 @@ class ChildToolRule(BaseToolRule):
children: List[str] = Field(..., description="The children tools that can be invoked.")


class ConditionalToolRule(BaseToolRule):
"""
A ToolRule that conditionally maps to different child tools based on the output.
"""
type: ToolRuleType = ToolRuleType.conditional
default_child: str = Field(..., description="The default child tool to be called")
child_output_mapping: Dict[Union[bool, str, int], str] = Field(..., description="The output case to check for mapping")
children: List[str] = Field(..., description="The child tool to call when output matches the case")
throw_error: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case")


class InitToolRule(BaseToolRule):
"""
Represents the initial tool rule configuration.
Expand All @@ -37,4 +48,4 @@ class TerminalToolRule(BaseToolRule):
type: ToolRuleType = ToolRuleType.exit_loop


ToolRule = Union[ChildToolRule, InitToolRule, TerminalToolRule]
ToolRule = Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]
104 changes: 85 additions & 19 deletions tests/test_tool_rule_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

from letta.helpers import ToolRulesSolver
from letta.helpers.tool_rule_solver import ToolRuleValidationError
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from letta.schemas.tool_rule import (
ChildToolRule,
ConditionalToolRule,
InitToolRule,
TerminalToolRule
)

# Constants for tool names used in the tests
START_TOOL = "start_tool"
Expand Down Expand Up @@ -31,7 +36,9 @@ def test_get_allowed_tool_names_with_subsequent_rule():
# Setup: Tool rule sequence
init_rule = InitToolRule(tool_name=START_TOOL)
rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL])
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[])
rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[END_TOOL])
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1, rule_2], terminal_tool_rules=[terminal_rule])

# Action: Update usage and get allowed tools
solver.update_tool_usage(START_TOOL)
Expand All @@ -44,21 +51,22 @@ def test_get_allowed_tool_names_with_subsequent_rule():
def test_is_terminal_tool():
# Setup: Terminal tool rule configuration
init_rule = InitToolRule(tool_name=START_TOOL)
rule_1 = ChildToolRule(tool_name=START_TOOL, children=[END_TOOL])
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[terminal_rule])
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[terminal_rule])

# Action & Assert: Verify terminal and non-terminal tools
assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as a terminal tool"
assert solver.is_terminal_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a terminal tool"


def test_get_allowed_tool_names_no_matching_rule_warning():
# Setup: Tool rules with no matching rule for the last tool
init_rule = InitToolRule(tool_name=START_TOOL)
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[])
# def test_get_allowed_tool_names_no_matching_rule_warning():
# # Setup: Tool rules with no matching rule for the last tool
# init_rule = InitToolRule(tool_name=START_TOOL)
# solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[])

# Action: Set last tool to an unrecognized tool and check warnings
solver.update_tool_usage(UNRECOGNIZED_TOOL)
# # Action: Set last tool to an unrecognized tool and check warnings
# solver.update_tool_usage(UNRECOGNIZED_TOOL)

# NOTE: removed for now since this warning is getting triggered on every LLM call
# with warnings.catch_warnings(record=True) as w:
Expand Down Expand Up @@ -104,7 +112,65 @@ def test_update_tool_usage_and_get_allowed_tool_names_combined():
assert solver.is_terminal_tool(FINAL_TOOL) is True, "Should recognize 'final_tool' as terminal"


def test_tool_rules_with_cycle_detection():
def test_conditional_tool_rule():
# Setup: Define a conditional tool rule
init_rule = InitToolRule(tool_name=START_TOOL)
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
rule = ConditionalToolRule(
tool_name=START_TOOL,
children=[START_TOOL, END_TOOL],
default_child=END_TOOL,
child_output_mapping={True: END_TOOL, False: START_TOOL}
)
solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule])

# Action & Assert: Verify the rule properties
# Step 1: Initially allowed tools
assert solver.get_allowed_tool_names() == [START_TOOL], "Initial allowed tool should be 'start_tool'"

# Step 2: After using 'start_tool'
solver.update_tool_usage(START_TOOL)
assert set(solver.get_allowed_tool_names()) == set({END_TOOL, START_TOOL}), "After 'start_tool', should allow 'start_tool' or 'end_tool'"

# Step 3: After using 'end_tool'
assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal"


def test_invalid_conditional_tool_rule():
# Setup: Define an invalid conditional tool rule
init_rule = InitToolRule(tool_name=START_TOOL)
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
invalid_rule_1 = ConditionalToolRule(
tool_name=START_TOOL,
children=[START_TOOL],
default_child=END_TOOL,
child_output_mapping={True: END_TOOL, False: START_TOOL}
)
invalid_rule_2 = ConditionalToolRule(
tool_name=START_TOOL,
children=[START_TOOL, END_TOOL],
default_child=END_TOOL,
child_output_mapping={True: END_TOOL}
)
invalid_rule_3 = ConditionalToolRule(
tool_name=START_TOOL,
children=[START_TOOL, FINAL_TOOL],
default_child=FINAL_TOOL,
child_output_mapping={True: END_TOOL, False: START_TOOL}
)

# Test 1: Missing child output mapping
with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have a child output mapping for each child tool."):
ToolRulesSolver(tool_rules=[init_rule, invalid_rule_1, terminal_rule])
with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have a child output mapping for each child tool."):
ToolRulesSolver(tool_rules=[init_rule, invalid_rule_2, terminal_rule])

# Test 2: Missing child
with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have a child output mapping for each child tool."):
ToolRulesSolver(tool_rules=[init_rule, invalid_rule_3, terminal_rule])


def test_tool_rules_with_invalid_path():
# Setup: Define tool rules with both connected, disconnected nodes and a cycle
init_rule = InitToolRule(tool_name=START_TOOL)
rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL])
Expand All @@ -114,14 +180,14 @@ def test_tool_rules_with_cycle_detection():
terminal_rule = TerminalToolRule(tool_name=END_TOOL)

# Action & Assert: Attempt to create the ToolRulesSolver with a cycle should raise ValidationError
with pytest.raises(ToolRuleValidationError, match="Tool rules contain cycles"):
with pytest.raises(ToolRuleValidationError, match="Tool rules does not have a path from Init to Terminal."):
ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, terminal_rule])

# Extra setup: Define tool rules without a cycle but with hanging nodes
rule_5 = ChildToolRule(tool_name=PREP_TOOL, children=[FINAL_TOOL]) # Hanging node with no connection to start_tool

# Assert that a configuration without cycles does not raise an error
try:
ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_4, rule_5, terminal_rule])
except ToolRuleValidationError:
pytest.fail("ToolRulesSolver raised ValidationError unexpectedly on a valid DAG with hanging nodes")
# Now: add a path from the start tool to the final tool
rule_5 = ConditionalToolRule(
tool_name=HELPER_TOOL,
children=[START_TOOL, FINAL_TOOL],
default_child=FINAL_TOOL,
child_output_mapping={True: START_TOOL, False: FINAL_TOOL},
)
ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, rule_5, terminal_rule])

0 comments on commit 1dea00c

Please sign in to comment.