Skip to content

Commit

Permalink
Use conversation_id instead of thread_id in GET /answer
Browse files Browse the repository at this point in the history
Rather than returning an initial JSON object, we introduce a new
`ChatEvent` type, and return the conversation ID on stream end upon
successful store.
  • Loading branch information
calyptobai committed Dec 19, 2023
1 parent a2baa4c commit ee7e019
Showing 1 changed file with 39 additions and 33 deletions.
72 changes: 39 additions & 33 deletions server/bleep/src/webserver/answer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ pub(super) async fn answer(
None => Conversation::new(project_id),
};

let conversation_id = conversation.store(&app.sql, user_id).await?;

let Answer {
parent_exchange_id,
q,
Expand Down Expand Up @@ -151,7 +149,6 @@ pub(super) async fn answer(
query_id,
project_id,
conversation,
conversation_id,
action,
}
.execute()
Expand All @@ -166,10 +163,22 @@ struct AgentExecutor {
query_id: uuid::Uuid,
project_id: i64,
conversation: Conversation,
conversation_id: i64,
action: Action,
}

#[derive(serde::Serialize)]
enum AnswerEvent {
ChatEvent(Exchange),
StreamEnd(StreamEnd),
}

#[derive(serde::Serialize)]
struct StreamEnd {
thread_id: String,
query_id: uuid::Uuid,
conversation_id: i64,
}

type SseDynStream<T> = Sse<std::pin::Pin<Box<dyn tokio_stream::Stream<Item = T> + Send>>>;

impl AgentExecutor {
Expand Down Expand Up @@ -254,14 +263,12 @@ impl AgentExecutor {
}
};

let initial_message = json!({
"thread_id": self.conversation.thread_id.to_string(),
"query_id": self.query_id,
"conversation_id": self.conversation_id,
});

// let project: Project = serde_json::from_str(&self.params.project).unwrap();
let Answer { agent_model, answer_model, .. } = self.params.clone();
let Answer {
agent_model,
answer_model,
..
} = self.params.clone();

let (exchange_tx, exchange_rx) = tokio::sync::mpsc::channel(10);

Expand All @@ -276,7 +283,7 @@ impl AgentExecutor {
repo_refs,
exchange_state: ExchangeState::Pending,
answer_model,
agent_model
agent_model,
};

let stream = async_stream::try_stream! {
Expand Down Expand Up @@ -304,7 +311,7 @@ impl AgentExecutor {
timeout,
) {
match item {
Ok(Either::Left(exchange)) => yield exchange.compressed(),
Ok(Either::Left(exchange)) => yield AnswerEvent::ChatEvent(exchange.compressed()),
Ok(Either::Right(next_action)) => match next_action {
Ok(n) => break next = n,
Err(e) => break 'outer Err(agent::Error::Processing(e)),
Expand All @@ -319,7 +326,7 @@ impl AgentExecutor {
// of the above loop without ever processing the final message. Here, we empty the
// queue.
while let Some(Some(exchange)) = exchange_rx.next().now_or_never() {
yield exchange.compressed();
yield AnswerEvent::ChatEvent(exchange.compressed());
}

match next {
Expand All @@ -331,7 +338,21 @@ impl AgentExecutor {
agent.complete(result.is_ok());

match result {
Ok(_) => {}
Ok(_) => {
let conversation_id = agent.conversation.store(
&agent.app.sql,
agent.user.username().context("agent failed to get user ID")?,
)
.await?;

let final_message = StreamEnd {
thread_id: agent.conversation.thread_id.to_string(),
query_id: agent.query_id,
conversation_id,
};

yield AnswerEvent::StreamEnd(final_message);
}
Err(agent::Error::Timeout(duration)) => {
warn!("Timeout reached.");
agent.track_query(
Expand All @@ -347,30 +368,19 @@ impl AgentExecutor {
);
Err(e)?;
}
}
};
};

let init_stream = futures::stream::once(async move {
Ok(sse::Event::default()
.json_data(initial_message)
// This should never happen, so we force an unwrap.
.expect("failed to serialize initialization object"))
});

// We know the stream is unwind safe as it doesn't use synchronization primitives like locks.
let answer_stream = AssertUnwindSafe(stream)
let stream = AssertUnwindSafe(stream)
.catch_unwind()
.map(|res| res.unwrap_or_else(|_| Err(anyhow!("stream panicked"))))
.map(|ex: Result<Exchange>| {
.map(|ex: Result<AnswerEvent>| {
sse::Event::default()
.json_data(ex.map_err(|e| e.to_string()))
.map_err(anyhow::Error::new)
});

let done_stream = futures::stream::once(async { Ok(sse::Event::default().data("[DONE]")) });

let stream = init_stream.chain(answer_stream).chain(done_stream);

Ok(Sse::new(Box::pin(stream)))
}
}
Expand Down Expand Up @@ -455,9 +465,6 @@ pub async fn explain(
let mut conversation = Conversation::new(project_id);
conversation.exchanges.push(exchange);

let user_id = user.username().ok_or_else(super::no_user_id)?;
let conversation_id = conversation.store(&app.sql, user_id).await?;

let action = Action::Answer { paths: vec![0] };

AgentExecutor {
Expand All @@ -467,7 +474,6 @@ pub async fn explain(
query_id,
project_id,
conversation,
conversation_id,
action,
}
.execute()
Expand Down

0 comments on commit ee7e019

Please sign in to comment.