Skip to content

Commit

Permalink
Use AsyncSession in memory
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Nov 18, 2024
1 parent 3188517 commit ccc8ddc
Show file tree
Hide file tree
Showing 24 changed files with 614 additions and 172 deletions.
3 changes: 1 addition & 2 deletions src/backend/base/langflow/base/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from abc import abstractmethod
from typing import TYPE_CHECKING, cast

Expand Down Expand Up @@ -165,7 +164,7 @@ async def run_agent(
)
except ExceptionWithMessageError as e:
msg_id = e.agent_message.id
await asyncio.to_thread(delete_message, id_=msg_id)
await delete_message(id_=msg_id)
self._send_message_event(e.agent_message, category="remove_message")
raise
except Exception:
Expand Down
37 changes: 18 additions & 19 deletions src/backend/base/langflow/base/agents/events.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Add helper functions for each event type
import asyncio
from collections.abc import AsyncIterator
from time import perf_counter
from typing import Any, Protocol
Expand Down Expand Up @@ -53,7 +52,7 @@ def _calculate_duration(start_time: float) -> int:
return result


def handle_on_chain_start(
async def handle_on_chain_start(
event: dict[str, Any], agent_message: Message, send_message_method: SendMessageFunctionType, start_time: float
) -> tuple[Message, float]:
# Create content blocks if they don't exist
Expand All @@ -75,12 +74,12 @@ def handle_on_chain_start(
header={"title": "Input", "icon": "MessageSquare"},
)
agent_message.content_blocks[0].contents.append(text_content)
agent_message = send_message_method(message=agent_message)
agent_message = await send_message_method(message=agent_message)
start_time = perf_counter()
return agent_message, start_time


def handle_on_chain_end(
async def handle_on_chain_end(
event: dict[str, Any], agent_message: Message, send_message_method: SendMessageFunctionType, start_time: float
) -> tuple[Message, float]:
data_output = event["data"].get("output")
Expand All @@ -97,12 +96,12 @@ def handle_on_chain_end(
header={"title": "Output", "icon": "MessageSquare"},
)
agent_message.content_blocks[0].contents.append(text_content)
agent_message = send_message_method(message=agent_message)
agent_message = await send_message_method(message=agent_message)
start_time = perf_counter()
return agent_message, start_time


def handle_on_tool_start(
async def handle_on_tool_start(
event: dict[str, Any],
agent_message: Message,
tool_blocks_map: dict[str, ToolContent],
Expand Down Expand Up @@ -136,12 +135,12 @@ def handle_on_tool_start(
tool_blocks_map[tool_key] = tool_content
agent_message.content_blocks[0].contents.append(tool_content)

agent_message = send_message_method(message=agent_message)
agent_message = await send_message_method(message=agent_message)
tool_blocks_map[tool_key] = agent_message.content_blocks[0].contents[-1]
return agent_message, new_start_time


def handle_on_tool_end(
async def handle_on_tool_end(
event: dict[str, Any],
agent_message: Message,
tool_blocks_map: dict[str, ToolContent],
Expand All @@ -159,13 +158,13 @@ def handle_on_tool_end(
tool_content.duration = duration
tool_content.header = {"title": f"Executed **{tool_content.name}**", "icon": "Hammer"}

agent_message = send_message_method(message=agent_message)
agent_message = await send_message_method(message=agent_message)
new_start_time = perf_counter() # Get new start time for next operation
return agent_message, new_start_time
return agent_message, start_time


def handle_on_tool_error(
async def handle_on_tool_error(
event: dict[str, Any],
agent_message: Message,
tool_blocks_map: dict[str, ToolContent],
Expand All @@ -181,12 +180,12 @@ def handle_on_tool_error(
tool_content.error = event["data"].get("error", "Unknown error")
tool_content.duration = _calculate_duration(start_time)
tool_content.header = {"title": f"Error using **{tool_content.name}**", "icon": "Hammer"}
agent_message = send_message_method(message=agent_message)
agent_message = await send_message_method(message=agent_message)
start_time = perf_counter()
return agent_message, start_time


def handle_on_chain_stream(
async def handle_on_chain_stream(
event: dict[str, Any],
agent_message: Message,
send_message_method: SendMessageFunctionType,
Expand All @@ -196,13 +195,13 @@ def handle_on_chain_stream(
if isinstance(data_chunk, dict) and data_chunk.get("output"):
agent_message.text = data_chunk.get("output")
agent_message.properties.state = "complete"
agent_message = send_message_method(message=agent_message)
agent_message = await send_message_method(message=agent_message)
start_time = perf_counter()
return agent_message, start_time


class ToolEventHandler(Protocol):
def __call__(
async def __call__(
self,
event: dict[str, Any],
agent_message: Message,
Expand All @@ -213,7 +212,7 @@ def __call__(


class ChainEventHandler(Protocol):
def __call__(
async def __call__(
self,
event: dict[str, Any],
agent_message: Message,
Expand Down Expand Up @@ -250,22 +249,22 @@ async def process_agent_events(
agent_message.properties.icon = "Bot"
agent_message.properties.state = "partial"
# Store the initial message
agent_message = await asyncio.to_thread(send_message_method, message=agent_message)
agent_message = await send_message_method(message=agent_message)
try:
# Create a mapping of run_ids to tool contents
tool_blocks_map: dict[str, ToolContent] = {}
start_time = perf_counter()
async for event in agent_executor:
if event["event"] in TOOL_EVENT_HANDLERS:
tool_handler = TOOL_EVENT_HANDLERS[event["event"]]
agent_message, start_time = tool_handler(
agent_message, start_time = await tool_handler(
event, agent_message, tool_blocks_map, send_message_method, start_time
)
elif event["event"] in CHAIN_EVENT_HANDLERS:
chain_handler = CHAIN_EVENT_HANDLERS[event["event"]]
agent_message, start_time = chain_handler(event, agent_message, send_message_method, start_time)
agent_message, start_time = await chain_handler(event, agent_message, send_message_method, start_time)
agent_message.properties.state = "complete"
except Exception as e:
raise ExceptionWithMessageError(agent_message) from e

return Message(**agent_message.model_dump())
return await Message.create(**agent_message.model_dump())
11 changes: 6 additions & 5 deletions src/backend/base/langflow/base/io/chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
from typing import cast

from langflow.custom import Component
from langflow.memory import store_message
from langflow.memory import astore_message
from langflow.schema import Data
from langflow.schema.message import Message

Expand All @@ -10,7 +11,7 @@ class ChatComponent(Component):
display_name = "Chat Component"
description = "Use as base for chat components."

def build_with_data(
async def build_with_data(
self,
*,
sender: str | None = "User",
Expand All @@ -20,15 +21,15 @@ def build_with_data(
session_id: str | None = None,
return_message: bool = False,
) -> str | Message:
message = self._create_message(input_value, sender, sender_name, files, session_id)
message = await asyncio.to_thread(self._create_message, input_value, sender, sender_name, files, session_id)
message_text = message.text if not return_message else message

self.status = message_text
if session_id and isinstance(message, Message) and isinstance(message.text, str):
flow_id = self.graph.flow_id if hasattr(self, "graph") else None
messages = store_message(message, flow_id=flow_id)
messages = await astore_message(message, flow_id=flow_id)
self.status = messages
self._send_messages_events(messages)
await asyncio.to_thread(self._send_messages_events, messages)

return cast(str | Message, message_text)

Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/base/tools/component_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def build_description(component: Component, output: Output) -> str:
return f"{output.method}({args}) - {component.description}"


def send_message_noop(
async def send_message_noop(
message: Message,
text: str | None = None, # noqa: ARG001
background_color: str | None = None, # noqa: ARG001
Expand Down
6 changes: 3 additions & 3 deletions src/backend/base/langflow/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def message_response(self) -> Message:
if llm_model is None:
msg = "No language model selected"
raise ValueError(msg)
self.chat_history = self.get_memory_data()
self.chat_history = await self.get_memory_data()

if self.add_current_date_tool:
if not isinstance(self.tools, list): # type: ignore[has-type]
Expand All @@ -87,12 +87,12 @@ async def message_response(self) -> Message:
agent = self.create_agent_runnable()
return await self.run_agent(agent)

def get_memory_data(self):
async def get_memory_data(self):
memory_kwargs = {
component_input.name: getattr(self, f"{component_input.name}") for component_input in self.memory_inputs
}

return MemoryComponent().set(**memory_kwargs).retrieve_messages()
return await MemoryComponent().set(**memory_kwargs).retrieve_messages()

def get_llm(self):
if isinstance(self.agent_llm, str):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from langflow.custom import CustomComponent
from langflow.memory import get_messages, store_message
from langflow.memory import aget_messages, astore_message
from langflow.schema.message import Message


Expand All @@ -13,12 +13,12 @@ def build_config(self):
"message": {"display_name": "Message"},
}

def build(
async def build(
self,
message: Message,
) -> Message:
flow_id = self.graph.flow_id if hasattr(self, "graph") else None
store_message(message, flow_id=flow_id)
self.status = get_messages()
await astore_message(message, flow_id=flow_id)
self.status = await aget_messages()

return message
12 changes: 6 additions & 6 deletions src/backend/base/langflow/components/helpers/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from langflow.helpers.data import data_to_text
from langflow.inputs import HandleInput
from langflow.io import DropdownInput, IntInput, MessageTextInput, MultilineInput, Output
from langflow.memory import LCBuiltinChatMemory, get_messages
from langflow.memory import LCBuiltinChatMemory, aget_messages
from langflow.schema import Data
from langflow.schema.message import Message
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_USER
Expand Down Expand Up @@ -74,7 +74,7 @@ class MemoryComponent(Component):
Output(display_name="Text", name="messages_text", method="retrieve_messages_as_text"),
]

def retrieve_messages(self) -> Data:
async def retrieve_messages(self) -> Data:
sender = self.sender
sender_name = self.sender_name
session_id = self.session_id
Expand All @@ -88,7 +88,7 @@ def retrieve_messages(self) -> Data:
# override session_id
self.memory.session_id = session_id

stored = self.memory.messages
stored = await self.memory.aget_messages()
# langchain memories are supposed to return messages in ascending order
if order == "DESC":
stored = stored[::-1]
Expand All @@ -99,7 +99,7 @@ def retrieve_messages(self) -> Data:
expected_type = MESSAGE_SENDER_AI if sender == MESSAGE_SENDER_AI else MESSAGE_SENDER_USER
stored = [m for m in stored if m.type == expected_type]
else:
stored = get_messages(
stored = await aget_messages(
sender=sender,
sender_name=sender_name,
session_id=session_id,
Expand All @@ -109,8 +109,8 @@ def retrieve_messages(self) -> Data:
self.status = stored
return stored

def retrieve_messages_as_text(self) -> Message:
stored_text = data_to_text(self.template, self.retrieve_messages())
async def retrieve_messages_as_text(self) -> Message:
stored_text = data_to_text(self.template, await self.retrieve_messages())
self.status = stored_text
return Message(text=stored_text)

Expand Down
14 changes: 8 additions & 6 deletions src/backend/base/langflow/components/helpers/store_message.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from langflow.custom import Component
from langflow.inputs import HandleInput, MessageInput
from langflow.inputs.inputs import MessageTextInput
from langflow.memory import get_messages, store_message
from langflow.memory import aget_messages, astore_message
from langflow.schema.message import Message
from langflow.template import Output
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_NAME_AI
Expand Down Expand Up @@ -47,7 +47,7 @@ class StoreMessageComponent(Component):
Output(display_name="Stored Messages", name="stored_messages", method="store_message"),
]

def store_message(self) -> Message:
async def store_message(self) -> Message:
message = self.message

message.session_id = self.session_id or message.session_id
Expand All @@ -58,13 +58,15 @@ def store_message(self) -> Message:
# override session_id
self.memory.session_id = message.session_id
lc_message = message.to_lc_message()
self.memory.add_messages([lc_message])
stored = self.memory.messages
await self.memory.aadd_messages([lc_message])
stored = await self.memory.aget_messages()
stored = [Message.from_lc_message(m) for m in stored]
if message.sender:
stored = [m for m in stored if m.sender == message.sender]
else:
store_message(message, flow_id=self.graph.flow_id)
stored = get_messages(session_id=message.session_id, sender_name=message.sender_name, sender=message.sender)
await astore_message(message, flow_id=self.graph.flow_id)
stored = await aget_messages(
session_id=message.session_id, sender_name=message.sender_name, sender=message.sender
)
self.status = stored
return stored
7 changes: 4 additions & 3 deletions src/backend/base/langflow/components/inputs/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,12 @@ class ChatInput(ChatComponent):
Output(display_name="Message", name="message", method="message_response"),
]

def message_response(self) -> Message:
async def message_response(self) -> Message:
_background_color = self.background_color
_text_color = self.text_color
_icon = self.chat_icon
message = Message(

message = await Message.create(
text=self.input_value,
sender=self.sender,
sender_name=self.sender_name,
Expand All @@ -91,7 +92,7 @@ def message_response(self) -> Message:
properties={"background_color": _background_color, "text_color": _text_color, "icon": _icon},
)
if self.session_id and isinstance(message, Message) and self.should_store_message:
stored_message = self.send_message(
stored_message = await self.send_message(
message,
)
self.message.value = stored_message
Expand Down
4 changes: 2 additions & 2 deletions src/backend/base/langflow/components/outputs/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _build_source(self, _id: str | None, display_name: str | None, source: str |
source_dict["source"] = source
return Source(**source_dict)

def message_response(self) -> Message:
async def message_response(self) -> Message:
_source, _icon, _display_name, _source_id = self.get_properties_from_source_component()
_background_color = self.background_color
_text_color = self.text_color
Expand All @@ -106,7 +106,7 @@ def message_response(self) -> Message:
message.properties.background_color = _background_color
message.properties.text_color = _text_color
if self.session_id and isinstance(message, Message) and self.should_store_message:
stored_message = self.send_message(
stored_message = await self.send_message(
message,
)
self.message.value = stored_message
Expand Down
Loading

0 comments on commit ccc8ddc

Please sign in to comment.