Skip to content

Commit

Permalink
Use AsyncSession in crud log and find_flow
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Nov 18, 2024
1 parent 3188517 commit 2f4eadc
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 37 deletions.
12 changes: 6 additions & 6 deletions src/backend/base/langflow/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from langflow.services.database.models.transactions.model import TransactionBase
from langflow.services.database.models.vertex_builds.crud import log_vertex_build as crud_log_vertex_build
from langflow.services.database.models.vertex_builds.model import VertexBuildBase
from langflow.services.database.utils import session_getter
from langflow.services.database.utils import async_session_getter
from langflow.services.deps import get_db_service, get_settings_service

if TYPE_CHECKING:
Expand Down Expand Up @@ -157,14 +157,14 @@ async def log_transaction(
error=error,
flow_id=flow_id if isinstance(flow_id, UUID) else UUID(flow_id),
)
with session_getter(get_db_service()) as session:
inserted = crud_log_transaction(session, transaction)
async with async_session_getter(get_db_service()) as session:
inserted = await crud_log_transaction(session, transaction)
logger.debug(f"Logged transaction: {inserted.id}")
except Exception: # noqa: BLE001
logger.exception("Error logging transaction")


def log_vertex_build(
async def log_vertex_build(
*,
flow_id: str,
vertex_id: str,
Expand All @@ -186,8 +186,8 @@ def log_vertex_build(
# ugly hack to get the model dump with weird datatypes
artifacts=json.loads(json.dumps(artifacts, default=str)),
)
with session_getter(get_db_service()) as session:
inserted = crud_log_vertex_build(session, vertex_build)
with async_session_getter(get_db_service()) as session:
inserted = await crud_log_vertex_build(session, vertex_build)
logger.debug(f"Logged vertex build: {inserted.build_id}")
except Exception: # noqa: BLE001
logger.exception("Error logging vertex build")
Expand Down
2 changes: 1 addition & 1 deletion src/backend/base/langflow/graph/vertex/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ async def stream(self):
and hasattr(self.custom_component, "store_message")
):
self.custom_component.store_message(message)
log_vertex_build(
await log_vertex_build(
flow_id=self.graph.flow_id,
vertex_id=self.id,
valid=True,
Expand Down
23 changes: 12 additions & 11 deletions src/backend/base/langflow/helpers/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from langflow.schema.schema import INPUT_FIELD_NAME
from langflow.services.database.models.flow import Flow
from langflow.services.database.models.flow.model import FlowRead
from langflow.services.deps import get_settings_service, session_scope
from langflow.services.deps import async_session_scope, get_settings_service, session_scope

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
Expand Down Expand Up @@ -53,13 +53,13 @@ async def load_flow(
msg = "Flow ID or Flow Name is required"
raise ValueError(msg)
if not flow_id and flow_name:
flow_id = find_flow(flow_name, user_id)
flow_id = await find_flow(flow_name, user_id)
if not flow_id:
msg = f"Flow {flow_name} not found"
raise ValueError(msg)

with session_scope() as session:
graph_data = flow.data if (flow := session.get(Flow, flow_id)) else None
async with async_session_scope() as session:
graph_data = flow.data if (flow := await session.get(Flow, flow_id)) else None
if not graph_data:
msg = f"Flow {flow_id} not found"
raise ValueError(msg)
Expand All @@ -68,9 +68,10 @@ async def load_flow(
return Graph.from_payload(graph_data, flow_id=flow_id, user_id=user_id)


def find_flow(flow_name: str, user_id: str) -> str | None:
with session_scope() as session:
flow = session.exec(select(Flow).where(Flow.name == flow_name).where(Flow.user_id == user_id)).first()
async def find_flow(flow_name: str, user_id: str) -> str | None:
with async_session_scope() as session:
stmt = select(Flow).where(Flow.name == flow_name).where(Flow.user_id == user_id)
flow = (await session.exec(stmt)).first()
return flow.id if flow else None


Expand Down Expand Up @@ -273,18 +274,18 @@ def get_arg_names(inputs: list[Vertex]) -> list[dict[str, str]]:
]


def get_flow_by_id_or_endpoint_name(flow_id_or_name: str, user_id: UUID | None = None) -> FlowRead | None:
with session_scope() as session:
async def get_flow_by_id_or_endpoint_name(flow_id_or_name: str, user_id: UUID | None = None) -> FlowRead | None:
with async_session_scope() as session:
endpoint_name = None
try:
flow_id = UUID(flow_id_or_name)
flow = session.get(Flow, flow_id)
flow = await session.get(Flow, flow_id)
except ValueError:
endpoint_name = flow_id_or_name
stmt = select(Flow).where(Flow.endpoint_name == endpoint_name)
if user_id:
stmt = stmt.where(Flow.user_id == user_id)
flow = session.exec(stmt).first()
flow = (await session.exec(stmt)).first()
if flow is None:
raise HTTPException(status_code=404, detail=f"Flow identifier {flow_id_or_name} not found")
return FlowRead.model_validate(flow, from_attributes=True)
Expand Down
10 changes: 5 additions & 5 deletions src/backend/base/langflow/helpers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,19 @@
from langflow.services.deps import get_db_service


def get_user_by_flow_id_or_endpoint_name(flow_id_or_name: str) -> UserRead | None:
with get_db_service().with_session() as session:
async def get_user_by_flow_id_or_endpoint_name(flow_id_or_name: str) -> UserRead | None:
async with get_db_service().with_async_session() as session:
try:
flow_id = UUID(flow_id_or_name)
flow = session.get(Flow, flow_id)
flow = await session.get(Flow, flow_id)
except ValueError:
stmt = select(Flow).where(Flow.endpoint_name == flow_id_or_name)
flow = session.exec(stmt).first()
flow = (await session.exec(stmt)).first()

if flow is None:
raise HTTPException(status_code=404, detail=f"Flow identifier {flow_id_or_name} not found")

user = session.get(User, flow.user_id)
user = await session.get(User, flow.user_id)
if user is None:
raise HTTPException(status_code=404, detail=f"User for flow {flow_id_or_name} not found")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from uuid import UUID

from sqlalchemy.exc import IntegrityError
from sqlmodel import Session, col, select
from sqlmodel import col, select
from sqlmodel.ext.asyncio.session import AsyncSession

from langflow.services.database.models.transactions.model import TransactionBase, TransactionTable
Expand All @@ -21,12 +21,12 @@ async def get_transactions_by_flow_id(
return list(transactions)


def log_transaction(db: Session, transaction: TransactionBase) -> TransactionTable:
async def log_transaction(db: AsyncSession, transaction: TransactionBase) -> TransactionTable:
table = TransactionTable(**transaction.model_dump())
db.add(table)
try:
db.commit()
await db.commit()
except IntegrityError:
db.rollback()
await db.rollback()
raise
return table
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from uuid import UUID

from sqlalchemy.exc import IntegrityError
from sqlmodel import Session, col, delete, select
from sqlmodel import col, delete, select
from sqlmodel.ext.asyncio.session import AsyncSession

from langflow.services.database.models.vertex_builds.model import VertexBuildBase, VertexBuildTable
Expand All @@ -21,13 +21,13 @@ async def get_vertex_builds_by_flow_id(
return list(builds)


def log_vertex_build(db: Session, vertex_build: VertexBuildBase) -> VertexBuildTable:
async def log_vertex_build(db: AsyncSession, vertex_build: VertexBuildBase) -> VertexBuildTable:
table = VertexBuildTable(**vertex_build.model_dump())
db.add(table)
try:
db.commit()
await db.commit()
except IntegrityError:
db.rollback()
await db.rollback()
raise
return table

Expand Down
9 changes: 5 additions & 4 deletions src/backend/base/langflow/services/socket/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from langflow.graph.utils import log_vertex_build
from langflow.graph.vertex.base import Vertex
from langflow.services.database.models.flow.model import Flow
from langflow.services.deps import get_session
from langflow.services.deps import get_async_session


def set_socketio_server(socketio_server) -> None:
Expand All @@ -23,8 +23,9 @@ def set_socketio_server(socketio_server) -> None:

async def get_vertices(sio, sid, flow_id, chat_service) -> None:
try:
session = next(get_session())
flow: Flow = session.exec(select(Flow).where(Flow.id == flow_id)).first()
session = await anext(get_async_session())
stmt = select(Flow).where(Flow.id == flow_id)
flow: Flow = (await session.exec(stmt)).first()
if not flow or not flow.data:
await sio.emit("error", data="Invalid flow ID", to=sid)
return
Expand Down Expand Up @@ -87,7 +88,7 @@ async def build_vertex(
result_dict = ResultDataResponse(results={})
artifacts = {}
await set_cache(flow_id, graph)
log_vertex_build(
await log_vertex_build(
flow_id=flow_id,
vertex_id=vertex_id,
valid=valid,
Expand Down
4 changes: 2 additions & 2 deletions src/backend/tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ class VertexTuple(NamedTuple):
"vertex_id": "vid",
"flow_id": flow_id,
}
log_vertex_build(
await log_vertex_build(
flow_id=build["flow_id"],
vertex_id=build["vertex_id"],
valid=build["valid"],
Expand Down Expand Up @@ -376,7 +376,7 @@ class VertexTuple(NamedTuple):
"vertex_id": "vid",
"flow_id": flow_id,
}
log_vertex_build(
await log_vertex_build(
flow_id=build["flow_id"],
vertex_id=build["vertex_id"],
valid=build["valid"],
Expand Down

0 comments on commit 2f4eadc

Please sign in to comment.