diff --git a/alembic/versions/4e88e702f85e_drop_api_tokens_table_in_oss.py b/alembic/versions/4e88e702f85e_drop_api_tokens_table_in_oss.py new file mode 100644 index 0000000000..75a90445a0 --- /dev/null +++ b/alembic/versions/4e88e702f85e_drop_api_tokens_table_in_oss.py @@ -0,0 +1,42 @@ +"""Drop api tokens table in OSS + +Revision ID: 4e88e702f85e +Revises: d05669b60ebe +Create Date: 2024-12-13 17:19:55.796210 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "4e88e702f85e" +down_revision: Union[str, None] = "d05669b60ebe" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("tokens_idx_key", table_name="tokens") + op.drop_index("tokens_idx_user", table_name="tokens") + op.drop_table("tokens") + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "tokens", + sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("key", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint("id", name="tokens_pkey"), + ) + op.create_index("tokens_idx_user", "tokens", ["user_id"], unique=False) + op.create_index("tokens_idx_key", "tokens", ["key"], unique=False) + # ### end Alembic commands ### diff --git a/alembic/versions/9a505cc7eca9_create_a_baseline_migrations.py b/alembic/versions/9a505cc7eca9_create_a_baseline_migrations.py index 479ca223e0..21f6a39613 100644 --- a/alembic/versions/9a505cc7eca9_create_a_baseline_migrations.py +++ b/alembic/versions/9a505cc7eca9_create_a_baseline_migrations.py @@ -12,7 +12,7 @@ import sqlalchemy as sa from sqlalchemy.dialects import postgresql -import letta.metadata +import letta.orm from alembic import op # revision identifiers, used by Alembic. @@ -43,8 +43,8 @@ def upgrade() -> None: sa.Column("memory", sa.JSON(), nullable=True), sa.Column("system", sa.String(), nullable=True), sa.Column("agent_type", sa.String(), nullable=True), - sa.Column("llm_config", letta.metadata.LLMConfigColumn(), nullable=True), - sa.Column("embedding_config", letta.metadata.EmbeddingConfigColumn(), nullable=True), + sa.Column("llm_config", letta.orm.custom_columns.LLMConfigColumn(), nullable=True), + sa.Column("embedding_config", letta.orm.custom_columns.EmbeddingConfigColumn(), nullable=True), sa.Column("metadata_", sa.JSON(), nullable=True), sa.Column("tools", sa.JSON(), nullable=True), sa.PrimaryKeyConstraint("id"), @@ -119,7 +119,7 @@ def upgrade() -> None: sa.Column("agent_id", sa.String(), nullable=True), sa.Column("source_id", sa.String(), nullable=True), sa.Column("embedding", pgvector.sqlalchemy.Vector(dim=4096), nullable=True), - sa.Column("embedding_config", letta.metadata.EmbeddingConfigColumn(), nullable=True), + sa.Column("embedding_config", letta.orm.custom_columns.EmbeddingConfigColumn(), nullable=True), sa.Column("metadata_", sa.JSON(), nullable=True), sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), sa.PrimaryKeyConstraint("id"), @@ -131,7 +131,7 @@ def upgrade() -> None: sa.Column("user_id", sa.String(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), - sa.Column("embedding_config", letta.metadata.EmbeddingConfigColumn(), nullable=True), + sa.Column("embedding_config", letta.orm.custom_columns.EmbeddingConfigColumn(), nullable=True), sa.Column("description", sa.String(), nullable=True), sa.Column("metadata_", sa.JSON(), nullable=True), sa.PrimaryKeyConstraint("id"), diff --git a/letta/agent.py b/letta/agent.py index 7484dc7889..a24cf358db 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -26,7 +26,6 @@ from letta.llm_api.llm_api_tools import create from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.memory import summarize_messages -from letta.metadata import MetadataStore from letta.orm import User from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent from letta.schemas.block import BlockUpdate @@ -889,18 +888,14 @@ def step( # additional args chaining: bool = True, max_chaining_steps: Optional[int] = None, - ms: Optional[MetadataStore] = None, **kwargs, ) -> LettaUsageStatistics: """Run Agent.step in a loop, handling chaining via heartbeat requests and function failures""" - # assert ms is not None, "MetadataStore is required" - next_input_message = messages if isinstance(messages, list) else [messages] counter = 0 total_usage = UsageStatistics() step_count = 0 while True: - kwargs["ms"] = ms kwargs["first_message"] = False step_response = self.inner_step( messages=next_input_message, @@ -918,8 +913,7 @@ def step( # logger.debug("Saving agent state") # save updated state - if ms: - save_agent(self) + save_agent(self) # Chain stops if not chaining: @@ -978,7 +972,6 @@ def inner_step( first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS, skip_verify: bool = False, stream: bool = False, # TODO move to config? - ms: Optional[MetadataStore] = None, ) -> AgentStepResponse: """Runs a single step in the agent loop (generates at most one LLM call)""" @@ -1098,7 +1091,6 @@ def inner_step( first_message_retry_limit=first_message_retry_limit, skip_verify=skip_verify, stream=stream, - ms=ms, ) else: diff --git a/letta/chat_only_agent.py b/letta/chat_only_agent.py index eb029e93dd..e340673eba 100644 --- a/letta/chat_only_agent.py +++ b/letta/chat_only_agent.py @@ -3,7 +3,6 @@ from letta.agent import Agent from letta.interface import AgentInterface -from letta.metadata import MetadataStore from letta.prompts import gpt_system from letta.schemas.agent import AgentState, AgentType from letta.schemas.embedding_config import EmbeddingConfig @@ -36,11 +35,9 @@ def step( messages: Union[Message, List[Message]], chaining: bool = True, max_chaining_steps: Optional[int] = None, - ms: Optional[MetadataStore] = None, **kwargs, ) -> LettaUsageStatistics: - # assert ms is not None, "MetadataStore is required" - letta_statistics = super().step(messages=messages, chaining=chaining, max_chaining_steps=max_chaining_steps, ms=ms, **kwargs) + letta_statistics = super().step(messages=messages, chaining=chaining, max_chaining_steps=max_chaining_steps, **kwargs) if self.always_rethink_memory: diff --git a/letta/cli/cli.py b/letta/cli/cli.py index e941589d46..e5a649f7ce 100644 --- a/letta/cli/cli.py +++ b/letta/cli/cli.py @@ -18,7 +18,6 @@ ) from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL from letta.log import get_logger -from letta.metadata import MetadataStore from letta.schemas.enums import OptionState from letta.schemas.memory import ChatMemory, Memory from letta.server.server import logger as server_logger @@ -138,7 +137,6 @@ def run( config = LettaConfig.load() # read user id from config - ms = MetadataStore(config) client = create_client() # determine agent to use, if not provided @@ -332,7 +330,6 @@ def run( letta_agent=letta_agent, config=config, first=first, - ms=ms, no_verify=no_verify, stream=stream, ) # TODO: add back no_verify diff --git a/letta/main.py b/letta/main.py index bb5d143118..c426917092 100644 --- a/letta/main.py +++ b/letta/main.py @@ -19,7 +19,6 @@ from letta.cli.cli_load import app as load_app from letta.config import LettaConfig from letta.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE -from letta.metadata import MetadataStore # from letta.interface import CLIInterface as interface # for printing to terminal from letta.streaming_interface import AgentRefreshStreamingInterface @@ -62,7 +61,6 @@ def run_agent_loop( letta_agent: agent.Agent, config: LettaConfig, first: bool, - ms: MetadataStore, no_verify: bool = False, strip_ui: bool = False, stream: bool = False, @@ -92,7 +90,6 @@ def run_agent_loop( # create client client = create_client() - ms = MetadataStore(config) # TODO: remove # run loops while True: @@ -378,7 +375,6 @@ def process_agent_step(user_message, no_verify): first_message=False, skip_verify=no_verify, stream=stream, - ms=ms, ) else: step_response = letta_agent.step_user_message( @@ -386,7 +382,6 @@ def process_agent_step(user_message, no_verify): first_message=False, skip_verify=no_verify, stream=stream, - ms=ms, ) new_messages = step_response.messages heartbeat_request = step_response.heartbeat_request diff --git a/letta/metadata.py b/letta/metadata.py deleted file mode 100644 index 0ecd696ba6..0000000000 --- a/letta/metadata.py +++ /dev/null @@ -1,157 +0,0 @@ -""" Metadata store for user/agent/data_source information""" - -import os -import secrets -from typing import List, Optional - -from sqlalchemy import JSON, Column, Index, String, TypeDecorator - -from letta.config import LettaConfig -from letta.orm.base import Base -from letta.schemas.api_key import APIKey -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.llm_config import LLMConfig -from letta.schemas.user import User -from letta.settings import settings -from letta.utils import enforce_types - - -class LLMConfigColumn(TypeDecorator): - """Custom type for storing LLMConfig as JSON""" - - impl = JSON - cache_ok = True - - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(JSON()) - - def process_bind_param(self, value, dialect): - if value: - # return vars(value) - if isinstance(value, LLMConfig): - return value.model_dump() - return value - - def process_result_value(self, value, dialect): - if value: - return LLMConfig(**value) - return value - - -class EmbeddingConfigColumn(TypeDecorator): - """Custom type for storing EmbeddingConfig as JSON""" - - impl = JSON - cache_ok = True - - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(JSON()) - - def process_bind_param(self, value, dialect): - if value: - # return vars(value) - if isinstance(value, EmbeddingConfig): - return value.model_dump() - return value - - def process_result_value(self, value, dialect): - if value: - return EmbeddingConfig(**value) - return value - - -class APIKeyModel(Base): - """Data model for authentication tokens. One-to-many relationship with UserModel (1 User - N tokens).""" - - __tablename__ = "tokens" - - id = Column(String, primary_key=True) - # each api key is tied to a user account (that it validates access for) - user_id = Column(String, nullable=False) - # the api key - key = Column(String, nullable=False) - # extra (optional) metadata - name = Column(String) - - Index(__tablename__ + "_idx_user", user_id), - Index(__tablename__ + "_idx_key", key), - - def __repr__(self) -> str: - return f"" - - def to_record(self) -> User: - return APIKey( - id=self.id, - user_id=self.user_id, - key=self.key, - name=self.name, - ) - - -def generate_api_key(prefix="sk-", length=51) -> str: - # Generate 'length // 2' bytes because each byte becomes two hex digits. Adjust length for prefix. - actual_length = max(length - len(prefix), 1) // 2 # Ensure at least 1 byte is generated - random_bytes = secrets.token_bytes(actual_length) - new_key = prefix + random_bytes.hex() - return new_key - - -class MetadataStore: - uri: Optional[str] = None - - def __init__(self, config: LettaConfig): - # TODO: get DB URI or path - if config.metadata_storage_type == "postgres": - # construct URI from enviornment variables - self.uri = settings.pg_uri if settings.pg_uri else config.metadata_storage_uri - - elif config.metadata_storage_type == "sqlite": - path = os.path.join(config.metadata_storage_path, "sqlite.db") - self.uri = f"sqlite:///{path}" - else: - raise ValueError(f"Invalid metadata storage type: {config.metadata_storage_type}") - - # Ensure valid URI - assert self.uri, "Database URI is not provided or is invalid." - - from letta.server.server import db_context - - self.session_maker = db_context - - @enforce_types - def create_api_key(self, user_id: str, name: str) -> APIKey: - """Create an API key for a user""" - new_api_key = generate_api_key() - with self.session_maker() as session: - if session.query(APIKeyModel).filter(APIKeyModel.key == new_api_key).count() > 0: - # NOTE duplicate API keys / tokens should never happen, but if it does don't allow it - raise ValueError(f"Token {new_api_key} already exists") - # TODO store the API keys as hashed - assert user_id and name, "User ID and name must be provided" - token = APIKey(user_id=user_id, key=new_api_key, name=name) - session.add(APIKeyModel(**vars(token))) - session.commit() - return self.get_api_key(api_key=new_api_key) - - @enforce_types - def delete_api_key(self, api_key: str): - """Delete an API key from the database""" - with self.session_maker() as session: - session.query(APIKeyModel).filter(APIKeyModel.key == api_key).delete() - session.commit() - - @enforce_types - def get_api_key(self, api_key: str) -> Optional[APIKey]: - with self.session_maker() as session: - results = session.query(APIKeyModel).filter(APIKeyModel.key == api_key).all() - if len(results) == 0: - return None - assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result - return results[0].to_record() - - @enforce_types - def get_all_api_keys_for_user(self, user_id: str) -> List[APIKey]: - with self.session_maker() as session: - results = session.query(APIKeyModel).filter(APIKeyModel.user_id == user_id).all() - tokens = [r.to_record() for r in results] - return tokens diff --git a/letta/o1_agent.py b/letta/o1_agent.py index eb882bfa01..285ed966fa 100644 --- a/letta/o1_agent.py +++ b/letta/o1_agent.py @@ -2,7 +2,6 @@ from letta.agent import Agent, save_agent from letta.interface import AgentInterface -from letta.metadata import MetadataStore from letta.schemas.agent import AgentState from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import UsageStatistics @@ -56,7 +55,6 @@ def step( messages: Union[Message, List[Message]], chaining: bool = True, max_chaining_steps: Optional[int] = None, - ms: Optional[MetadataStore] = None, **kwargs, ) -> LettaUsageStatistics: """Run Agent.inner_step in a loop, terminate when final thinking message is sent or max_thinking_steps is reached""" @@ -70,7 +68,6 @@ def step( if counter > 0: next_input_message = [] - kwargs["ms"] = ms kwargs["first_message"] = False step_response = self.inner_step( messages=next_input_message, @@ -84,7 +81,6 @@ def step( # check if it is final thinking message if step_response.messages[-1].name == "send_final_message": break - if ms: - save_agent(self) + save_agent(self) return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count) diff --git a/letta/offline_memory_agent.py b/letta/offline_memory_agent.py index 1e71af6c81..f4eeec8a83 100644 --- a/letta/offline_memory_agent.py +++ b/letta/offline_memory_agent.py @@ -2,7 +2,6 @@ from letta.agent import Agent, AgentState, save_agent from letta.interface import AgentInterface -from letta.metadata import MetadataStore from letta.orm import User from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import UsageStatistics @@ -141,7 +140,6 @@ def step( messages: Union[Message, List[Message]], chaining: bool = True, max_chaining_steps: Optional[int] = None, - ms: Optional[MetadataStore] = None, **kwargs, ) -> LettaUsageStatistics: """Go through what is currently in memory core memory and integrate information.""" @@ -153,7 +151,6 @@ def step( while counter < self.max_memory_rethinks: if counter > 0: next_input_message = [] - kwargs["ms"] = ms kwargs["first_message"] = False step_response = self.inner_step( messages=next_input_message, @@ -172,7 +169,6 @@ def step( counter += 1 self.interface.step_complete() - if ms: - save_agent(self) + save_agent(self) return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count) diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 150ab51b4c..c4645c3ed2 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -1,10 +1,15 @@ import uuid -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import JSON, String, TypeDecorator, UniqueConstraint +from sqlalchemy import JSON, String, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.block import Block +from letta.orm.custom_columns import ( + EmbeddingConfigColumn, + LLMConfigColumn, + ToolRulesColumn, +) from letta.orm.message import Message from letta.orm.mixins import OrganizationMixin from letta.orm.organization import Organization @@ -12,15 +17,9 @@ from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.agent import AgentType from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import ToolRuleType from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory -from letta.schemas.tool_rule import ( - ChildToolRule, - InitToolRule, - TerminalToolRule, - ToolRule, -) +from letta.schemas.tool_rule import ToolRule if TYPE_CHECKING: from letta.orm.agents_tags import AgentsTags @@ -29,92 +28,6 @@ from letta.orm.tool import Tool -class LLMConfigColumn(TypeDecorator): - """Custom type for storing LLMConfig as JSON""" - - impl = JSON - cache_ok = True - - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(JSON()) - - def process_bind_param(self, value, dialect): - if value: - # return vars(value) - if isinstance(value, LLMConfig): - return value.model_dump() - return value - - def process_result_value(self, value, dialect): - if value: - return LLMConfig(**value) - return value - - -class EmbeddingConfigColumn(TypeDecorator): - """Custom type for storing EmbeddingConfig as JSON""" - - impl = JSON - cache_ok = True - - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(JSON()) - - def process_bind_param(self, value, dialect): - if value: - # return vars(value) - if isinstance(value, EmbeddingConfig): - return value.model_dump() - return value - - def process_result_value(self, value, dialect): - if value: - return EmbeddingConfig(**value) - return value - - -class ToolRulesColumn(TypeDecorator): - """Custom type for storing a list of ToolRules as JSON""" - - impl = JSON - cache_ok = True - - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(JSON()) - - def process_bind_param(self, value, dialect): - """Convert a list of ToolRules to JSON-serializable format.""" - if value: - data = [rule.model_dump() for rule in value] - for d in data: - d["type"] = d["type"].value - - for d in data: - assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field" - return data - return value - - def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]: - """Convert JSON back to a list of ToolRules.""" - if value: - return [self.deserialize_tool_rule(rule_data) for rule_data in value] - return value - - @staticmethod - def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]: - """Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'.""" - rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var - if rule_type == ToolRuleType.run_first: - return InitToolRule(**data) - elif rule_type == ToolRuleType.exit_loop: - return TerminalToolRule(**data) - elif rule_type == ToolRuleType.constrain_child_tools: - rule = ChildToolRule(**data) - return rule - else: - raise ValueError(f"Unknown tool rule type: {rule_type}") - - class Agent(SqlalchemyBase, OrganizationMixin): __tablename__ = "agents" __pydantic_model__ = PydanticAgentState diff --git a/letta/orm/custom_columns.py b/letta/orm/custom_columns.py new file mode 100644 index 0000000000..1d8263e332 --- /dev/null +++ b/letta/orm/custom_columns.py @@ -0,0 +1,152 @@ +import base64 +from typing import List, Union + +import numpy as np +from sqlalchemy import JSON +from sqlalchemy.types import BINARY, TypeDecorator + +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import ToolRuleType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction +from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule + + +class EmbeddingConfigColumn(TypeDecorator): + """Custom type for storing EmbeddingConfig as JSON.""" + + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(JSON()) + + def process_bind_param(self, value, dialect): + if value and isinstance(value, EmbeddingConfig): + return value.model_dump() + return value + + def process_result_value(self, value, dialect): + if value: + return EmbeddingConfig(**value) + return value + + +class LLMConfigColumn(TypeDecorator): + """Custom type for storing LLMConfig as JSON.""" + + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(JSON()) + + def process_bind_param(self, value, dialect): + if value and isinstance(value, LLMConfig): + return value.model_dump() + return value + + def process_result_value(self, value, dialect): + if value: + return LLMConfig(**value) + return value + + +class ToolRulesColumn(TypeDecorator): + """Custom type for storing a list of ToolRules as JSON""" + + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(JSON()) + + def process_bind_param(self, value, dialect): + """Convert a list of ToolRules to JSON-serializable format.""" + if value: + data = [rule.model_dump() for rule in value] + for d in data: + d["type"] = d["type"].value + + for d in data: + assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field" + return data + return value + + def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]: + """Convert JSON back to a list of ToolRules.""" + if value: + return [self.deserialize_tool_rule(rule_data) for rule_data in value] + return value + + @staticmethod + def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]: + """Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'.""" + rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var + if rule_type == ToolRuleType.run_first: + return InitToolRule(**data) + elif rule_type == ToolRuleType.exit_loop: + return TerminalToolRule(**data) + elif rule_type == ToolRuleType.constrain_child_tools: + rule = ChildToolRule(**data) + return rule + else: + raise ValueError(f"Unknown tool rule type: {rule_type}") + + +class ToolCallColumn(TypeDecorator): + + impl = JSON + cache_ok = True + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(JSON()) + + def process_bind_param(self, value, dialect): + if value: + values = [] + for v in value: + if isinstance(v, ToolCall): + values.append(v.model_dump()) + else: + values.append(v) + return values + + return value + + def process_result_value(self, value, dialect): + if value: + tools = [] + for tool_value in value: + if "function" in tool_value: + tool_call_function = ToolCallFunction(**tool_value["function"]) + del tool_value["function"] + else: + tool_call_function = None + tools.append(ToolCall(function=tool_call_function, **tool_value)) + return tools + return value + + +class CommonVector(TypeDecorator): + """Common type for representing vectors in SQLite""" + + impl = BINARY + cache_ok = True + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(BINARY()) + + def process_bind_param(self, value, dialect): + if value is None: + return value + if isinstance(value, list): + value = np.array(value, dtype=np.float32) + return base64.b64encode(value.tobytes()) + + def process_result_value(self, value, dialect): + if not value: + return value + if dialect.name == "sqlite": + value = base64.b64decode(value) + return np.frombuffer(value, dtype=np.float32) diff --git a/letta/orm/message.py b/letta/orm/message.py index 77ac075a51..e1d6da7889 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -1,46 +1,12 @@ from typing import Optional -from sqlalchemy import JSON, TypeDecorator from sqlalchemy.orm import Mapped, mapped_column, relationship +from letta.orm.custom_columns import ToolCallColumn from letta.orm.mixins import AgentMixin, OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.message import Message as PydanticMessage -from letta.schemas.openai.chat_completions import ToolCall, ToolCallFunction - - -class ToolCallColumn(TypeDecorator): - - impl = JSON - cache_ok = True - - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(JSON()) - - def process_bind_param(self, value, dialect): - if value: - values = [] - for v in value: - if isinstance(v, ToolCall): - values.append(v.model_dump()) - else: - values.append(v) - return values - - return value - - def process_result_value(self, value, dialect): - if value: - tools = [] - for tool_value in value: - if "function" in tool_value: - tool_call_function = ToolCallFunction(**tool_value["function"]) - del tool_value["function"] - else: - tool_call_function = None - tools.append(ToolCall(function=tool_call_function, **tool_value)) - return tools - return value +from letta.schemas.openai.chat_completions import ToolCall class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): diff --git a/letta/orm/passage.py b/letta/orm/passage.py index d1aeb7db68..dc5289b421 100644 --- a/letta/orm/passage.py +++ b/letta/orm/passage.py @@ -1,4 +1,3 @@ -import base64 from datetime import datetime from typing import TYPE_CHECKING from sqlalchemy import Column, DateTime, JSON, Index @@ -13,10 +12,13 @@ from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin, SourceMixin from letta.schemas.passage import Passage as PydanticPassage + from letta.config import LettaConfig from letta.constants import MAX_EMBEDDING_DIM +from letta.orm.custom_columns import CommonVector from letta.orm.mixins import FileMixin, OrganizationMixin -from letta.orm.source import EmbeddingConfigColumn +from letta.orm.source import + from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.passage import Passage as PydanticPassage from letta.settings import settings diff --git a/letta/orm/source.py b/letta/orm/source.py index 0bf160f12e..3ecffda6d9 100644 --- a/letta/orm/source.py +++ b/letta/orm/source.py @@ -1,9 +1,10 @@ from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import JSON, TypeDecorator +from sqlalchemy import JSON from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm import FileMetadata +from letta.orm.custom_columns import EmbeddingConfigColumn from letta.orm.mixins import OrganizationMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.embedding_config import EmbeddingConfig @@ -16,28 +17,6 @@ from letta.orm.agent import Agent -class EmbeddingConfigColumn(TypeDecorator): - """Custom type for storing EmbeddingConfig as JSON""" - - impl = JSON - cache_ok = True - - def load_dialect_impl(self, dialect): - return dialect.type_descriptor(JSON()) - - def process_bind_param(self, value, dialect): - if value: - # return vars(value) - if isinstance(value, EmbeddingConfig): - return value.model_dump() - return value - - def process_result_value(self, value, dialect): - if value: - return EmbeddingConfig(**value) - return value - - class Source(SqlalchemyBase, OrganizationMixin): """A source represents an embedded text passage""" diff --git a/letta/schemas/api_key.py b/letta/schemas/api_key.py deleted file mode 100644 index 37a55ab1c3..0000000000 --- a/letta/schemas/api_key.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Optional - -from pydantic import Field - -from letta.schemas.letta_base import LettaBase - - -class BaseAPIKey(LettaBase): - __id_prefix__ = "sk" # secret key - - -class APIKey(BaseAPIKey): - id: str = BaseAPIKey.generate_id_field() - user_id: str = Field(..., description="The unique identifier of the user associated with the token.") - key: str = Field(..., description="The key value.") - name: str = Field(..., description="Name of the token.") - - -class APIKeyCreate(BaseAPIKey): - user_id: str = Field(..., description="The unique identifier of the user associated with the token.") - name: Optional[str] = Field(None, description="Name of the token.") diff --git a/letta/schemas/sandbox_config.py b/letta/schemas/sandbox_config.py index 246ba8a3ab..9b118cf68e 100644 --- a/letta/schemas/sandbox_config.py +++ b/letta/schemas/sandbox_config.py @@ -3,10 +3,11 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from letta.schemas.agent import AgentState from letta.schemas.letta_base import LettaBase, OrmMetadataBase +from letta.settings import tool_settings # Sandbox Config @@ -45,6 +46,16 @@ class E2BSandboxConfig(BaseModel): def type(self) -> "SandboxType": return SandboxType.E2B + @model_validator(mode="before") + @classmethod + def set_default_template(cls, data: dict): + """ + Assign a default template value if the template field is not provided. + """ + if data.get("template") is None: + data["template"] = tool_settings.e2b_sandbox_template_id + return data + class SandboxConfigBase(OrmMetadataBase): __id_prefix__ = "sandbox" diff --git a/letta/server/rest_api/routers/v1/users.py b/letta/server/rest_api/routers/v1/users.py index 9253d8d29c..27a2feeb03 100644 --- a/letta/server/rest_api/routers/v1/users.py +++ b/letta/server/rest_api/routers/v1/users.py @@ -2,21 +2,9 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Query -from letta.schemas.api_key import APIKey, APIKeyCreate from letta.schemas.user import User, UserCreate, UserUpdate from letta.server.rest_api.utils import get_letta_server -# from letta.server.schemas.users import ( -# CreateAPIKeyRequest, -# CreateAPIKeyResponse, -# CreateUserRequest, -# CreateUserResponse, -# DeleteAPIKeyResponse, -# DeleteUserResponse, -# GetAllUsersResponse, -# GetAPIKeysResponse, -# ) - if TYPE_CHECKING: from letta.schemas.user import User from letta.server.server import SyncServer @@ -84,37 +72,3 @@ def delete_user( except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") return user - - -@router.post("/keys", response_model=APIKey, operation_id="create_api_key") -def create_new_api_key( - create_key: APIKeyCreate = Body(...), - server: "SyncServer" = Depends(get_letta_server), -): - """ - Create a new API key for a user - """ - api_key = server.create_api_key(create_key) - return api_key - - -@router.get("/keys", response_model=List[APIKey], operation_id="list_api_keys") -def get_api_keys( - user_id: str = Query(..., description="The unique identifier of the user."), - server: "SyncServer" = Depends(get_letta_server), -): - """ - Get a list of all API keys for a user - """ - if server.user_manager.get_user_by_id(user_id=user_id) is None: - raise HTTPException(status_code=404, detail=f"User does not exist") - api_keys = server.ms.get_all_api_keys_for_user(user_id=user_id) - return api_keys - - -@router.delete("/keys", response_model=APIKey, operation_id="delete_api_key") -def delete_api_key( - api_key: str = Query(..., description="The API key to be deleted."), - server: "SyncServer" = Depends(get_letta_server), -): - return server.delete_api_key(api_key) diff --git a/letta/server/server.py b/letta/server/server.py index 2b85f243c8..0e00b7b88c 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -25,7 +25,6 @@ from letta.interface import AgentInterface # abstract from letta.interface import CLIInterface # for printing to terminal from letta.log import get_logger -from letta.metadata import MetadataStore from letta.o1_agent import O1Agent from letta.offline_memory_agent import OfflineMemoryAgent from letta.orm import Base @@ -44,7 +43,6 @@ VLLMCompletionsProvider, ) from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent -from letta.schemas.api_key import APIKey, APIKeyCreate from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig @@ -280,7 +278,6 @@ def __init__( config.archival_storage_uri = settings.letta_pg_uri_no_default config.save() self.config = config - self.ms = MetadataStore(self.config) # Managers that interface with data models self.organization_manager = OrganizationManager() @@ -404,9 +401,6 @@ def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface if agent_state is None: raise LettaAgentNotFoundError(f"Agent (agent_id={agent_id}) does not exist") - elif agent_state.created_by_id is None: - raise ValueError(f"Agent (agent_id={agent_id}) does not have a user_id") - actor = self.user_manager.get_user_by_id(user_id=agent_state.created_by_id) interface = interface or self.default_interface_factory() if agent_state.agent_type == AgentType.memgpt_agent: @@ -448,7 +442,7 @@ def _step( try: letta_agent = self.load_agent(agent_id=agent_id, interface=interface, actor=actor) if letta_agent is None: - raise KeyError(f"Agent (user={user_id}, agent={agent_id}) is not loaded") + raise KeyError(f"Agent (user={actor.id}, agent={agent_id}) is not loaded") # Determine whether or not to token stream based on the capability of the interface token_streaming = letta_agent.interface.streaming_mode if hasattr(letta_agent.interface, "streaming_mode") else False @@ -459,7 +453,6 @@ def _step( chaining=self.chaining, max_chaining_steps=self.max_chaining_steps, stream=token_streaming, - ms=self.ms, skip_verify=True, ) @@ -1127,33 +1120,6 @@ def update_agent_core_memory(self, agent_id: str, label: str, value: str, actor: letta_agent = self.load_agent(agent_id=agent_id, actor=actor) return letta_agent.agent_state.memory - def api_key_to_user(self, api_key: str) -> str: - """Decode an API key to a user""" - token = self.ms.get_api_key(api_key=api_key) - user = self.user_manager.get_user_by_id(token.user_id) - if user is None: - raise HTTPException(status_code=403, detail="Invalid credentials") - else: - return user.id - - def create_api_key(self, request: APIKeyCreate) -> APIKey: # TODO: add other fields - """Create a new API key for a user""" - if request.name is None: - request.name = f"API Key {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" - token = self.ms.create_api_key(user_id=request.user_id, name=request.name) - return token - - def list_api_keys(self, user_id: str) -> List[APIKey]: - """List all API keys for a user""" - return self.ms.get_all_api_keys_for_user(user_id=user_id) - - def delete_api_key(self, api_key: str) -> APIKey: - api_key_obj = self.ms.get_api_key(api_key=api_key) - if api_key_obj is None: - raise ValueError("API key does not exist") - self.ms.delete_api_key(api_key=api_key) - return api_key_obj - def delete_source(self, source_id: str, actor: User): """Delete a data source""" self.source_manager.delete_source(source_id=source_id, actor=actor) diff --git a/letta/services/organization_manager.py b/letta/services/organization_manager.py index f98ba65dba..fc86b05fe7 100644 --- a/letta/services/organization_manager.py +++ b/letta/services/organization_manager.py @@ -13,7 +13,6 @@ class OrganizationManager: DEFAULT_ORG_NAME = "default_org" def __init__(self): - # This is probably horrible but we reuse this technique from metadata.py # TODO: Please refactor this out # I am currently working on a ORM refactor and would like to make a more minimal set of changes # - Matt diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py index c91e6669f1..010ae400c8 100644 --- a/letta/services/sandbox_config_manager.py +++ b/letta/services/sandbox_config_manager.py @@ -5,7 +5,7 @@ from letta.orm.errors import NoResultFound from letta.orm.sandbox_config import SandboxConfig as SandboxConfigModel from letta.orm.sandbox_config import SandboxEnvironmentVariable as SandboxEnvVarModel -from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig +from letta.schemas.sandbox_config import LocalSandboxConfig from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar @@ -27,7 +27,6 @@ def __init__(self, settings): from letta.server.server import db_context self.session_maker = db_context - self.e2b_template_id = settings.e2b_sandbox_template_id @enforce_types def get_or_create_default_sandbox_config(self, sandbox_type: SandboxType, actor: PydanticUser) -> PydanticSandboxConfig: @@ -37,8 +36,9 @@ def get_or_create_default_sandbox_config(self, sandbox_type: SandboxType, actor: # TODO: Add more sandbox types later if sandbox_type == SandboxType.E2B: - default_config = E2BSandboxConfig(template=self.e2b_template_id).model_dump(exclude_none=True) + default_config = {} # Empty else: + # TODO: May want to move this to environment variables v.s. persisting in database default_local_sandbox_path = str(Path(__file__).parent / "tool_sandbox_env") default_config = LocalSandboxConfig(sandbox_dir=default_local_sandbox_path).model_dump(exclude_none=True) diff --git a/scripts/migrate_0.3.17.py b/scripts/migrate_0.3.17.py deleted file mode 100644 index fe5f9f7736..0000000000 --- a/scripts/migrate_0.3.17.py +++ /dev/null @@ -1,67 +0,0 @@ -import os - -from sqlalchemy import DDL, MetaData, Table, create_engine, update - -from letta.config import LettaConfig -from letta.constants import BASE_TOOLS -from letta.metadata import MetadataStore -from letta.presets.presets import add_default_tools -from letta.prompts import gpt_system - -# Replace this with your actual database connection URL -config = LettaConfig.load() -if config.recall_storage_type == "sqlite": - DATABASE_URL = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db") -else: - DATABASE_URL = config.recall_storage_uri -print(DATABASE_URL) -engine = create_engine(DATABASE_URL) -metadata = MetaData() - -# defaults -system_prompt = gpt_system.get_system_text("memgpt_chat") - -# Reflect the existing table -table = Table("agents", metadata, autoload_with=engine) - -# Using a connection to manage adding columns and committing updates -with engine.connect() as conn: - trans = conn.begin() - try: - # Check and add 'system' column if it does not exist - if "system" not in table.c: - ddl_system = DDL("ALTER TABLE agents ADD COLUMN system VARCHAR") - conn.execute(ddl_system) - # Reflect the table again to update metadata - metadata.clear() - table = Table("agents", metadata, autoload_with=conn) - - # Check and add 'tools' column if it does not exist - if "tools" not in table.c: - ddl_tools = DDL("ALTER TABLE agents ADD COLUMN tools JSON") - conn.execute(ddl_tools) - # Reflect the table again to update metadata - metadata.clear() - table = Table("agents", metadata, autoload_with=conn) - - # Update all existing rows with default values for the new columns - conn.execute(update(table).values(system=system_prompt, tools=BASE_TOOLS)) - - # Commit transaction - trans.commit() - print("Columns added and data updated successfully!") - - except Exception as e: - print("An error occurred:", e) - trans.rollback() # Rollback if there are errors - -# remove tool table -tool_model = Table("toolmodel", metadata, autoload_with=engine) -tool_model.drop(engine) - -# re-create tables and add default tools -ms = MetadataStore(config) -add_default_tools(None, ms) -print("Tools", [tool.name for tool in ms.list_tools()]) - -print("Migration completed successfully!") diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index eeb71af5a8..9131797cc5 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -65,7 +65,6 @@ def test_summarizer(config_filename): first_message=False, skip_verify=False, stream=False, - ms=client.server.ms, ) # Invoke a summarize