diff --git a/backend/retrieval_graph/graph.py b/backend/retrieval_graph/graph.py index 02b1d30c..ffa787fb 100644 --- a/backend/retrieval_graph/graph.py +++ b/backend/retrieval_graph/graph.py @@ -145,7 +145,11 @@ class Plan(TypedDict): {"role": "system", "content": configuration.research_plan_system_prompt} ] + state.messages response = cast(Plan, await model.ainvoke(messages)) - return {"steps": response["steps"], "documents": "delete"} + return { + "steps": response["steps"], + "documents": "delete", + "query": state.messages[-1].content, + } async def conduct_research(state: AgentState) -> dict[str, Any]: diff --git a/backend/retrieval_graph/state.py b/backend/retrieval_graph/state.py index ebd81000..72ddbf39 100644 --- a/backend/retrieval_graph/state.py +++ b/backend/retrieval_graph/state.py @@ -81,3 +81,4 @@ class AgentState(InputState): """Populated by the retriever. This is a list of documents that the agent can reference.""" answer: str = field(default="") """Final answer. Useful for evaluations""" + query: str = field(default="")