Skip to content

Commit

Permalink
Refactor out custom columns
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 committed Dec 14, 2024
1 parent cbd113f commit 4884967
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 181 deletions.
103 changes: 8 additions & 95 deletions letta/orm/agent.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
import uuid
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional

from sqlalchemy import JSON, String, TypeDecorator, UniqueConstraint
from sqlalchemy import JSON, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.block import Block
from letta.orm.custom_columns import (
EmbeddingConfigColumn,
LLMConfigColumn,
ToolRulesColumn,
)
from letta.orm.message import Message
from letta.orm.mixins import OrganizationMixin
from letta.orm.organization import Organization
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.agent import AgentType
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
from letta.schemas.tool_rule import (
ChildToolRule,
InitToolRule,
TerminalToolRule,
ToolRule,
)
from letta.schemas.tool_rule import ToolRule

if TYPE_CHECKING:
from letta.orm.agents_tags import AgentsTags
Expand All @@ -29,92 +28,6 @@
from letta.orm.tool import Tool


class LLMConfigColumn(TypeDecorator):
"""Custom type for storing LLMConfig as JSON"""

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
if value:
# return vars(value)
if isinstance(value, LLMConfig):
return value.model_dump()
return value

def process_result_value(self, value, dialect):
if value:
return LLMConfig(**value)
return value


class EmbeddingConfigColumn(TypeDecorator):
"""Custom type for storing EmbeddingConfig as JSON"""

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
if value:
# return vars(value)
if isinstance(value, EmbeddingConfig):
return value.model_dump()
return value

def process_result_value(self, value, dialect):
if value:
return EmbeddingConfig(**value)
return value


class ToolRulesColumn(TypeDecorator):
"""Custom type for storing a list of ToolRules as JSON"""

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
"""Convert a list of ToolRules to JSON-serializable format."""
if value:
data = [rule.model_dump() for rule in value]
for d in data:
d["type"] = d["type"].value

for d in data:
assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field"
return data
return value

def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]:
"""Convert JSON back to a list of ToolRules."""
if value:
return [self.deserialize_tool_rule(rule_data) for rule_data in value]
return value

@staticmethod
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
if rule_type == ToolRuleType.run_first:
return InitToolRule(**data)
elif rule_type == ToolRuleType.exit_loop:
return TerminalToolRule(**data)
elif rule_type == ToolRuleType.constrain_child_tools:
rule = ChildToolRule(**data)
return rule
else:
raise ValueError(f"Unknown tool rule type: {rule_type}")


class Agent(SqlalchemyBase, OrganizationMixin):
__tablename__ = "agents"
__pydantic_model__ = PydanticAgentState
Expand Down
152 changes: 152 additions & 0 deletions letta/orm/custom_columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import base64
from typing import List, Union

import numpy as np
from sqlalchemy import JSON
from sqlalchemy.types import BINARY, TypeDecorator

from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule


class EmbeddingConfigColumn(TypeDecorator):
"""Custom type for storing EmbeddingConfig as JSON."""

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
if value and isinstance(value, EmbeddingConfig):
return value.model_dump()
return value

def process_result_value(self, value, dialect):
if value:
return EmbeddingConfig(**value)
return value


class LLMConfigColumn(TypeDecorator):
"""Custom type for storing LLMConfig as JSON."""

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
if value and isinstance(value, LLMConfig):
return value.model_dump()
return value

def process_result_value(self, value, dialect):
if value:
return LLMConfig(**value)
return value


class ToolRulesColumn(TypeDecorator):
"""Custom type for storing a list of ToolRules as JSON"""

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
"""Convert a list of ToolRules to JSON-serializable format."""
if value:
data = [rule.model_dump() for rule in value]
for d in data:
d["type"] = d["type"].value

for d in data:
assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field"
return data
return value

def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]:
"""Convert JSON back to a list of ToolRules."""
if value:
return [self.deserialize_tool_rule(rule_data) for rule_data in value]
return value

@staticmethod
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
if rule_type == ToolRuleType.run_first:
return InitToolRule(**data)
elif rule_type == ToolRuleType.exit_loop:
return TerminalToolRule(**data)
elif rule_type == ToolRuleType.constrain_child_tools:
rule = ChildToolRule(**data)
return rule
else:
raise ValueError(f"Unknown tool rule type: {rule_type}")


class ToolCallColumn(TypeDecorator):

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
if value:
values = []
for v in value:
if isinstance(v, ToolCall):
values.append(v.model_dump())
else:
values.append(v)
return values

return value

def process_result_value(self, value, dialect):
if value:
tools = []
for tool_value in value:
if "function" in tool_value:
tool_call_function = ToolCallFunction(**tool_value["function"])
del tool_value["function"]
else:
tool_call_function = None
tools.append(ToolCall(function=tool_call_function, **tool_value))
return tools
return value


class CommonVector(TypeDecorator):
"""Common type for representing vectors in SQLite"""

impl = BINARY
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(BINARY())

def process_bind_param(self, value, dialect):
if value is None:
return value
if isinstance(value, list):
value = np.array(value, dtype=np.float32)
return base64.b64encode(value.tobytes())

def process_result_value(self, value, dialect):
if not value:
return value
if dialect.name == "sqlite":
value = base64.b64decode(value)
return np.frombuffer(value, dtype=np.float32)
38 changes: 2 additions & 36 deletions letta/orm/message.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,12 @@
from typing import Optional

from sqlalchemy import JSON, TypeDecorator
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.orm.custom_columns import ToolCallColumn
from letta.orm.mixins import AgentMixin, OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction


class ToolCallColumn(TypeDecorator):

impl = JSON
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())

def process_bind_param(self, value, dialect):
if value:
values = []
for v in value:
if isinstance(v, ToolCall):
values.append(v.model_dump())
else:
values.append(v)
return values

return value

def process_result_value(self, value, dialect):
if value:
tools = []
for tool_value in value:
if "function" in tool_value:
tool_call_function = ToolCallFunction(**tool_value["function"])
del tool_value["function"]
else:
tool_call_function = None
tools.append(ToolCall(function=tool_call_function, **tool_value))
return tools
return value
from letta.schemas.openai.chat_completions import ToolCall


class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
Expand Down
28 changes: 1 addition & 27 deletions letta/orm/passage.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import base64
from datetime import datetime
from typing import TYPE_CHECKING, Optional

import numpy as np
from sqlalchemy import JSON, Column, DateTime, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.types import BINARY, TypeDecorator

from letta.config import LettaConfig
from letta.constants import MAX_EMBEDDING_DIM
from letta.orm.custom_columns import CommonVector
from letta.orm.mixins import FileMixin, OrganizationMixin
from letta.orm.source import EmbeddingConfigColumn
from letta.orm.sqlalchemy_base import SqlalchemyBase
Expand All @@ -21,30 +19,6 @@
from letta.orm.organization import Organization


class CommonVector(TypeDecorator):
"""Common type for representing vectors in SQLite"""

impl = BINARY
cache_ok = True

def load_dialect_impl(self, dialect):
return dialect.type_descriptor(BINARY())

def process_bind_param(self, value, dialect):
if value is None:
return value
if isinstance(value, list):
value = np.array(value, dtype=np.float32)
return base64.b64encode(value.tobytes())

def process_result_value(self, value, dialect):
if not value:
return value
if dialect.name == "sqlite":
value = base64.b64decode(value)
return np.frombuffer(value, dtype=np.float32)


# TODO: After migration to Passage, will need to manually delete passages where files
# are deleted on web
class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
Expand Down
Loading

0 comments on commit 4884967

Please sign in to comment.