Skip to content

Commit

Permalink
remove incremental done steps
Browse files Browse the repository at this point in the history
  • Loading branch information
Caren Thomas committed Dec 19, 2024
1 parent e5f230e commit eb117a9
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 55 deletions.
6 changes: 3 additions & 3 deletions letta/client/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.errors import LLMError
from letta.schemas.enums import MessageStreamStatus
from letta.llm_api.openai import OPENAI_SSE_DONE
from letta.schemas.letta_message import (
FunctionCallMessage,
FunctionReturn,
Expand Down Expand Up @@ -47,10 +47,10 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
# if sse.data == OPENAI_SSE_DONE:
# print("finished")
# break
if sse.data in [status.value for status in MessageStreamStatus]:
if sse.data == OPENAI_SSE_DONE:
# break
# print("sse.data::", sse.data)
yield MessageStreamStatus(sse.data)
yield sse.data
else:
chunk_data = json.loads(sse.data)
if "internal_monologue" in chunk_data:
Expand Down
6 changes: 0 additions & 6 deletions letta/schemas/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ class JobStatus(str, Enum):
pending = "pending"


class MessageStreamStatus(str, Enum):
done_generation = "[DONE_GEN]"
done_step = "[DONE_STEP]"
done = "[DONE]"


class ToolRuleType(str, Enum):
"""
Type of tool rule.
Expand Down
8 changes: 4 additions & 4 deletions letta/schemas/letta_response.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import html
import json
import re
from typing import List, Union
from typing import List, Literal, Union

from pydantic import BaseModel, Field

from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
from letta.schemas.usage import LettaUsageStatistics
from letta.utils import json_dumps

# TODO: consider moving into own file

StreamDoneStatus = Literal["[DONE]"]

class LettaResponse(BaseModel):
"""
Expand Down Expand Up @@ -144,5 +144,5 @@ def format_json(json_str):
return html_output


# The streaming response is either [DONE], [DONE_STEP], [DONE], an error, or a LettaMessage
LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus, LettaUsageStatistics]
# The streaming response is either [DONE], an error, or a LettaMessage
LettaStreamingResponse = Union[LettaMessage, StreamDoneStatus, LettaUsageStatistics]
16 changes: 3 additions & 13 deletions letta/server/rest_api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from datetime import datetime
from typing import AsyncGenerator, Literal, Optional, Union

