Skip to content

Commit

Permalink
feat: returning a description of each workflow node (#3539)
Browse files Browse the repository at this point in the history
# Description

By returning a description of each node executed by a (LangGraph)
workflow we can show it to the user and thus inform him about the status
of the task execution
  • Loading branch information
jacopo-chevallard authored Dec 30, 2024
1 parent e0ccd3d commit d835fc6
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
1 change: 1 addition & 0 deletions core/quivr_core/rag/entities/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def resolve_special_edges(self):

class NodeConfig(QuivrBaseConfig):
name: str
description: str | None = None
edges: List[str] | None = None
conditional_edge: ConditionalEdgeConfig | None = None
tools: List[Dict[str, Any]] | None = None
Expand Down
1 change: 1 addition & 0 deletions core/quivr_core/rag/entities/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class RAGResponseMetadata(BaseModel):
followup_questions: list[str] = Field(default_factory=list)
sources: list[Any] = Field(default_factory=list)
metadata_model: ChatLLMMetadata | None = None
workflow_step: str | None = None


class ParsedRAGResponse(BaseModel):
Expand Down
23 changes: 22 additions & 1 deletion core/quivr_core/rag/quivr_rag_langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
TypedDict,
)
from uuid import UUID, uuid4

import openai
from langchain.retrievers import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
Expand All @@ -38,6 +37,7 @@
from quivr_core.rag.entities.models import (
ParsedRAGChunkResponse,
QuivrKnowledge,
RAGResponseMetadata,
)
from quivr_core.rag.prompts import custom_prompts
from quivr_core.rag.utils import (
Expand Down Expand Up @@ -950,6 +950,8 @@ async def answer_astream(
version="v1",
config={"metadata": metadata, "callbacks": [langfuse_handler]},
):
node_name = self._extract_node_name(event)

if self._is_final_node_with_docs(event):
tasks = event["data"]["output"]["tasks"]
docs = tasks.docs if tasks else []
Expand All @@ -965,9 +967,17 @@ async def answer_astream(

if new_content:
chunk_metadata = get_chunk_metadata(rolling_message, docs)
if node_name:
chunk_metadata.workflow_step = node_name
yield ParsedRAGChunkResponse(
answer=new_content, metadata=chunk_metadata
)
else:
if node_name:
yield ParsedRAGChunkResponse(
answer="",
metadata=RAGResponseMetadata(workflow_step=node_name),
)

# Yield final metadata chunk
yield ParsedRAGChunkResponse(
Expand All @@ -991,6 +1001,17 @@ def _is_final_node_and_chat_model_stream(self, event: dict) -> bool:
and event["metadata"]["langgraph_node"] in self.final_nodes
)

def _extract_node_name(self, event: dict) -> str:
if "metadata" in event and "langgraph_node" in event["metadata"]:
name = event["metadata"]["langgraph_node"]
for node in self.retrieval_config.workflow_config.nodes:
if node.name == name:
if node.description:
return node.description
else:
return node.name
return ""

async def ainvoke_structured_output(
self, prompt: str, output_class: Type[BaseModel]
) -> Any:
Expand Down

0 comments on commit d835fc6

Please sign in to comment.