Skip to content

Commit

Permalink
Passing tests for separate passage tables
Browse files Browse the repository at this point in the history
* Added test for two agents attaching same source
* agent_passages should be deleted when an agent is deleted
* Fixed agent manager bug
  • Loading branch information
Mindy Long committed Dec 15, 2024
1 parent 13f02c9 commit e1a5d04
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 111 deletions.
33 changes: 12 additions & 21 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,33 +1376,24 @@ def attach_source(
source_id: str,
source_manager: SourceManager,
agent_manager: AgentManager,
page_size: Optional[int] = None,
):
"""Attach data with name `source_name` to the agent from source_connector."""
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory
passages = self.passage_manager.list_passages(actor=user, source_id=source_id, limit=page_size)

for passage in passages:
assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}"
passage.agent_id = self.agent_state.id
self.passage_manager.update_passage_by_id(passage_id=passage.id, passage=passage, actor=user)

agents_passages = self.passage_manager.list_passages(actor=user, agent_id=self.agent_state.id, source_id=source_id, limit=page_size)
passage_size = self.passage_manager.size(actor=user, agent_id=self.agent_state.id, source_id=source_id)
assert all([p.agent_id == self.agent_state.id for p in agents_passages])
assert len(agents_passages) == passage_size # sanity check
assert passage_size == len(passages), f"Expected {len(passages)} passages, got {passage_size}"

# attach to agent
"""Attach a source to the agent using the SourcesAgents ORM relationship.
Args:
user: User performing the action
source_id: ID of the source to attach
source_manager: SourceManager instance to verify source exists
agent_manager: AgentManager instance to manage agent-source relationship
"""
# Verify source exists and user has permission to access it
source = source_manager.get_source_by_id(source_id=source_id, actor=user)
assert source is not None, f"Source {source_id} not found in metadata store"
assert source is not None, f"Source {source_id} not found in user's organization ({user.organization_id})"

# NOTE: need this redundant line here because we haven't migrated agent to ORM yet
# TODO: delete @matt and remove
# Use the agent_manager to create the relationship
agent_manager.attach_source(agent_id=self.agent_state.id, source_id=source_id, actor=user)

printd(
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(passages)}. Agent now has {passage_size} embeddings in archival memory.",
f"Attached data source {source.name} to agent {self.agent_state.name}.",
)

def update_message(self, message_id: str, request: MessageUpdate) -> Message:
Expand Down
10 changes: 9 additions & 1 deletion letta/orm/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,15 @@ class Agent(SqlalchemyBase, OrganizationMixin):
viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship
doc="All passages derived from sources associated with this agent.",
)
agent_passages: Mapped[List["AgentPassage"]] = relationship("AgentPassage", back_populates="agent", lazy="selectin", order_by="AgentPassage.created_at.desc()",)
agent_passages: Mapped[List["AgentPassage"]] = relationship(
"AgentPassage",
back_populates="agent",
lazy="selectin",
order_by="AgentPassage.created_at.desc()",
cascade="all, delete-orphan",
viewonly=True, # Ensures SQLAlchemy doesn't attempt to manage this relationship
doc="All passages derived created by this agent.",
)

def to_pydantic(self) -> PydanticAgentState:
"""converts to the basic pydantic model counterpart"""
Expand Down
34 changes: 10 additions & 24 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,18 +994,9 @@ def get_agent_archival(self, user_id: str, agent_id: str, cursor: Optional[str]
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)

# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)

# iterate over records
records = letta_agent.passage_manager.list_passages(
actor=actor,
agent_id=agent_id,
cursor=cursor,
limit=limit,
)
passages = self.agent_manager.list_passages(agent_id=agent_id, actor=actor)

return records
return passages

def get_agent_archival_cursor(
self,
Expand All @@ -1019,15 +1010,13 @@ def get_agent_archival_cursor(
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)

# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)

# iterate over records
records = letta_agent.passage_manager.list_passages(
actor=self.default_user,
records = self.agent_manager.list_passages(
actor=actor,
agent_id=agent_id,
cursor=cursor,
limit=limit,
ascending=not reverse,
)
return records

Expand Down Expand Up @@ -1201,9 +1190,11 @@ def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor
for agent_state in agent_states:
agent_id = agent_state.id
agent = self.load_agent(agent_id=agent_id, actor=actor)
curr_passage_size = self.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source_id)

# Attach source to agent
curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
agent.attach_source(user=actor, source_id=source_id, source_manager=self.source_manager, agent_manager=self.agent_manager)
new_passage_size = self.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source_id)
new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
assert new_passage_size >= curr_passage_size # in case empty files are added

