Skip to content

Commit

Permalink
moving list_passages(), size() to AgentManager with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Mindy Long committed Dec 15, 2024
1 parent 6f4b7d7 commit 13f02c9
Show file tree
Hide file tree
Showing 5 changed files with 492 additions and 7 deletions.
12 changes: 11 additions & 1 deletion letta/orm/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,17 @@ class Agent(SqlalchemyBase, OrganizationMixin):
lazy="selectin",
doc="Tags associated with the agent.",
)
# passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="agent", lazy="selectin")
source_passages: Mapped[List["SourcePassage"]] = relationship(
"SourcePassage",
secondary="sources_agents", # The join table for Agent -> Source
primaryjoin="Agent.id == sources_agents.c.agent_id",
secondaryjoin="and_(SourcePassage.source_id == sources_agents.c.source_id)",
lazy="selectin",
order_by="SourcePassage.created_at.desc()",
viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship
doc="All passages derived from sources associated with this agent.",
)
agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="agent", lazy="selectin", order_by="AgentPassage.created_at.desc()",)

def to_pydantic(self) -> PydanticAgentState:
"""converts to the basic pydantic model counterpart"""
Expand Down
8 changes: 7 additions & 1 deletion letta/orm/passage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import base64
from datetime import datetime
from typing import Optional, TYPE_CHECKING
from typing import TYPE_CHECKING
from sqlalchemy import Column, DateTime, JSON, Index
from sqlalchemy.orm import Mapped, mapped_column, relationship, declared_attr
from sqlalchemy.types import TypeDecorator, BINARY
Expand All @@ -25,6 +25,7 @@

if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm.agent import Agent


class CommonVector(TypeDecorator):
Expand Down Expand Up @@ -110,3 +111,8 @@ class AgentPassage(BasePassage, AgentMixin):
@declared_attr
def organization(cls) -> Mapped["Organization"]:
return relationship("Organization", back_populates="agent_passages", lazy="selectin")

@declared_attr
def agent(cls) -> Mapped["Agent"]:
"""Relationship to agent"""
return relationship("Agent", back_populates="agent_passages", lazy="selectin", passive_deletes=True)
250 changes: 249 additions & 1 deletion letta/services/agent_manager.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
from typing import Dict, List, Optional
from datetime import datetime
import numpy as np

from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from sqlalchemy import select, union_all, literal, func


from letta import agent
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
from letta.embeddings import embedding_model, parse_and_chunk_text
from letta.orm import Agent as AgentModel
from letta.orm import Block as BlockModel
from letta.orm import Source as SourceModel
from letta.orm import Tool as ToolModel
from letta.orm import AgentPassage, SourcePassage
from letta.orm import SourcesAgents
from letta.orm.errors import NoResultFound
from letta.orm.sqlite_functions import adapt_array
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.source import Source as PydanticSource
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
from letta.schemas.user import User as PydanticUser
Expand All @@ -23,6 +34,7 @@
from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_manager import ToolManager
from letta.settings import settings
from letta.utils import enforce_types


