From ee7e019fb51b46007c40a4fbfffe5bb53e0b5e7f Mon Sep 17 00:00:00 2001 From: calyptobai Date: Tue, 19 Dec 2023 17:07:29 -0500 Subject: [PATCH] Use `conversation_id` instead of `thread_id` in `GET /answer` Rather than returning an initial JSON object, we introduce a new `ChatEvent` type, and return the conversation ID on stream end upon successful store. --- server/bleep/src/webserver/answer.rs | 72 +++++++++++++++------------- 1 file changed, 39 insertions(+), 33 deletions(-) diff --git a/server/bleep/src/webserver/answer.rs b/server/bleep/src/webserver/answer.rs index e7a225c965..7d62aed50f 100644 --- a/server/bleep/src/webserver/answer.rs +++ b/server/bleep/src/webserver/answer.rs @@ -106,8 +106,6 @@ pub(super) async fn answer( None => Conversation::new(project_id), }; - let conversation_id = conversation.store(&app.sql, user_id).await?; - let Answer { parent_exchange_id, q, @@ -151,7 +149,6 @@ pub(super) async fn answer( query_id, project_id, conversation, - conversation_id, action, } .execute() @@ -166,10 +163,22 @@ struct AgentExecutor { query_id: uuid::Uuid, project_id: i64, conversation: Conversation, - conversation_id: i64, action: Action, } +#[derive(serde::Serialize)] +enum AnswerEvent { + ChatEvent(Exchange), + StreamEnd(StreamEnd), +} + +#[derive(serde::Serialize)] +struct StreamEnd { + thread_id: String, + query_id: uuid::Uuid, + conversation_id: i64, +} + type SseDynStream = Sse + Send>>>; impl AgentExecutor { @@ -254,14 +263,12 @@ impl AgentExecutor { } }; - let initial_message = json!({ - "thread_id": self.conversation.thread_id.to_string(), - "query_id": self.query_id, - "conversation_id": self.conversation_id, - }); - // let project: Project = serde_json::from_str(&self.params.project).unwrap(); - let Answer { agent_model, answer_model, .. } = self.params.clone(); + let Answer { + agent_model, + answer_model, + .. + } = self.params.clone(); let (exchange_tx, exchange_rx) = tokio::sync::mpsc::channel(10); @@ -276,7 +283,7 @@ impl AgentExecutor { repo_refs, exchange_state: ExchangeState::Pending, answer_model, - agent_model + agent_model, }; let stream = async_stream::try_stream! { @@ -304,7 +311,7 @@ impl AgentExecutor { timeout, ) { match item { - Ok(Either::Left(exchange)) => yield exchange.compressed(), + Ok(Either::Left(exchange)) => yield AnswerEvent::ChatEvent(exchange.compressed()), Ok(Either::Right(next_action)) => match next_action { Ok(n) => break next = n, Err(e) => break 'outer Err(agent::Error::Processing(e)), @@ -319,7 +326,7 @@ impl AgentExecutor { // of the above loop without ever processing the final message. Here, we empty the // queue. while let Some(Some(exchange)) = exchange_rx.next().now_or_never() { - yield exchange.compressed(); + yield AnswerEvent::ChatEvent(exchange.compressed()); } match next { @@ -331,7 +338,21 @@ impl AgentExecutor { agent.complete(result.is_ok()); match result { - Ok(_) => {} + Ok(_) => { + let conversation_id = agent.conversation.store( + &agent.app.sql, + agent.user.username().context("agent failed to get user ID")?, + ) + .await?; + + let final_message = StreamEnd { + thread_id: agent.conversation.thread_id.to_string(), + query_id: agent.query_id, + conversation_id, + }; + + yield AnswerEvent::StreamEnd(final_message); + } Err(agent::Error::Timeout(duration)) => { warn!("Timeout reached."); agent.track_query( @@ -347,30 +368,19 @@ impl AgentExecutor { ); Err(e)?; } - } + }; }; - let init_stream = futures::stream::once(async move { - Ok(sse::Event::default() - .json_data(initial_message) - // This should never happen, so we force an unwrap. - .expect("failed to serialize initialization object")) - }); - // We know the stream is unwind safe as it doesn't use synchronization primitives like locks. - let answer_stream = AssertUnwindSafe(stream) + let stream = AssertUnwindSafe(stream) .catch_unwind() .map(|res| res.unwrap_or_else(|_| Err(anyhow!("stream panicked")))) - .map(|ex: Result| { + .map(|ex: Result| { sse::Event::default() .json_data(ex.map_err(|e| e.to_string())) .map_err(anyhow::Error::new) }); - let done_stream = futures::stream::once(async { Ok(sse::Event::default().data("[DONE]")) }); - - let stream = init_stream.chain(answer_stream).chain(done_stream); - Ok(Sse::new(Box::pin(stream))) } } @@ -455,9 +465,6 @@ pub async fn explain( let mut conversation = Conversation::new(project_id); conversation.exchanges.push(exchange); - let user_id = user.username().ok_or_else(super::no_user_id)?; - let conversation_id = conversation.store(&app.sql, user_id).await?; - let action = Action::Answer { paths: vec![0] }; AgentExecutor { @@ -467,7 +474,6 @@ pub async fn explain( query_id, project_id, conversation, - conversation_id, action, } .execute()