diff --git a/letta/client/streaming.py b/letta/client/streaming.py index 80a8a814e5..688cba5df2 100644 --- a/letta/client/streaming.py +++ b/letta/client/streaming.py @@ -6,7 +6,7 @@ from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING from letta.errors import LLMError -from letta.schemas.enums import MessageStreamStatus +from letta.llm_api.openai import OPENAI_SSE_DONE from letta.schemas.letta_message import ( FunctionCallMessage, FunctionReturn, @@ -47,10 +47,10 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe # if sse.data == OPENAI_SSE_DONE: # print("finished") # break - if sse.data in [status.value for status in MessageStreamStatus]: + if sse.data == OPENAI_SSE_DONE: # break # print("sse.data::", sse.data) - yield MessageStreamStatus(sse.data) + yield sse.data else: chunk_data = json.loads(sse.data) if "internal_monologue" in chunk_data: diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 8b74b83732..c4a3807ac8 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -29,12 +29,6 @@ class JobStatus(str, Enum): pending = "pending" -class MessageStreamStatus(str, Enum): - done_generation = "[DONE_GEN]" - done_step = "[DONE_STEP]" - done = "[DONE]" - - class ToolRuleType(str, Enum): """ Type of tool rule. diff --git a/letta/schemas/letta_response.py b/letta/schemas/letta_response.py index 58dbf42929..be61a55e3b 100644 --- a/letta/schemas/letta_response.py +++ b/letta/schemas/letta_response.py @@ -1,17 +1,17 @@ import html import json import re -from typing import List, Union +from typing import List, Literal, Union from pydantic import BaseModel, Field -from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import LettaMessage, LettaMessageUnion from letta.schemas.usage import LettaUsageStatistics from letta.utils import json_dumps # TODO: consider moving into own file +StreamDoneStatus = Literal["[DONE]"] class LettaResponse(BaseModel): """ @@ -144,5 +144,5 @@ def format_json(json_str): return html_output -# The streaming response is either [DONE], [DONE_STEP], [DONE], an error, or a LettaMessage -LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus, LettaUsageStatistics] +# The streaming response is either [DONE], an error, or a LettaMessage +LettaStreamingResponse = Union[LettaMessage, StreamDoneStatus, LettaUsageStatistics] diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 11843250c7..e512496748 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -6,10 +6,10 @@ from datetime import datetime from typing import AsyncGenerator, Literal, Optional, Union +from letta.schemas.letta_response import StreamDoneStatus from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.interface import AgentInterface from letta.local_llm.constants import INNER_THOUGHTS_KWARG -from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( AssistantMessage, FunctionCall, @@ -295,8 +295,6 @@ def __init__( # if multi_step = True, the stream ends when the agent yields # if multi_step = False, the stream ends when the step ends self.multi_step = multi_step - self.multi_step_indicator = MessageStreamStatus.done_step - self.multi_step_gen_indicator = MessageStreamStatus.done_generation # Support for AssistantMessage self.use_assistant_message = False # TODO: Remove this @@ -325,7 +323,7 @@ def _reset_inner_thoughts_json_reader(self): self.function_args_buffer = None self.function_id_buffer = None - async def _create_generator(self) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]: + async def _create_generator(self) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, StreamDoneStatus], None]: """An asynchronous generator that yields chunks as they become available.""" while self._active: try: @@ -350,8 +348,6 @@ def get_generator(self) -> AsyncGenerator: def _push_to_buffer( self, item: Union[ - # signal on SSE stream status [DONE_GEN], [DONE_STEP], [DONE] - MessageStreamStatus, # the non-streaming message types LettaMessage, LegacyLettaMessage, @@ -362,7 +358,7 @@ def _push_to_buffer( """Add an item to the deque""" assert self._active, "Generator is inactive" assert ( - isinstance(item, LettaMessage) or isinstance(item, LegacyLettaMessage) or isinstance(item, MessageStreamStatus) + isinstance(item, LettaMessage) or isinstance(item, LegacyLettaMessage) ), f"Wrong type: {type(item)}" self._chunks.append(item) @@ -381,9 +377,6 @@ def stream_end(self): """Clean up the stream by deactivating and clearing chunks.""" self.streaming_chat_completion_mode_function_name = None - if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode: - self._push_to_buffer(self.multi_step_gen_indicator) - # Wipe the inner thoughts buffers self._reset_inner_thoughts_json_reader() @@ -393,9 +386,6 @@ def step_complete(self): # end the stream self._active = False self._event.set() # Unblock the generator if it's waiting to allow it to complete - elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode: - # signal that a new step has started in the stream - self._push_to_buffer(self.multi_step_indicator) # Wipe the inner thoughts buffers self._reset_inner_thoughts_json_reader() diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index fc5ce50794..c72eed8593 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -17,6 +17,7 @@ from pydantic import Field from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.llm_api.openai import OPENAI_SSE_DONE from letta.orm.errors import NoResultFound from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent from letta.schemas.block import ( # , BlockLabelUpdate, BlockLimitUpdate @@ -24,7 +25,6 @@ BlockUpdate, CreateBlock, ) -from letta.schemas.enums import MessageStreamStatus from letta.schemas.job import Job, JobStatus, JobUpdate from letta.schemas.letta_message import ( LegacyLettaMessage, @@ -729,14 +729,14 @@ async def send_message_to_agent( generated_stream = [] async for message in streaming_interface.get_generator(): assert ( - isinstance(message, LettaMessage) or isinstance(message, LegacyLettaMessage) or isinstance(message, MessageStreamStatus) + isinstance(message, LettaMessage) or isinstance(message, LegacyLettaMessage) or message == OPENAI_SSE_DONE ), type(message) generated_stream.append(message) - if message == MessageStreamStatus.done: + if message == OPENAI_SSE_DONE: break # Get rid of the stream status messages - filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)] + filtered_stream = [d for d in generated_stream if d != OPENAI_SSE_DONE] usage = await task # By default the stream will be messages of type LettaMessage or LettaLegacyMessage diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 7c634e5fa3..2864c56720 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -12,10 +12,11 @@ from letta import create_client from letta.client.client import LocalClient, RESTClient from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_PRESET +from letta.llm_api.openai import OPENAI_SSE_DONE from letta.orm import FileMetadata, Source from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import MessageRole, MessageStreamStatus +from letta.schemas.enums import MessageRole from letta.schemas.letta_message import ( AssistantMessage, FunctionCallMessage, @@ -245,45 +246,35 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent inner_thoughts_count = 0 # 2. Check that the agent runs `send_message` send_message_ran = False - # 3. Check that we get all the start/stop/end tokens we want - # This includes all of the MessageStreamStatus enums - done_gen = False - done_step = False + # 3. Check that we get the end token we want (StreamDoneStatus) done = False - # print(response) + print(response) assert response, "Sending message failed" for chunk in response: - assert isinstance(chunk, LettaStreamingResponse) - if isinstance(chunk, InternalMonologue) and chunk.internal_monologue and chunk.internal_monologue != "": - inner_thoughts_exist = True - inner_thoughts_count += 1 - if isinstance(chunk, FunctionCallMessage) and chunk.function_call and chunk.function_call.name == "send_message": - send_message_ran = True - if isinstance(chunk, MessageStreamStatus): - if chunk == MessageStreamStatus.done: - assert not done, "Message stream already done" - done = True - elif chunk == MessageStreamStatus.done_step: - assert not done_step, "Message stream already done step" - done_step = True - elif chunk == MessageStreamStatus.done_generation: - assert not done_gen, "Message stream already done generation" - done_gen = True - if isinstance(chunk, LettaUsageStatistics): + if isinstance(chunk, LettaMessage): + if isinstance(chunk, InternalMonologue) and chunk.internal_monologue and chunk.internal_monologue != "": + inner_thoughts_exist = True + inner_thoughts_count += 1 + if isinstance(chunk, FunctionCallMessage) and chunk.function_call and chunk.function_call.name == "send_message": + send_message_ran = True + elif chunk == OPENAI_SSE_DONE: + assert not done, "Message stream already done" + done = True + elif isinstance(chunk, LettaUsageStatistics): # Some rough metrics for a reasonable usage pattern assert chunk.step_count == 1 assert chunk.completion_tokens > 10 assert chunk.prompt_tokens > 1000 assert chunk.total_tokens > 1000 + else: + assert isinstance(chunk, LettaStreamingResponse) # If stream tokens, we expect at least one inner thought assert inner_thoughts_count >= 1, "Expected more than one inner thought" assert inner_thoughts_exist, "No inner thoughts found" assert send_message_ran, "send_message function call not found" assert done, "Message stream not done" - assert done_step, "Message stream not done step" - assert done_gen, "Message stream not done generation" def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentState):