diff --git a/server/bleep/src/agent.rs b/server/bleep/src/agent.rs index 736e8f3353..f9c4a198fb 100644 --- a/server/bleep/src/agent.rs +++ b/server/bleep/src/agent.rs @@ -195,7 +195,7 @@ impl Agent { let raw_response = self .llm_gateway - .chat( + .chat_stream( &trim_history(history.clone(), self.model)?, Some(&functions), ) diff --git a/server/bleep/src/agent/tools/answer.rs b/server/bleep/src/agent/tools/answer.rs index 9846487ef5..ba11a36945 100644 --- a/server/bleep/src/agent/tools/answer.rs +++ b/server/bleep/src/agent/tools/answer.rs @@ -59,7 +59,7 @@ impl Agent { self.llm_gateway .clone() .model(self.model.model_name) - .chat(&messages, None) + .chat_stream(&messages, None) .await? ); diff --git a/server/bleep/src/agent/tools/code.rs b/server/bleep/src/agent/tools/code.rs index e9256434b0..3eebc58446 100644 --- a/server/bleep/src/agent/tools/code.rs +++ b/server/bleep/src/agent/tools/code.rs @@ -1,5 +1,4 @@ use anyhow::Result; -use futures::TryStreamExt; use tracing::{info, instrument}; use crate::{ @@ -99,8 +98,6 @@ impl Agent { .clone() .model("gpt-3.5-turbo-0613") .chat(&prompt, None) - .await? - .try_collect::() .await?; tracing::trace!("parsing hyde response"); diff --git a/server/bleep/src/agent/tools/proc.rs b/server/bleep/src/agent/tools/proc.rs index 24c89b309d..b6eb2a0afa 100644 --- a/server/bleep/src/agent/tools/proc.rs +++ b/server/bleep/src/agent/tools/proc.rs @@ -1,5 +1,5 @@ use anyhow::{anyhow, Context, Result}; -use futures::{stream, StreamExt, TryStreamExt}; +use futures::{stream, StreamExt}; use tiktoken_rs::CoreBPE; use tracing::{debug, instrument}; @@ -89,8 +89,6 @@ impl Agent { // Set low frequency penalty to discourage long outputs. .frequency_penalty(0.2) .chat(&[llm_gateway::api::Message::system(&prompt)], None) - .await? - .try_collect::() .await?; #[derive( diff --git a/server/bleep/src/llm_gateway.rs b/server/bleep/src/llm_gateway.rs index f248c01dce..372a555096 100644 --- a/server/bleep/src/llm_gateway.rs +++ b/server/bleep/src/llm_gateway.rs @@ -290,13 +290,44 @@ impl Client { &self, messages: &[api::Message], functions: Option<&[api::Function]>, + ) -> anyhow::Result { + const TOTAL_CHAT_RETRIES: usize = 5; + + 'retry_loop: for _ in 0..TOTAL_CHAT_RETRIES { + let mut buf = String::new(); + let stream = self.chat_stream(messages, functions).await?; + tokio::pin!(stream); + + loop { + match stream.next().await { + None => break, + Some(Ok(s)) => buf += &s, + Some(Err(e)) => { + warn!(?e, "token stream errored out, retrying..."); + continue 'retry_loop; + } + } + } + + return Ok(buf); + } + + Err(anyhow!( + "chat stream errored too many times, failed to generate response" + )) + } + + pub async fn chat_stream( + &self, + messages: &[api::Message], + functions: Option<&[api::Function]>, ) -> anyhow::Result>> { const INITIAL_DELAY: Duration = Duration::from_millis(100); const SCALE_FACTOR: f32 = 1.5; let mut delay = INITIAL_DELAY; for _ in 0..self.max_retries { - match self.chat_oneshot(messages, functions).await { + match self.chat_stream_oneshot(messages, functions).await { Err(ChatError::TooManyRequests) => { warn!(?delay, "too many LLM requests, retrying with delay..."); tokio::time::sleep(delay).await; @@ -324,7 +355,7 @@ impl Client { } /// Like `chat`, but without exponential backoff. - async fn chat_oneshot( + async fn chat_stream_oneshot( &self, messages: &[api::Message], functions: Option<&[api::Function]>, diff --git a/server/bleep/src/webserver/quota.rs b/server/bleep/src/webserver/quota.rs index 3f5bdab031..9b6d4dbc3d 100644 --- a/server/bleep/src/webserver/quota.rs +++ b/server/bleep/src/webserver/quota.rs @@ -2,6 +2,7 @@ use axum::{Extension, Json}; use chrono::{DateTime, Utc}; use reqwest::StatusCode; use secrecy::ExposeSecret; +use serde::Deserialize; use crate::Application; @@ -15,50 +16,52 @@ pub struct QuotaResponse { reset_at: DateTime, } -pub async fn get(app: Extension) -> super::Result> { - let answer_api_token = app - .answer_api_token() - .map_err(|e| Error::user(e).with_status(StatusCode::UNAUTHORIZED))? - .ok_or_else(|| Error::unauthorized("answer API token was not present")) - .map(|s| s.expose_secret().to_owned())?; - - reqwest::Client::new() - .get(format!("{}/v2/get-usage-quota", app.config.answer_api_url)) - .bearer_auth(answer_api_token) - .send() - .await - .map_err(Error::internal)? - .json() - .await - .map_err(Error::internal) - .map(Json) -} - #[derive(serde::Deserialize, serde::Serialize)] pub struct SubscriptionResponse { url: String, } +pub async fn get(app: Extension) -> super::Result> { + get_request(app, "/v2/get-usage-quota").await +} + pub async fn create_checkout_session( app: Extension, ) -> super::Result> { + get_request(app, "/v2/create-checkout-session").await +} + +async fn get_request Deserialize<'a>>( + app: Extension, + endpoint: &str, +) -> super::Result> { let answer_api_token = app .answer_api_token() .map_err(|e| Error::user(e).with_status(StatusCode::UNAUTHORIZED))? .ok_or_else(|| Error::unauthorized("answer API token was not present")) .map(|s| s.expose_secret().to_owned())?; - reqwest::Client::new() - .get(format!( - "{}/v2/create-checkout-session", - app.config.answer_api_url - )) + let response = reqwest::Client::new() + .get(format!("{}{}", app.config.answer_api_url, endpoint)) .bearer_auth(answer_api_token) .send() .await - .map_err(Error::internal)? - .json() - .await - .map_err(Error::internal) - .map(Json) + .map_err(Error::internal)?; + + if response.status().is_success() { + response.json().await.map_err(Error::internal).map(Json) + } else { + let status = response.status(); + match response.text().await { + Ok(body) if !body.is_empty() => Err(Error::internal(format!( + "request failed with status code {status}: {body}", + ))), + Ok(_) => Err(Error::internal(format!( + "request failed with status code {status}, response had no body", + ))), + Err(_) => Err(Error::internal(format!( + "request failed with status code {status}, failed to retrieve response body", + ))), + } + } } diff --git a/server/bleep/src/webserver/studio.rs b/server/bleep/src/webserver/studio.rs index e43a19da35..91560346df 100644 --- a/server/bleep/src/webserver/studio.rs +++ b/server/bleep/src/webserver/studio.rs @@ -571,7 +571,7 @@ pub async fn generate( .chain(messages.iter().map(llm_gateway::api::Message::from)) .collect::>(); - let tokens = llm_gateway.chat(&llm_messages, None).await?; + let tokens = llm_gateway.chat_stream(&llm_messages, None).await?; let stream = async_stream::try_stream! { pin_mut!(tokens); @@ -722,11 +722,7 @@ async fn populate_studio_name( &prompts::studio_name_prompt(&context_json, &messages_json), )]; - let name = llm_gateway - .chat(messages, None) - .await? - .try_collect::() - .await?; + let name = llm_gateway.chat(messages, None).await?; // Normalize studio name by removing: // - surrounding whitespace @@ -894,11 +890,8 @@ async fn extract_relevant_chunks( ]; // Call the LLM gateway - let response_stream = llm_gateway.chat(&llm_messages, None).await?; - - // Collect the response into a string - let result = response_stream - .try_collect() + let result = llm_gateway + .chat(&llm_messages, None) .await .and_then(|json: String| serde_json::from_str(&json).map_err(anyhow::Error::new));