from letta.schemas.letta_response import StreamDoneStatus
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.interface import AgentInterface
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import (
AssistantMessage,
FunctionCall,
Expand Down Expand Up @@ -295,8 +295,6 @@ def __init__(
# if multi_step = True, the stream ends when the agent yields
# if multi_step = False, the stream ends when the step ends
self.multi_step = multi_step
self.multi_step_indicator = MessageStreamStatus.done_step
self.multi_step_gen_indicator = MessageStreamStatus.done_generation

# Support for AssistantMessage
self.use_assistant_message = False # TODO: Remove this
Expand Down Expand Up @@ -325,7 +323,7 @@ def _reset_inner_thoughts_json_reader(self):
self.function_args_buffer = None
self.function_id_buffer = None

async def _create_generator(self) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
async def _create_generator(self) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, StreamDoneStatus], None]:
"""An asynchronous generator that yields chunks as they become available."""
while self._active:
try:
Expand All @@ -350,8 +348,6 @@ def get_generator(self) -> AsyncGenerator:
def _push_to_buffer(
self,
item: Union[
# signal on SSE stream status [DONE_GEN], [DONE_STEP], [DONE]
MessageStreamStatus,
# the non-streaming message types
LettaMessage,
LegacyLettaMessage,
Expand All @@ -362,7 +358,7 @@ def _push_to_buffer(
"""Add an item to the deque"""
assert self._active, "Generator is inactive"
assert (
isinstance(item, LettaMessage) or isinstance(item, LegacyLettaMessage) or isinstance(item, MessageStreamStatus)
isinstance(item, LettaMessage) or isinstance(item, LegacyLettaMessage)
), f"Wrong type: {type(item)}"

self._chunks.append(item)
Expand All @@ -381,9 +377,6 @@ def stream_end(self):
"""Clean up the stream by deactivating and clearing chunks."""
self.streaming_chat_completion_mode_function_name = None

if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
self._push_to_buffer(self.multi_step_gen_indicator)

# Wipe the inner thoughts buffers
self._reset_inner_thoughts_json_reader()

Expand All @@ -393,9 +386,6 @@ def step_complete(self):
# end the stream
self._active = False
self._event.set() # Unblock the generator if it's waiting to allow it to complete
elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
# signal that a new step has started in the stream
self._push_to_buffer(self.multi_step_indicator)

# Wipe the inner thoughts buffers
self._reset_inner_thoughts_json_reader()
Expand Down
8 changes: 4 additions & 4 deletions letta/server/rest_api/routers/v1/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
from pydantic import Field

from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.llm_api.openai import OPENAI_SSE_DONE
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
from letta.schemas.block import ( # , BlockLabelUpdate, BlockLimitUpdate
Block,
BlockUpdate,
CreateBlock,
)
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.job import Job, JobStatus, JobUpdate
from letta.schemas.letta_message import (
LegacyLettaMessage,
Expand Down Expand Up @@ -729,14 +729,14 @@ async def send_message_to_agent(
generated_stream = []
async for message in streaming_interface.get_generator():
assert (
isinstance(message, LettaMessage) or isinstance(message, LegacyLettaMessage) or isinstance(message, MessageStreamStatus)
isinstance(message, LettaMessage) or isinstance(message, LegacyLettaMessage) or message == OPENAI_SSE_DONE
), type(message)
generated_stream.append(message)
if message == MessageStreamStatus.done:
if message == OPENAI_SSE_DONE:
break

# Get rid of the stream status messages
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
filtered_stream = [d for d in generated_stream if d != OPENAI_SSE_DONE]
usage = await task

# By default the stream will be messages of type LettaMessage or LettaLegacyMessage
Expand Down
41 changes: 16 additions & 25 deletions tests/test_client_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from letta import create_client
from letta.client.client import LocalClient, RESTClient
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_PRESET
from letta.llm_api.openai import OPENAI_SSE_DONE
from letta.orm import FileMetadata, Source
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole, MessageStreamStatus
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import (
AssistantMessage,
FunctionCallMessage,
Expand Down Expand Up @@ -245,45 +246,35 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent
inner_thoughts_count = 0
# 2. Check that the agent runs `send_message`
send_message_ran = False
# 3. Check that we get all the start/stop/end tokens we want
# This includes all of the MessageStreamStatus enums
done_gen = False
done_step = False
# 3. Check that we get the end token we want (StreamDoneStatus)
done = False

# print(response)
print(response)
assert response, "Sending message failed"
for chunk in response:
assert isinstance(chunk, LettaStreamingResponse)
if isinstance(chunk, InternalMonologue) and chunk.internal_monologue and chunk.internal_monologue != "":
inner_thoughts_exist = True
inner_thoughts_count += 1
if isinstance(chunk, FunctionCallMessage) and chunk.function_call and chunk.function_call.name == "send_message":
send_message_ran = True
if isinstance(chunk, MessageStreamStatus):
if chunk == MessageStreamStatus.done:
assert not done, "Message stream already done"
done = True
elif chunk == MessageStreamStatus.done_step:
assert not done_step, "Message stream already done step"
done_step = True
elif chunk == MessageStreamStatus.done_generation:
assert not done_gen, "Message stream already done generation"
done_gen = True
if isinstance(chunk, LettaUsageStatistics):
if isinstance(chunk, LettaMessage):
if isinstance(chunk, InternalMonologue) and chunk.internal_monologue and chunk.internal_monologue != "":
inner_thoughts_exist = True
inner_thoughts_count += 1
if isinstance(chunk, FunctionCallMessage) and chunk.function_call and chunk.function_call.name == "send_message":
send_message_ran = True
elif chunk == OPENAI_SSE_DONE:
assert not done, "Message stream already done"
done = True
elif isinstance(chunk, LettaUsageStatistics):
# Some rough metrics for a reasonable usage pattern
assert chunk.step_count == 1
assert chunk.completion_tokens > 10
assert chunk.prompt_tokens > 1000
assert chunk.total_tokens > 1000
else:
assert isinstance(chunk, LettaStreamingResponse)

# If stream tokens, we expect at least one inner thought
assert inner_thoughts_count >= 1, "Expected more than one inner thought"
assert inner_thoughts_exist, "No inner thoughts found"
assert send_message_ran, "send_message function call not found"
assert done, "Message stream not done"
assert done_step, "Message stream not done step"
assert done_gen, "Message stream not done generation"


def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentState):
Expand Down

0 comments on commit eb117a9

Please sign in to comment.