Skip to content

Commit

Permalink
feat: enable configuration of response_char_limit for tools (#2207)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Dec 10, 2024
1 parent 85a9046 commit af5ef6d
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""add column to tools table to contain function return limit return_char_limit
Revision ID: a91994b9752f
Revises: e1a625072dbf
Create Date: 2024-12-09 18:27:25.650079
"""

from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op
from letta.constants import FUNCTION_RETURN_CHAR_LIMIT

# revision identifiers, used by Alembic.
revision: str = "a91994b9752f"
down_revision: Union[str, None] = "e1a625072dbf"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("tools", sa.Column("return_char_limit", sa.Integer(), nullable=True))

# Populate `return_char_limit` column
op.execute(
f"""
UPDATE tools
SET return_char_limit = {FUNCTION_RETURN_CHAR_LIMIT}
"""
)


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("tools", "return_char_limit")
# ### end Alembic commands ###
8 changes: 7 additions & 1 deletion letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,13 @@ def _handle_ai_response(
# but by default, we add a truncation safeguard to prevent bad functions from
# overflow the agent context window
truncate = True
function_response_string = validate_function_response(function_response, truncate=truncate)

# get the function response limit
tool_obj = [tool for tool in self.agent_state.tools if tool.name == function_name][0]
return_char_limit = tool_obj.return_char_limit
function_response_string = validate_function_response(
function_response, return_char_limit=return_char_limit, truncate=truncate
)
function_args.pop("self", None)
function_response = package_function_response(True, function_response_string)
function_failed = False
Expand Down
44 changes: 33 additions & 11 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
BASE_TOOLS,
DEFAULT_HUMAN,
DEFAULT_PERSONA,
FUNCTION_RETURN_CHAR_LIMIT,
)
from letta.data_sources.connectors import DataConnector
from letta.functions.functions import parse_source_code
Expand Down Expand Up @@ -200,18 +201,12 @@ def load_composio_tool(self, action: "ActionType") -> Tool:
raise NotImplementedError

def create_tool(
self,
func,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
self, func, name: Optional[str] = None, tags: Optional[List[str]] = None, return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT
) -> Tool:
raise NotImplementedError

def create_or_update_tool(
self,
func,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
self, func, name: Optional[str] = None, tags: Optional[List[str]] = None, return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT
) -> Tool:
raise NotImplementedError

Expand All @@ -222,6 +217,7 @@ def update_tool(
description: Optional[str] = None,
func: Optional[Callable] = None,
tags: Optional[List[str]] = None,
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
) -> Tool:
raise NotImplementedError

Expand Down Expand Up @@ -1465,6 +1461,7 @@ def create_tool(
func: Callable,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
) -> Tool:
"""
Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
Expand All @@ -1473,6 +1470,7 @@ def create_tool(
func (callable): The function to create a tool for.
name: (str): Name of the tool (must be unique per-user.)
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
Returns:
tool (Tool): The created tool.
Expand All @@ -1481,7 +1479,9 @@ def create_tool(
source_type = "python"

# call server function
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags)
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, return_char_limit=return_char_limit)
if tags:
request.tags = tags
response = requests.post(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create tool: {response.text}")
Expand All @@ -1492,6 +1492,7 @@ def create_or_update_tool(
func: Callable,
name: Optional[str] = None,
tags: Optional[List[str]] = None,
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
) -> Tool:
"""
Creates or updates a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
Expand All @@ -1500,6 +1501,7 @@ def create_or_update_tool(
func (callable): The function to create a tool for.
name: (str): Name of the tool (must be unique per-user.)
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
Returns:
tool (Tool): The created tool.
Expand All @@ -1508,7 +1510,9 @@ def create_or_update_tool(
source_type = "python"

# call server function
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, tags=tags)
request = ToolCreate(source_type=source_type, source_code=source_code, name=name, return_char_limit=return_char_limit)
if tags:
request.tags = tags
response = requests.put(f"{self.base_url}/{self.api_prefix}/tools", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create tool: {response.text}")
Expand All @@ -1521,6 +1525,7 @@ def update_tool(
description: Optional[str] = None,
func: Optional[Callable] = None,
tags: Optional[List[str]] = None,
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
) -> Tool:
"""
Update a tool with provided parameters (name, func, tags)
Expand All @@ -1530,6 +1535,7 @@ def update_tool(
name (str): Name of the tool
func (callable): Function to wrap in a tool
tags (List[str]): Tags for the tool
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
Returns:
tool (Tool): Updated tool
Expand All @@ -1541,7 +1547,14 @@ def update_tool(

source_type = "python"

request = ToolUpdate(description=description, source_type=source_type, source_code=source_code, tags=tags, name=name)
request = ToolUpdate(
description=description,
source_type=source_type,
source_code=source_code,
tags=tags,
name=name,
return_char_limit=return_char_limit,
)
response = requests.patch(f"{self.base_url}/{self.api_prefix}/tools/{id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update tool: {response.text}")
Expand Down Expand Up @@ -2726,6 +2739,7 @@ def create_tool(
name: Optional[str] = None,
tags: Optional[List[str]] = None,
description: Optional[str] = None,
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
) -> Tool:
"""
Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
Expand All @@ -2735,6 +2749,7 @@ def create_tool(
name: (str): Name of the tool (must be unique per-user.)
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
description (str, optional): The description.
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
Returns:
tool (Tool): The created tool.
Expand All @@ -2755,6 +2770,7 @@ def create_tool(
name=name,
tags=tags,
description=description,
return_char_limit=return_char_limit,
),
actor=self.user,
)
Expand All @@ -2765,6 +2781,7 @@ def create_or_update_tool(
name: Optional[str] = None,
tags: Optional[List[str]] = None,
description: Optional[str] = None,
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
) -> Tool:
"""
Creates or updates a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent.
Expand All @@ -2774,6 +2791,7 @@ def create_or_update_tool(
name: (str): Name of the tool (must be unique per-user.)
tags (Optional[List[str]], optional): Tags for the tool. Defaults to None.
description (str, optional): The description.
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
Returns:
tool (Tool): The created tool.
Expand All @@ -2791,6 +2809,7 @@ def create_or_update_tool(
name=name,
tags=tags,
description=description,
return_char_limit=return_char_limit,
),
actor=self.user,
)
Expand All @@ -2802,6 +2821,7 @@ def update_tool(
description: Optional[str] = None,
func: Optional[callable] = None,
tags: Optional[List[str]] = None,
return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT,
) -> Tool:
"""
Update a tool with provided parameters (name, func, tags)
Expand All @@ -2811,6 +2831,7 @@ def update_tool(
name (str): Name of the tool
func (callable): Function to wrap in a tool
tags (List[str]): Tags for the tool
return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT.
Returns:
tool (Tool): Updated tool
Expand All @@ -2821,6 +2842,7 @@ def update_tool(
"tags": tags,
"name": name,
"description": description,
"return_char_limit": return_char_limit,
}

# Filter out any None values from the dictionary
Expand Down
14 changes: 6 additions & 8 deletions letta/orm/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Tool(SqlalchemyBase, OrganizationMixin):
__table_args__ = (UniqueConstraint("name", "organization_id", name="uix_name_organization"),)

name: Mapped[str] = mapped_column(doc="The display name of the tool.")
return_char_limit: Mapped[int] = mapped_column(nullable=True, doc="The maximum number of characters the tool can return.")
description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.")
tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.")
source_type: Mapped[ToolSourceType] = mapped_column(String, doc="The type of the source code.", default=ToolSourceType.json)
Expand All @@ -45,19 +46,16 @@ class Tool(SqlalchemyBase, OrganizationMixin):


# Add event listener to update tool_name in ToolsAgents when Tool name changes
@event.listens_for(Tool, 'before_update')
@event.listens_for(Tool, "before_update")
def update_tool_name_in_tools_agents(mapper, connection, target):
"""Update tool_name in ToolsAgents when Tool name changes."""
state = target._sa_instance_state
history = state.get_history('name', passive=True)
history = state.get_history("name", passive=True)
if not history.has_changes():
return

# Get the new name and update all associated ToolsAgents records
new_name = target.name
from letta.orm.tools_agents import ToolsAgents
connection.execute(
ToolsAgents.__table__.update().where(
ToolsAgents.tool_id == target.id
).values(tool_name=new_name)
)

connection.execute(ToolsAgents.__table__.update().where(ToolsAgents.tool_id == target.id).values(tool_name=new_name))
5 changes: 5 additions & 0 deletions letta/schemas/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pydantic import Field, model_validator

from letta.constants import FUNCTION_RETURN_CHAR_LIMIT
from letta.functions.functions import derive_openai_json_schema
from letta.functions.helpers import (
generate_composio_tool_wrapper,
Expand Down Expand Up @@ -41,6 +42,9 @@ class Tool(BaseTool):
source_code: str = Field(..., description="The source code of the function.")
json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.")

# tool configuration
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")

# metadata fields
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
Expand Down Expand Up @@ -91,6 +95,7 @@ class ToolCreate(LettaBase):
json_schema: Optional[Dict] = Field(
None, description="The JSON schema of the function (auto-generated from source_code if not provided)"
)
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")

@classmethod
def from_composio(cls, action_name: str, api_key: Optional[str] = None) -> "ToolCreate":
Expand Down
11 changes: 5 additions & 6 deletions letta/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
CLI_WARNING_PREFIX,
CORE_MEMORY_HUMAN_CHAR_LIMIT,
CORE_MEMORY_PERSONA_CHAR_LIMIT,
FUNCTION_RETURN_CHAR_LIMIT,
LETTA_DIR,
MAX_FILENAME_LENGTH,
TOOL_CALL_ID_MAX_LEN,
Expand Down Expand Up @@ -906,8 +905,8 @@ def parse_json(string) -> dict:
raise e


def validate_function_response(function_response_string: any, strict: bool = False, truncate: bool = True) -> str:
"""Check to make sure that a function used by Letta returned a valid response
def validate_function_response(function_response_string: any, return_char_limit: int, strict: bool = False, truncate: bool = True) -> str:
"""Check to make sure that a function used by Letta returned a valid response. Truncates to return_char_limit if necessary.
Responses need to be strings (or None) that fall under a certain text count limit.
"""
Expand Down Expand Up @@ -943,11 +942,11 @@ def validate_function_response(function_response_string: any, strict: bool = Fal

# Now check the length and make sure it doesn't go over the limit
# TODO we should change this to a max token limit that's variable based on tokens remaining (or context-window)
if truncate and len(function_response_string) > FUNCTION_RETURN_CHAR_LIMIT:
if truncate and len(function_response_string) > return_char_limit:
print(
f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT}) and was truncated"
f"{CLI_WARNING_PREFIX}function return was over limit ({len(function_response_string)} > {return_char_limit}) and was truncated"
)
function_response_string = f"{function_response_string[:FUNCTION_RETURN_CHAR_LIMIT]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {FUNCTION_RETURN_CHAR_LIMIT})]"
function_response_string = f"{function_response_string[:return_char_limit]}... [NOTE: function output was truncated since it exceeded the character limit ({len(function_response_string)} > {return_char_limit})]"

return function_response_string

Expand Down
Loading

0 comments on commit af5ef6d

Please sign in to comment.