diff --git a/server/bleep/src/llm_gateway.rs b/server/bleep/src/llm_gateway.rs index a1597fc405..ba27464b3f 100644 --- a/server/bleep/src/llm_gateway.rs +++ b/server/bleep/src/llm_gateway.rs @@ -8,6 +8,8 @@ use futures::{Stream, StreamExt}; use reqwest_eventsource::EventSource; use tracing::{debug, error, warn}; +use crate::{periodic::sync_github_status_once, Application}; + use self::api::FunctionCall; pub mod api { @@ -213,13 +215,15 @@ impl From<&api::Message> for tiktoken_rs::ChatCompletionRequestMessage { enum ChatError { BadRequest(String), TooManyRequests(String), + InvalidToken, Other(anyhow::Error), } #[derive(Clone)] pub struct Client { http: reqwest::Client, - pub base_url: String, + app: Application, + pub max_retries: u32, pub bearer_token: Option, @@ -234,10 +238,11 @@ pub struct Client { } impl Client { - pub fn new(base_url: &str) -> Self { + pub fn new(app: Application) -> Self { Self { + app, http: reqwest::Client::new(), - base_url: base_url.to_owned(), + max_retries: 5, bearer_token: None, @@ -305,7 +310,10 @@ impl Client { version: semver::Version, ) -> Result { self.http - .get(format!("{}/v1/compatibility", self.base_url)) + .get(format!( + "{}/v1/compatibility", + self.app.config.answer_api_url + )) .query(&[("version", version)]) .send() .await @@ -365,6 +373,10 @@ impl Client { error!("LLM request failed, request not eligible for retry: {body}"); bail!("request failed (not eligible for retry): {body}"); } + Err(ChatError::InvalidToken) => { + warn!("invalid token, retrying LLM request"); + sync_github_status_once(&self.app).await; + } Err(ChatError::Other(e)) => { // We log the messages in a separate `debug!` statement so that they can be // filtered out, due to their verbosity. @@ -387,7 +399,9 @@ impl Client { ) -> Result>, ChatError> { let mut event_source = Box::pin( EventSource::new({ - let mut builder = self.http.post(format!("{}/v2/q", self.base_url)); + let mut builder = self + .http + .post(format!("{}/v2/q", self.app.config.answer_api_url)); if let Some(bearer) = &self.bearer_token { builder = builder.bearer_auth(bearer); @@ -433,6 +447,11 @@ impl Client { warn!("bad request to LLM: {body}"); return Err(ChatError::BadRequest(body)); } + Some(Err(reqwest_eventsource::Error::InvalidStatusCode(status, _))) + if status == StatusCode::UNAUTHORIZED => + { + return Err(ChatError::InvalidToken); + } Some(Err(reqwest_eventsource::Error::InvalidStatusCode(status, response))) if status == StatusCode::TOO_MANY_REQUESTS => { diff --git a/server/bleep/src/periodic/remotes.rs b/server/bleep/src/periodic/remotes.rs index 1a7fa3354d..2a2d9019a3 100644 --- a/server/bleep/src/periodic/remotes.rs +++ b/server/bleep/src/periodic/remotes.rs @@ -70,12 +70,16 @@ pub(crate) async fn sync_github_status(app: Application) { // credentials from CLI/config loop { // then retrieve username & other maintenance - update_credentials(&app).await; - update_repo_list(&app).await; + sync_github_status_once(&app).await; sleep_systime(POLL_PERIOD).await; } } +pub async fn sync_github_status_once(app: &Application) { + update_credentials(app).await; + update_repo_list(app).await; +} + pub(crate) async fn update_repo_list(app: &Application) { if let Some(gh) = app.credentials.github() { let repos = match gh.current_repo_list().await { diff --git a/server/bleep/src/webserver/middleware.rs b/server/bleep/src/webserver/middleware.rs index 3269ab0b4c..9459837fdd 100644 --- a/server/bleep/src/webserver/middleware.rs +++ b/server/bleep/src/webserver/middleware.rs @@ -75,7 +75,7 @@ impl User { } let access_token = self.access_token().map(str::to_owned); - Ok(llm_gateway::Client::new(&app.config.answer_api_url).bearer(access_token)) + Ok(llm_gateway::Client::new(app.clone()).bearer(access_token)) } pub(crate) async fn paid_features(&self, app: &Application) -> bool { diff --git a/server/bleep/src/webserver/quota.rs b/server/bleep/src/webserver/quota.rs index 273ca80e81..33d497ec75 100644 --- a/server/bleep/src/webserver/quota.rs +++ b/server/bleep/src/webserver/quota.rs @@ -1,9 +1,10 @@ use axum::{Extension, Json}; use chrono::{DateTime, Utc}; +use reqwest::StatusCode; use serde::Deserialize; use tracing::error; -use crate::Application; +use crate::{periodic::sync_github_status_once, Application}; use super::{middleware::User, Error}; @@ -53,37 +54,48 @@ async fn get_request Deserialize<'a>>( Extension(user): Extension, endpoint: &str, ) -> super::Result> { + const MAX_RETRIES: usize = 5; + let Some(api_token) = user.access_token() else { return Err(Error::unauthorized("answer API token was not present")); }; - let response = reqwest::Client::new() - .get(format!("{}{}", app.config.answer_api_url, endpoint)) - .bearer_auth(api_token) - .send() - .await - .map_err(Error::internal)?; + for _ in 0..MAX_RETRIES { + let response = reqwest::Client::new() + .get(format!("{}{}", app.config.answer_api_url, endpoint)) + .bearer_auth(api_token) + .send() + .await + .map_err(Error::internal)?; - if response.status().is_success() { - let body = response.text().await.map_err(Error::internal)?; - match serde_json::from_str::(&body) { - Ok(t) => Ok(Json(t)), - Err(_) => Err(Error::internal(format!( - "quota call return invalid JSON: {body}" - ))), - } - } else { - let status = response.status(); - match response.text().await { - Ok(body) if !body.is_empty() => Err(Error::internal(format!( - "request failed with status code {status}: {body}", - ))), - Ok(_) => Err(Error::internal(format!( - "request failed with status code {status}, response had no body", - ))), - Err(_) => Err(Error::internal(format!( - "request failed with status code {status}, failed to retrieve response body", - ))), + if response.status().is_success() { + let body = response.text().await.map_err(Error::internal)?; + return match serde_json::from_str::(&body) { + Ok(t) => Ok(Json(t)), + Err(_) => Err(Error::internal(format!( + "quota call return invalid JSON: {body}" + ))), + }; + } else if response.status() == StatusCode::UNAUTHORIZED { + sync_github_status_once(&app).await; + continue; + } else { + let status = response.status(); + return match response.text().await { + Ok(body) if !body.is_empty() => Err(Error::internal(format!( + "request failed with status code {status}: {body}", + ))), + Ok(_) => Err(Error::internal(format!( + "request failed with status code {status}, response had no body", + ))), + Err(_) => Err(Error::internal(format!( + "request failed with status code {status}, failed to retrieve response body", + ))), + }; } } + + Err(Error::internal( + "failed to make quota request, potentially failed authorization?", + )) }