return job
Expand Down Expand Up @@ -1267,14 +1258,9 @@ def detach_source_from_agent(
source = self.source_manager.get_source_by_id(source_id=source_id, actor=actor)
elif source_name:
source = self.source_manager.get_source_by_name(source_name=source_name, actor=actor)
source_id = source.id
else:
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
source_id = source.id

# TODO: This should be done with the ORM?
# delete all Passage objects with source_id==source_id from agent's archival memory
agent = self.load_agent(agent_id=agent_id, actor=actor)
agent.passage_manager.delete_passages(actor=actor, limit=100, source_id=source_id)

# delete agent-source mapping
self.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor)
Expand Down
13 changes: 4 additions & 9 deletions letta/services/agent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,6 @@ def delete_agent(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState
with self.session_maker() as session:
# Retrieve the agent
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)

# TODO: @mindy delete this piece when we have a proper passages/sources implementation
# TODO: This is done very hacky on purpose
# TODO: 1000 limit is also wack
passage_manager = PassageManager()
passage_manager.delete_passages(actor=actor, agent_id=agent_id, limit=1000)

agent_state = agent.to_pydantic()
agent.hard_delete(session)
return agent_state
Expand Down Expand Up @@ -465,8 +458,6 @@ def list_passages(
embedded_text = np.array(embedded_text)
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()

results = []

with self.session_maker() as session:
# Start with base query for source passages

Expand Down Expand Up @@ -523,6 +514,10 @@ def list_passages(
main_query = main_query.where(combined_query.c.created_at >= start_date)
if end_date:
main_query = main_query.where(combined_query.c.created_at <= end_date)
if source_id:
main_query = main_query.where(combined_query.c.source_id == source_id)
if file_id:
main_query = main_query.where(combined_query.c.file_id == file_id)

# Vector search
if embedded_text:
Expand Down
1 change: 0 additions & 1 deletion tests/test_client_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,6 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):

# check agent archival memory size
archival_memories = client.get_archival_memory(agent_id=agent.id)
print(archival_memories)
assert len(archival_memories) == 0

# load a file into a source (non-blocking job)
Expand Down
109 changes: 54 additions & 55 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,16 @@ def test_user_message_memory(server, user_id, agent_id):

@pytest.mark.order(3)
def test_load_data(server, user_id, agent_id):
user = server.user_manager.get_user_or_default(user_id=user_id)

# create source
passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
passages_before = server.agent_manager.list_passages(
actor=user, agent_id=agent_id, cursor=None, limit=10000
)
assert len(passages_before) == 0

source = server.source_manager.create_source(
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=server.default_user
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=user
)

# load data
Expand All @@ -409,15 +413,11 @@ def test_load_data(server, user_id, agent_id):
connector = DummyDataConnector(archival_memories)
server.load_data(user_id, connector, source.name)

# @pytest.mark.order(3)
# def test_attach_source_to_agent(server, user_id, agent_id):
# check archival memory size

# attach source
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")

# check archival memory size
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
assert len(passages_after) == 5


Expand Down Expand Up @@ -465,7 +465,7 @@ def test_get_archival_memory(server, user_id, agent_id):
user = server.user_manager.get_user_by_id(user_id=user_id)

# List latest 2 passages
passages_1 = server.passage_manager.list_passages(
passages_1 = server.agent_manager.list_passages(
actor=user,
agent_id=agent_id,
ascending=False,
Expand All @@ -475,7 +475,7 @@ def test_get_archival_memory(server, user_id, agent_id):

# List next 3 passages (earliest 3)
cursor1 = passages_1[-1].id
passages_2 = server.passage_manager.list_passages(
passages_2 = server.agent_manager.list_passages(
actor=user,
agent_id=agent_id,
ascending=False,
Expand All @@ -484,24 +484,28 @@ def test_get_archival_memory(server, user_id, agent_id):

# List all 5
cursor2 = passages_1[0].created_at
passages_3 = server.passage_manager.list_passages(
passages_3 = server.agent_manager.list_passages(
actor=user,
agent_id=agent_id,
ascending=False,
end_date=cursor2,
limit=1000,
)
# assert passages_1[0].text == "Cinderella wore a blue dress"
assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test

latest = passages_1[0]
earliest = passages_2[-1]

# test archival memory
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, limit=1)
passage_1 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, limit=1, ascending=True)
assert len(passage_1) == 1
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passage_1[-1].id, limit=1000)
assert passage_1[0].text == "alpha"
passage_2 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
assert all("alpha" not in passage.text for passage in passage_2)
# test safe empty return
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passages_1[0].id, limit=1000)
passage_none = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
assert len(passage_none) == 0


Expand Down Expand Up @@ -955,6 +959,14 @@ def count_system_messages_in_recall() -> Tuple[int, List[LettaMessage]]:
def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, other_agent_id: str, tmp_path):
actor = server.user_manager.get_user_or_default(user_id)

existing_sources = server.source_manager.list_sources(actor=actor)
if len(existing_sources) > 0:
for source in existing_sources:
server.agent_manager.detach_source(agent_id=agent_id, source_id=source.id, actor=actor)
initial_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert initial_passage_count == 0


# Create a source
source = server.source_manager.create_source(
PydanticSource(
Expand All @@ -973,10 +985,6 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
# Attach source to agent first
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=actor)

# Get initial passage count
initial_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
assert initial_passage_count == 0

# Create a job for loading the first file
job = server.job_manager.create_job(
PydanticJob(
Expand All @@ -1001,7 +1009,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert job.metadata_["num_documents"] == 1

# Verify passages were added
first_file_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
first_file_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert first_file_passage_count > initial_passage_count

# Create a second test file with different content
Expand Down Expand Up @@ -1032,14 +1040,13 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert job2.metadata_["num_documents"] == 1

# Verify passages were appended (not replaced)
final_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
final_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert final_passage_count > first_file_passage_count

# Verify both old and new content is searchable
passages = server.passage_manager.list_passages(
actor=actor,
passages = server.agent_manager.list_passages(
agent_id=agent_id,
source_id=source.id,
actor=actor,
query_text="what does Timber like to eat",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
embed_query=True,
Expand All @@ -1048,35 +1055,27 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert any("chicken" in passage.text.lower() for passage in passages)
assert any("Anna".lower() in passage.text.lower() for passage in passages)

# TODO: Add this test back in after separation of `Passage tables` (LET-449)
# # Load second agent
# agent2 = server.load_agent(agent_id=other_agent_id)

# # Initially should have no passages
# initial_agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id)
# assert initial_agent2_passages == 0

# # Attach source to second agent
# agent2.attach_source(user=user, source_id=source.id, source_manager=server.source_manager, ms=server.ms)

# # Verify second agent has same number of passages as first agent
# agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id)
# agent1_passages = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id)
# assert agent2_passages == agent1_passages

# # Verify second agent can query the same content
# passages2 = server.passage_manager.list_passages(
# actor=user,
# agent_id=other_agent_id,
# source_id=source.id,
# query_text="what does Timber like to eat",
# embedding_config=EmbeddingConfig.default_config(provider="openai"),
# embed_query=True,
# limit=10,
# )
# assert len(passages2) == len(passages)
# assert any("chicken" in passage.text.lower() for passage in passages2)
# assert any("sleep" in passage.text.lower() for passage in passages2)

# # Cleanup
# server.delete_agent(user_id=user_id, agent_id=agent2_state.id)
# Initially should have no passages
initial_agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
assert initial_agent2_passages == 0

# Attach source to second agent
server.agent_manager.attach_source(agent_id=other_agent_id, source_id=source.id, actor=actor)

# Verify second agent has same number of passages as first agent
agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
agent1_passages = server.agent_manager.passage_size(agent_id=agent_id, actor=actor, source_id=source.id)
assert agent2_passages == agent1_passages

# Verify second agent can query the same content
passages2 = server.agent_manager.list_passages(
actor=actor,
agent_id=other_agent_id,
source_id=source.id,
query_text="what does Timber like to eat",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
embed_query=True,
)
assert len(passages2) == len(passages)
assert any("chicken" in passage.text.lower() for passage in passages2)
assert any("Anna".lower() in passage.text.lower() for passage in passages2)

0 comments on commit e1a5d04

Please sign in to comment.