Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Add testing around base tools #2268

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 1 addition & 14 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,6 @@ def __init__(
self.agent_manager = AgentManager()

# State needed for heartbeat pausing
self.pause_heartbeats_start = None
self.pause_heartbeats_minutes = 0

self.first_message_verify_mono = first_message_verify_mono

Expand Down Expand Up @@ -1235,17 +1233,6 @@ def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True,

printd(f"Ran summarizer, messages length {prior_len} -> {len(self.messages)}")

def heartbeat_is_paused(self):
"""Check if there's a requested pause on timed heartbeats"""

# Check if the pause has been initiated
if self.pause_heartbeats_start is None:
return False

# Check if it's been more than pause_heartbeats_minutes since pause_heartbeats_start
elapsed_time = get_utc_time() - self.pause_heartbeats_start
return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60

def _swap_system_message_in_buffer(self, new_system_message: str):
"""Update the system message (NOT prompt) of the Agent (requires updating the internal buffer)"""
assert isinstance(new_system_message, str)
Expand Down Expand Up @@ -1370,7 +1357,7 @@ def attach_source(
agent_manager: AgentManager,
):
"""Attach a source to the agent using the SourcesAgents ORM relationship.

Args:
user: User performing the action
source_id: ID of the source to attach
Expand Down
3 changes: 2 additions & 1 deletion letta/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
DEFAULT_PRESET = "memgpt_chat"

# Base tools that cannot be edited, as they access agent state directly
BASE_TOOLS = ["send_message", "conversation_search", "conversation_search_date", "archival_memory_insert", "archival_memory_search"]
# Note that we don't include "conversation_search_date" for now
BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", "archival_memory_search"]
O1_BASE_TOOLS = ["send_thinking_message", "send_final_message"]
# Base memory tools CAN be edited, and are added by default by the server
BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"]
Expand Down
66 changes: 9 additions & 57 deletions letta/functions/function_sets/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
from datetime import datetime
from typing import Optional

from letta.agent import Agent
from letta.constants import MAX_PAUSE_HEARTBEATS
from letta.services.agent_manager import AgentManager

# import math
# from letta.utils import json_dumps

### Functions / tools the agent can use
# All functions should return a response string (or None)
# If the function fails, throw an exception


def send_message(self: "Agent", message: str) -> Optional[str]:
Expand All @@ -28,36 +18,6 @@ def send_message(self: "Agent", message: str) -> Optional[str]:
return None


# Construct the docstring dynamically (since it should use the external constants)
pause_heartbeats_docstring = f"""
Temporarily ignore timed heartbeats. You may still receive messages from manual heartbeats and other events.

Args:
minutes (int): Number of minutes to ignore heartbeats for. Max value of {MAX_PAUSE_HEARTBEATS} minutes ({MAX_PAUSE_HEARTBEATS // 60} hours).

Returns:
str: Function status response
"""


def pause_heartbeats(self: "Agent", minutes: int) -> Optional[str]:
import datetime

from letta.constants import MAX_PAUSE_HEARTBEATS

minutes = min(MAX_PAUSE_HEARTBEATS, minutes)

# Record the current time
self.pause_heartbeats_start = datetime.datetime.now(datetime.timezone.utc)
# And record how long the pause should go for
self.pause_heartbeats_minutes = int(minutes)

return f"Pausing timed heartbeats for {minutes} min"


pause_heartbeats.__doc__ = pause_heartbeats_docstring


def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> Optional[str]:
"""
Search prior conversation history using case-insensitive string matching.
Expand All @@ -84,19 +44,19 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
# TODO: add paging by page number. currently cursor only works with strings.
# original: start=page * count
results = self.message_manager.list_user_messages_for_agent(
messages = self.message_manager.list_user_messages_for_agent(
agent_id=self.agent_state.id,
actor=self.user,
query_text=query,
limit=count,
)
total = len(results)
total = len(messages)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
if len(messages) == 0:
results_str = f"No results found."
else:
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
results_pref = f"Showing {len(messages)} of {total} results (page {page}/{num_pages}):"
results_formatted = [message.text for message in messages]
results_str = f"{results_pref} {json_dumps(results_formatted)}"
return results_str

Expand All @@ -114,6 +74,7 @@ def conversation_search_date(self: "Agent", start_date: str, end_date: str, page
str: Query result string
"""
import math
from datetime import datetime

from letta.constants import RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
from letta.utils import json_dumps
Expand Down Expand Up @@ -142,7 +103,6 @@ def conversation_search_date(self: "Agent", start_date: str, end_date: str, page
start_date=start_datetime,
end_date=end_datetime,
limit=count,
# start_date=start_date, end_date=end_date, limit=count, start=page * count
)
total = len(results)
num_pages = math.ceil(total / count) - 1 # 0 index
Expand Down Expand Up @@ -186,10 +146,8 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, s
Returns:
str: Query result string
"""
import math

from letta.constants import RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
from letta.utils import json_dumps

if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
page = 0
Expand All @@ -198,7 +156,7 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, s
except:
raise ValueError(f"'page' argument must be an integer")
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE

try:
# Get results using passage manager
all_results = self.agent_manager.list_passages(
Expand All @@ -207,21 +165,15 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, s
query_text=query,
limit=count + start, # Request enough results to handle offset
embedding_config=self.agent_state.embedding_config,
embed_query=True
embed_query=True,
)

# Apply pagination
end = min(count + start, len(all_results))
paged_results = all_results[start:end]

# Format results to match previous implementation
formatted_results = [
{
"timestamp": str(result.created_at),
"content": result.text
}
for result in paged_results
]
formatted_results = [{"timestamp": str(result.created_at), "content": result.text} for result in paged_results]

return formatted_results, len(formatted_results)

Expand Down
2 changes: 1 addition & 1 deletion letta/functions/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
# append the heartbeat
# TODO: don't hard-code
# TODO: if terminal, don't include this
if function.__name__ not in ["send_message", "pause_heartbeats"]:
if function.__name__ not in ["send_message"]:
schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
Expand Down
2 changes: 1 addition & 1 deletion letta/local_llm/function_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from letta.utils import json_dumps, json_loads

NO_HEARTBEAT_FUNCS = ["send_message", "pause_heartbeats"]
NO_HEARTBEAT_FUNCS = ["send_message"]


def insert_heartbeat(message):
Expand Down
10 changes: 0 additions & 10 deletions scripts/migrate_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,3 @@ def deprecated_tool():
),
actor=fake_user,
)

ToolManager().create_or_update_tool(
Tool(
name="pause_heartbeats",
source_code=source_code,
source_type=source_type,
description=description,
),
actor=fake_user,
)
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,12 @@ def mock_e2b_api_key_none():

# Restore the original value of e2b_api_key
tool_settings.e2b_api_key = original_api_key


@pytest.fixture
def check_e2b_key_is_set():
from letta.settings import tool_settings

original_api_key = tool_settings.e2b_api_key
assert original_api_key is not None, "Missing e2b key! Cannot execute these tests."
yield
Loading
Loading