Skip to content

Commit

Permalink
Remove code paths that skip LLM db in prod (#16008)
Browse files Browse the repository at this point in the history
Release Notes:

- N/A
  • Loading branch information
maxbrunsfeld authored Aug 9, 2024
1 parent c1872e9 commit 225726b
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 54 deletions.
77 changes: 29 additions & 48 deletions crates/collab/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub use token::*;
pub struct LlmState {
pub config: Config,
pub executor: Executor,
pub db: Option<Arc<LlmDatabase>>,
pub db: Arc<LlmDatabase>,
pub http_client: IsahcHttpClient,
active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
}
Expand All @@ -36,37 +36,29 @@ const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);

impl LlmState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
// TODO: This is temporary until we have the LLM database stood up.
let db = if config.is_development() {
let database_url = config
.llm_database_url
.as_ref()
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
let max_connections = config
.llm_database_max_connections
.ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;

let mut db_options = db::ConnectOptions::new(database_url);
db_options.max_connections(max_connections);
let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
db.initialize().await?;

Some(Arc::new(db))
} else {
None
};
let database_url = config
.llm_database_url
.as_ref()
.ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
let max_connections = config
.llm_database_max_connections
.ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;

let mut db_options = db::ConnectOptions::new(database_url);
db_options.max_connections(max_connections);
let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
db.initialize().await?;

let db = Arc::new(db);

let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
let http_client = IsahcHttpClient::builder()
.default_header("User-Agent", user_agent)
.build()
.context("failed to construct http client")?;

let initial_active_user_count = if let Some(db) = &db {
Some((Utc::now(), db.get_active_user_count(Utc::now()).await?))
} else {
None
};
let initial_active_user_count =
Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));

let this = Self {
config,
Expand All @@ -88,14 +80,10 @@ impl LlmState {
}
}

if let Some(db) = &self.db {
let mut cache = self.active_user_count.write().await;
let new_count = db.get_active_user_count(now).await?;
*cache = Some((now, new_count));
Ok(new_count)
} else {
Ok(ActiveUserCount::default())
}
let mut cache = self.active_user_count.write().await;
let new_count = self.db.get_active_user_count(now).await?;
*cache = Some((now, new_count));
Ok(new_count)
}
}

Expand Down Expand Up @@ -165,9 +153,7 @@ async fn perform_completion(

let user_id = claims.user_id as i32;

if state.db.is_some() {
check_usage_limit(&state, params.provider, &model, &claims).await?;
}
check_usage_limit(&state, params.provider, &model, &claims).await?;

match params.provider {
LanguageModelProvider::Anthropic => {
Expand Down Expand Up @@ -199,14 +185,14 @@ async fn perform_completion(
)
.await?;

let mut recorder = state.db.clone().map(|db| UsageRecorder {
db,
let mut recorder = UsageRecorder {
db: state.db.clone(),
executor: state.executor.clone(),
user_id,
provider: params.provider,
model,
token_count: 0,
});
};

let stream = chunks.map(move |event| {
let mut buffer = Vec::new();
Expand All @@ -216,10 +202,8 @@ async fn perform_completion(
message: anthropic::Response { usage, .. },
}
| anthropic::Event::MessageDelta { usage, .. } => {
if let Some(recorder) = &mut recorder {
recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
}
recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
}
_ => {}
}
Expand Down Expand Up @@ -349,12 +333,9 @@ async fn check_usage_limit(
model_name: &str,
claims: &LlmTokenClaims,
) -> Result<()> {
let db = state
let model = state.db.model(provider, model_name)?;
let usage = state
.db
.as_ref()
.ok_or_else(|| anyhow!("LLM database not configured"))?;
let model = db.model(provider, model_name)?;
let usage = db
.get_usage(claims.user_id as i32, provider, model_name, Utc::now())
.await?;

Expand Down
12 changes: 6 additions & 6 deletions crates/collab/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,6 @@ async fn setup_app_database(config: &Config) -> Result<()> {
}

async fn setup_llm_database(config: &Config) -> Result<()> {
// TODO: This is temporary until we have the LLM database stood up.
if !config.is_development() {
return Ok(());
}

let database_url = config
.llm_database_url
.as_ref()
Expand Down Expand Up @@ -298,7 +293,12 @@ async fn handle_liveness_probe(
state.db.get_all_users(0, 1).await?;
}

if let Some(_llm_state) = llm_state {}
if let Some(llm_state) = llm_state {
llm_state
.db
.get_active_user_count(chrono::Utc::now())
.await?;
}

Ok("ok".to_string())
}
Expand Down

0 comments on commit 225726b

Please sign in to comment.