Expand Down Expand Up @@ -403,3 +415,239 @@ def detach_block_with_label(

agent.update(session, actor=actor)
return agent.to_pydantic()

# ======================================================================================================================
# Passage Management
# ======================================================================================================================
@enforce_types
def list_passages(
self,
agent_id : str,
actor : PydanticUser,
file_id : Optional[str] = None,
limit : Optional[int] = 50,
query_text : Optional[str] = None,
start_date : Optional[datetime] = None,
end_date : Optional[datetime] = None,
cursor : Optional[str] = None,
source_id : Optional[str] = None,
embed_query : bool = False,
ascending : bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only : bool = False
) -> List[PydanticPassage]:
"""
Lists all passages attached to an agent.
Args:
agent_id: ID of the agent to list passages for
actor: User performing the action
file_id: Optional ID of file to filter passages by
limit: Optional maximum number of passages to return
query_text: Optional query text to search for
start_date: Optional start date to filter passages by
end_date: Optional end date to filter passages by
cursor: Optional passage ID to start from
source_id: Optional ID of source to filter passages by
ascending: If True, sort by created_at ascending, if False sort descending
embed_query: If True, embed the query text using the specified embedding config
embedding_config: Optional embedding config to use for embedding the query text
agent_only: If True, only return agent passages, if False return agent and source passages
Returns:
List[PydanticPassage]: List of passages attached to the agent
"""
embedded_text = None
if embed_query:
assert embedding_config is not None, "embedding_config must be specified for vector search"
assert query_text is not None, "query_text must be specified for vector search"
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
embedded_text = np.array(embedded_text)
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()

results = []

with self.session_maker() as session:
# Start with base query for source passages

source_passages = None
if not agent_only: # Include source passages
source_passages = (
select(
SourcePassage,
literal(None).label('agent_id')
)
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
.where(SourcesAgents.agent_id == agent_id)
.where(SourcePassage.organization_id == actor.organization_id)
)

if source_id:
source_passages = source_passages.where(SourcePassage.source_id == source_id)
if file_id:
source_passages = source_passages.where(SourcePassage.file_id == file_id)

# Add agent passages query
agent_passages = (
select(
AgentPassage.id,
AgentPassage.text,
AgentPassage.embedding_config,
AgentPassage.metadata_,
AgentPassage.created_at,
AgentPassage.embedding,
AgentPassage.updated_at,
AgentPassage.is_deleted,
AgentPassage._created_by_id,
AgentPassage._last_updated_by_id,
AgentPassage.organization_id,
literal(None).label('file_id'),
literal(None).label('source_id'),
AgentPassage.agent_id
)
.where(AgentPassage.agent_id == agent_id)
.where(AgentPassage.organization_id == actor.organization_id)
)

# Combine queries
if source_passages is not None:
combined_query = union_all(source_passages, agent_passages).cte('combined_passages')
else:
combined_query = agent_passages.cte('combined_passages')

# Build main query from combined CTE
main_query = select(combined_query)

# Apply filters
if start_date:
main_query = main_query.where(combined_query.c.created_at >= start_date)
if end_date:
main_query = main_query.where(combined_query.c.created_at <= end_date)

# Vector search
if embedded_text:
from letta.settings import settings

if settings.letta_pg_uri_no_default:
# PostgreSQL with pgvector
main_query = main_query.order_by(
combined_query.c.embedding.cosine_distance(embedded_text).asc()
)
else:
# SQLite with custom vector type
query_embedding_binary = adapt_array(embedded_text)
if ascending:
main_query = main_query.order_by(
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
combined_query.c.created_at.asc(),
combined_query.c.id.asc()
)
else:
main_query = main_query.order_by(
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
combined_query.c.created_at.desc(),
combined_query.c.id.asc()
)
else:
if query_text:
main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text)))

# Handle cursor-based pagination
if cursor:
cursor_query = select(combined_query.c.created_at).where(
combined_query.c.id == cursor
).scalar_subquery()

if ascending:
main_query = main_query.where(
combined_query.c.created_at > cursor_query
)
else:
main_query = main_query.where(
combined_query.c.created_at < cursor_query
)

# Add ordering
if not embed_query: # Skip if already ordered by similarity
if ascending:
main_query = main_query.order_by(
combined_query.c.created_at.asc(),
combined_query.c.id.asc(),
)
else:
main_query = main_query.order_by(
combined_query.c.created_at.desc(),
combined_query.c.id.asc(),
)

# Add limit
if limit:
main_query = main_query.limit(limit)

# Execute query
results = list(session.execute(main_query))

passages = []
for row in results:
data = dict(row._mapping)
if data['agent_id'] is not None:
# This is an AgentPassage - remove source fields
data.pop('source_id', None)
data.pop('file_id', None)
passage = AgentPassage(**data)
else:
# This is a SourcePassage - remove agent field
data.pop('agent_id', None)
passage = SourcePassage(**data)
passages.append(passage)

return [p.to_pydantic() for p in passages]


# # Verify agent exists and user has permission to access it
# agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)

# # Use the lazy-loaded relationships to get passages
# # Sort by created_at
# all_passages = sorted(
# [*agent.source_passages, *agent.agent_passages],
# key=lambda x: x.created_at,
# reverse=(not ascending)
# )


# return [passage.to_pydantic() for passage in all_passages]

@enforce_types
def passage_size(
self,
agent_id : str,
actor : PydanticUser,
file_id : Optional[str] = None,
limit : Optional[int] = 50,
query_text : Optional[str] = None,
start_date : Optional[datetime] = None,
end_date : Optional[datetime] = None,
cursor : Optional[str] = None,
source_id : Optional[str] = None,
embed_query : bool = False,
ascending : bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only : bool = False
) -> int:
return len(
self.list_passages(
agent_id=agent_id,
actor=actor,
file_id=file_id,
limit=limit,
query_text=query_text,
start_date=start_date,
end_date=end_date,
cursor=cursor,
source_id=source_id,
embed_query=embed_query,
ascending=ascending,
embedding_config=embedding_config,
agent_only=agent_only)
)
3 changes: 0 additions & 3 deletions letta/services/passage_manager.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from typing import List, Optional

from datetime import datetime
from typing import List, Optional

import numpy as np

from sqlalchemy import select, union_all, literal
Expand Down
Loading

0 comments on commit 13f02c9

Please sign in to comment.