Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Use AsyncSession in crud log and find_flow #4691

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/backend/base/langflow/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from langflow.services.database.models.transactions.model import TransactionBase
from langflow.services.database.models.vertex_builds.crud import log_vertex_build as crud_log_vertex_build
from langflow.services.database.models.vertex_builds.model import VertexBuildBase
from langflow.services.database.utils import session_getter
from langflow.services.database.utils import async_session_getter
from langflow.services.deps import get_db_service, get_settings_service

if TYPE_CHECKING:
Expand Down Expand Up @@ -157,14 +157,14 @@ async def log_transaction(
error=error,
flow_id=flow_id if isinstance(flow_id, UUID) else UUID(flow_id),
)
with session_getter(get_db_service()) as session:
inserted = crud_log_transaction(session, transaction)
async with async_session_getter(get_db_service()) as session:
inserted = await crud_log_transaction(session, transaction)
logger.debug(f"Logged transaction: {inserted.id}")
except Exception: # noqa: BLE001
logger.exception("Error logging transaction")


def log_vertex_build(
async def log_vertex_build(
*,
flow_id: str,
vertex_id: str,
Expand All @@ -186,8 +186,8 @@ def log_vertex_build(
# ugly hack to get the model dump with weird datatypes
artifacts=json.loads(json.dumps(artifacts, default=str)),
)
with session_getter(get_db_service()) as session:
inserted = crud_log_vertex_build(session, vertex_build)
async with async_session_getter(get_db_service()) as session:
inserted = await crud_log_vertex_build(session, vertex_build)
logger.debug(f"Logged vertex build: {inserted.build_id}")
except Exception: # noqa: BLE001
logger.exception("Error logging vertex build")
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/graph/vertex/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ async def stream(self):
and hasattr(self.custom_component, "store_message")
):
self.custom_component.store_message(message)
log_vertex_build(
await log_vertex_build(
flow_id=self.graph.flow_id,
vertex_id=self.id,
valid=True,
Expand Down
23 changes: 12 additions & 11 deletions src/backend/base/langflow/helpers/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from langflow.schema.schema import INPUT_FIELD_NAME
from langflow.services.database.models.flow import Flow
from langflow.services.database.models.flow.model import FlowRead
from langflow.services.deps import get_settings_service, session_scope
from langflow.services.deps import async_session_scope, get_settings_service, session_scope

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
Expand Down Expand Up @@ -53,13 +53,13 @@ async def load_flow(
msg = "Flow ID or Flow Name is required"
raise ValueError(msg)
if not flow_id and flow_name:
flow_id = find_flow(flow_name, user_id)
flow_id = await find_flow(flow_name, user_id)
if not flow_id:
msg = f"Flow {flow_name} not found"
raise ValueError(msg)

with session_scope() as session:
graph_data = flow.data if (flow := session.get(Flow, flow_id)) else None
async with async_session_scope() as session:
graph_data = flow.data if (flow := await session.get(Flow, flow_id)) else None
if not graph_data:
msg = f"Flow {flow_id} not found"
raise ValueError(msg)
Expand All @@ -68,9 +68,10 @@ async def load_flow(
return Graph.from_payload(graph_data, flow_id=flow_id, user_id=user_id)


def find_flow(flow_name: str, user_id: str) -> str | None:
with session_scope() as session:
flow = session.exec(select(Flow).where(Flow.name == flow_name).where(Flow.user_id == user_id)).first()
async def find_flow(flow_name: str, user_id: str) -> str | None:
async with async_session_scope() as session:
stmt = select(Flow).where(Flow.name == flow_name).where(Flow.user_id == user_id)
flow = (await session.exec(stmt)).first()
return flow.id if flow else None


Expand Down Expand Up @@ -273,18 +274,18 @@ def get_arg_names(inputs: list[Vertex]) -> list[dict[str, str]]:
]


def get_flow_by_id_or_endpoint_name(flow_id_or_name: str, user_id: UUID | None = None) -> FlowRead | None:
with session_scope() as session:
async def get_flow_by_id_or_endpoint_name(flow_id_or_name: str, user_id: UUID | None = None) -> FlowRead | None:
async with async_session_scope() as session:
endpoint_name = None
try:
flow_id = UUID(flow_id_or_name)
flow = session.get(Flow, flow_id)
flow = await session.get(Flow, flow_id)
except ValueError:
endpoint_name = flow_id_or_name
stmt = select(Flow).where(Flow.endpoint_name == endpoint_name)
if user_id:
stmt = stmt.where(Flow.user_id == user_id)
flow = session.exec(stmt).first()
flow = (await session.exec(stmt)).first()
if flow is None:
raise HTTPException(status_code=404, detail=f"Flow identifier {flow_id_or_name} not found")
return FlowRead.model_validate(flow, from_attributes=True)
Expand Down
10 changes: 5 additions & 5 deletions src/backend/base/langflow/helpers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
from langflow.services.deps import get_db_service


def get_user_by_flow_id_or_endpoint_name(flow_id_or_name: str) -> UserRead | None:
with get_db_service().with_session() as session:
async def get_user_by_flow_id_or_endpoint_name(flow_id_or_name: str) -> UserRead | None:
async with get_db_service().with_async_session() as session:
try:
flow_id = UUID(flow_id_or_name)
flow = session.get(Flow, flow_id)
flow = await session.get(Flow, flow_id)
except ValueError:
stmt = select(Flow).where(Flow.endpoint_name == flow_id_or_name)
flow = session.exec(stmt).first()
flow = (await session.exec(stmt)).first()

if flow is None:
raise HTTPException(status_code=404, detail=f"Flow identifier {flow_id_or_name} not found")

user = session.get(User, flow.user_id)
user = await session.get(User, flow.user_id)
if user is None:
raise HTTPException(status_code=404, detail=f"User for flow {flow_id_or_name} not found")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from uuid import UUID

from sqlalchemy.exc import IntegrityError
from sqlmodel import Session, col, select
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession

from langflow.services.database.models.transactions.model import TransactionBase, TransactionTable
Expand All @@ -21,12 +21,13 @@ async def get_transactions_by_flow_id(
return list(transactions)


def log_transaction(db: Session, transaction: TransactionBase) -> TransactionTable:
async def log_transaction(db: AsyncSession, transaction: TransactionBase) -> TransactionTable:
table = TransactionTable(**transaction.model_dump())
db.add(table)
try:
db.commit()
await db.commit()
await db.refresh(table)
except IntegrityError:
db.rollback()
await db.rollback()
raise
return table
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from uuid import UUID

from sqlalchemy.exc import IntegrityError
from sqlmodel import Session, col, delete, select
from sqlmodel import col, delete, select
from sqlmodel.ext.asyncio.session import AsyncSession

from langflow.services.database.models.vertex_builds.model import VertexBuildBase, VertexBuildTable
Expand All @@ -21,13 +21,14 @@ async def get_vertex_builds_by_flow_id(
return list(builds)


def log_vertex_build(db: Session, vertex_build: VertexBuildBase) -> VertexBuildTable:
async def log_vertex_build(db: AsyncSession, vertex_build: VertexBuildBase) -> VertexBuildTable:
table = VertexBuildTable(**vertex_build.model_dump())
db.add(table)
try:
db.commit()
await db.commit()
await db.refresh(table)
except IntegrityError:
db.rollback()
await db.rollback()
raise
return table

Expand Down
9 changes: 5 additions & 4 deletions src/backend/base/langflow/services/socket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from langflow.graph.utils import log_vertex_build
from langflow.graph.vertex.base import Vertex
from langflow.services.database.models.flow.model import Flow
from langflow.services.deps import get_session
from langflow.services.deps import get_async_session


def set_socketio_server(socketio_server) -> None:
Expand All @@ -23,8 +23,9 @@ def set_socketio_server(socketio_server) -> None:

async def get_vertices(sio, sid, flow_id, chat_service) -> None:
try:
session = next(get_session())
flow: Flow = session.exec(select(Flow).where(Flow.id == flow_id)).first()
session = await anext(get_async_session())
stmt = select(Flow).where(Flow.id == flow_id)
flow: Flow = (await session.exec(stmt)).first()
if not flow or not flow.data:
await sio.emit("error", data="Invalid flow ID", to=sid)
return
Expand Down Expand Up @@ -87,7 +88,7 @@ async def build_vertex(
result_dict = ResultDataResponse(results={})
artifacts = {}
await set_cache(flow_id, graph)
log_vertex_build(
await log_vertex_build(
flow_id=flow_id,
vertex_id=vertex_id,
valid=valid,
Expand Down
4 changes: 2 additions & 2 deletions src/backend/tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ class VertexTuple(NamedTuple):
"vertex_id": "vid",
"flow_id": flow_id,
}
log_vertex_build(
await log_vertex_build(
flow_id=build["flow_id"],
vertex_id=build["vertex_id"],
valid=build["valid"],
Expand Down Expand Up @@ -376,7 +376,7 @@ class VertexTuple(NamedTuple):
"vertex_id": "vid",
"flow_id": flow_id,
}
log_vertex_build(
await log_vertex_build(
flow_id=build["flow_id"],
vertex_id=build["vertex_id"],
valid=build["valid"],
Expand Down
Loading