Skip to content
This repository has been archived by the owner on Jan 2, 2025. It is now read-only.

Commit

Permalink
Retry LLM gateway by default upon mid-stream errors (#955)
Browse files Browse the repository at this point in the history
Now, we retry LLM gateway requests when they encounter mid-stream
errors. The `llm_gateway::Client::chat` method has been updated to
simply return a `String` directly. Previous stream functionality was
moved to `llm_gateway::Client::chat_stream`.

Internally, `chat` will now call `chat_stream` and attempt to build up a
full result buffer. If an error occurs in the initial handshake, the
request will be exponentially retried, as per `chat_stream` behaviour.
If the stream encounters an error mid-collection, the buffer will be
emptied and re-created from scratch with a new LLM request. This will
repeat up to a constant number of times.
  • Loading branch information
calyptobai authored Sep 21, 2023
1 parent fe205dd commit 40a1461
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 50 deletions.
2 changes: 1 addition & 1 deletion server/bleep/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ impl Agent {

let raw_response = self
.llm_gateway
.chat(
.chat_stream(
&trim_history(history.clone(), self.model)?,
Some(&functions),
)
Expand Down
2 changes: 1 addition & 1 deletion server/bleep/src/agent/tools/answer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl Agent {
self.llm_gateway
.clone()
.model(self.model.model_name)
.chat(&messages, None)
.chat_stream(&messages, None)
.await?
);

Expand Down
3 changes: 0 additions & 3 deletions server/bleep/src/agent/tools/code.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use anyhow::Result;
use futures::TryStreamExt;
use tracing::{info, instrument};

use crate::{
Expand Down Expand Up @@ -99,8 +98,6 @@ impl Agent {
.clone()
.model("gpt-3.5-turbo-0613")
.chat(&prompt, None)
.await?
.try_collect::<String>()
.await?;

tracing::trace!("parsing hyde response");
Expand Down
4 changes: 1 addition & 3 deletions server/bleep/src/agent/tools/proc.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::{anyhow, Context, Result};
use futures::{stream, StreamExt, TryStreamExt};
use futures::{stream, StreamExt};
use tiktoken_rs::CoreBPE;
use tracing::{debug, instrument};

Expand Down Expand Up @@ -89,8 +89,6 @@ impl Agent {
// Set low frequency penalty to discourage long outputs.
.frequency_penalty(0.2)
.chat(&[llm_gateway::api::Message::system(&prompt)], None)
.await?
.try_collect::<String>()
.await?;

#[derive(
Expand Down
35 changes: 33 additions & 2 deletions server/bleep/src/llm_gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,44 @@ impl Client {
&self,
messages: &[api::Message],
functions: Option<&[api::Function]>,
) -> anyhow::Result<String> {
const TOTAL_CHAT_RETRIES: usize = 5;

'retry_loop: for _ in 0..TOTAL_CHAT_RETRIES {
let mut buf = String::new();
let stream = self.chat_stream(messages, functions).await?;
tokio::pin!(stream);

loop {
match stream.next().await {
None => break,
Some(Ok(s)) => buf += &s,
Some(Err(e)) => {
warn!(?e, "token stream errored out, retrying...");
continue 'retry_loop;
}
}
}

return Ok(buf);
}

Err(anyhow!(
"chat stream errored too many times, failed to generate response"
))
}

pub async fn chat_stream(
&self,
messages: &[api::Message],
functions: Option<&[api::Function]>,
) -> anyhow::Result<impl Stream<Item = anyhow::Result<String>>> {
const INITIAL_DELAY: Duration = Duration::from_millis(100);
const SCALE_FACTOR: f32 = 1.5;

let mut delay = INITIAL_DELAY;
for _ in 0..self.max_retries {
match self.chat_oneshot(messages, functions).await {
match self.chat_stream_oneshot(messages, functions).await {
Err(ChatError::TooManyRequests) => {
warn!(?delay, "too many LLM requests, retrying with delay...");
tokio::time::sleep(delay).await;
Expand Down Expand Up @@ -324,7 +355,7 @@ impl Client {
}

/// Like `chat`, but without exponential backoff.
async fn chat_oneshot(
async fn chat_stream_oneshot(
&self,
messages: &[api::Message],
functions: Option<&[api::Function]>,
Expand Down
61 changes: 32 additions & 29 deletions server/bleep/src/webserver/quota.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use axum::{Extension, Json};
use chrono::{DateTime, Utc};
use reqwest::StatusCode;
use secrecy::ExposeSecret;
use serde::Deserialize;

use crate::Application;

Expand All @@ -15,50 +16,52 @@ pub struct QuotaResponse {
reset_at: DateTime<Utc>,
}

pub async fn get(app: Extension<Application>) -> super::Result<Json<QuotaResponse>> {
let answer_api_token = app
.answer_api_token()
.map_err(|e| Error::user(e).with_status(StatusCode::UNAUTHORIZED))?
.ok_or_else(|| Error::unauthorized("answer API token was not present"))
.map(|s| s.expose_secret().to_owned())?;

reqwest::Client::new()
.get(format!("{}/v2/get-usage-quota", app.config.answer_api_url))
.bearer_auth(answer_api_token)
.send()
.await
.map_err(Error::internal)?
.json()
.await
.map_err(Error::internal)
.map(Json)
}

#[derive(serde::Deserialize, serde::Serialize)]
pub struct SubscriptionResponse {
url: String,
}

pub async fn get(app: Extension<Application>) -> super::Result<Json<QuotaResponse>> {
get_request(app, "/v2/get-usage-quota").await
}

pub async fn create_checkout_session(
app: Extension<Application>,
) -> super::Result<Json<SubscriptionResponse>> {
get_request(app, "/v2/create-checkout-session").await
}

async fn get_request<T: for<'a> Deserialize<'a>>(
app: Extension<Application>,
endpoint: &str,
) -> super::Result<Json<T>> {
let answer_api_token = app
.answer_api_token()
.map_err(|e| Error::user(e).with_status(StatusCode::UNAUTHORIZED))?
.ok_or_else(|| Error::unauthorized("answer API token was not present"))
.map(|s| s.expose_secret().to_owned())?;

reqwest::Client::new()
.get(format!(
"{}/v2/create-checkout-session",
app.config.answer_api_url
))
let response = reqwest::Client::new()
.get(format!("{}{}", app.config.answer_api_url, endpoint))
.bearer_auth(answer_api_token)
.send()
.await
.map_err(Error::internal)?
.json()
.await
.map_err(Error::internal)
.map(Json)
.map_err(Error::internal)?;

if response.status().is_success() {
response.json().await.map_err(Error::internal).map(Json)
} else {
let status = response.status();
match response.text().await {
Ok(body) if !body.is_empty() => Err(Error::internal(format!(
"request failed with status code {status}: {body}",
))),
Ok(_) => Err(Error::internal(format!(
"request failed with status code {status}, response had no body",
))),
Err(_) => Err(Error::internal(format!(
"request failed with status code {status}, failed to retrieve response body",
))),
}
}
}
15 changes: 4 additions & 11 deletions server/bleep/src/webserver/studio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ pub async fn generate(
.chain(messages.iter().map(llm_gateway::api::Message::from))
.collect::<Vec<_>>();

let tokens = llm_gateway.chat(&llm_messages, None).await?;
let tokens = llm_gateway.chat_stream(&llm_messages, None).await?;

let stream = async_stream::try_stream! {
pin_mut!(tokens);
Expand Down Expand Up @@ -722,11 +722,7 @@ async fn populate_studio_name(
&prompts::studio_name_prompt(&context_json, &messages_json),
)];

let name = llm_gateway
.chat(messages, None)
.await?
.try_collect::<String>()
.await?;
let name = llm_gateway.chat(messages, None).await?;

// Normalize studio name by removing:
// - surrounding whitespace
Expand Down Expand Up @@ -894,11 +890,8 @@ async fn extract_relevant_chunks(
];

// Call the LLM gateway
let response_stream = llm_gateway.chat(&llm_messages, None).await?;

// Collect the response into a string
let result = response_stream
.try_collect()
let result = llm_gateway
.chat(&llm_messages, None)
.await
.and_then(|json: String| serde_json::from_str(&json).map_err(anyhow::Error::new));

Expand Down

0 comments on commit 40a1461

Please sign in to comment.