From 8688b2ad19032c1e49750a05932cd9e53c91e783 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 9 Aug 2024 15:15:57 -0700 Subject: [PATCH] Add telemetry for LLM usage (#16049) Release Notes: - N/A Co-authored-by: Marshall --- crates/collab/src/llm.rs | 74 ++++++++++++---- crates/collab/src/llm/db/queries/usages.rs | 98 ++++++++++++---------- crates/collab/src/llm/telemetry.rs | 25 ++++++ 3 files changed, 137 insertions(+), 60 deletions(-) create mode 100644 crates/collab/src/llm/telemetry.rs diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 5bab54799cfc8..7d21d070589d4 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -1,8 +1,12 @@ mod authorization; pub mod db; +mod telemetry; mod token; -use crate::{api::CloudflareIpCountryHeader, executor::Executor, Config, Error, Result}; +use crate::{ + api::CloudflareIpCountryHeader, build_clickhouse_client, executor::Executor, Config, Error, + Result, +}; use anyhow::{anyhow, Context as _}; use authorization::authorize_access_to_language_model; use axum::{ @@ -17,12 +21,15 @@ use chrono::{DateTime, Duration, Utc}; use db::{ActiveUserCount, LlmDatabase}; use futures::{Stream, StreamExt as _}; use http_client::IsahcHttpClient; -use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME}; +use rpc::{ + proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME, +}; use std::{ pin::Pin, sync::Arc, task::{Context, Poll}, }; +use telemetry::{report_llm_usage, LlmUsageEventRow}; use tokio::sync::RwLock; use util::ResultExt; @@ -33,6 +40,7 @@ pub struct LlmState { pub executor: Executor, pub db: Arc, pub http_client: IsahcHttpClient, + pub clickhouse_client: Option, active_user_count: RwLock, ActiveUserCount)>>, } @@ -65,11 +73,15 @@ impl LlmState { Some((Utc::now(), db.get_active_user_count(Utc::now()).await?)); let this = Self { - config, executor, db, http_client, + clickhouse_client: config + .clickhouse_url + .as_ref() + .and_then(|_| build_clickhouse_client(&config).log_err()), active_user_count: RwLock::new(initial_active_user_count), + config, }; Ok(Arc::new(this)) @@ -155,8 +167,6 @@ async fn perform_completion( &model, )?; - let user_id = claims.user_id as i32; - check_usage_limit(&state, params.provider, &model, &claims).await?; let stream = match params.provider { @@ -310,9 +320,8 @@ async fn perform_completion( }; Ok(Response::new(Body::wrap_stream(TokenCountingStream { - db: state.db.clone(), - executor: state.executor.clone(), - user_id, + state, + claims, provider: params.provider, model, input_tokens: 0, @@ -403,9 +412,8 @@ async fn check_usage_limit( } struct TokenCountingStream { - db: Arc, - executor: Executor, - user_id: i32, + state: Arc, + claims: LlmTokenClaims, provider: LanguageModelProvider, model: String, input_tokens: usize, @@ -436,15 +444,49 @@ where impl Drop for TokenCountingStream { fn drop(&mut self) { - let db = self.db.clone(); - let user_id = self.user_id; + let state = self.state.clone(); + let claims = self.claims.clone(); let provider = self.provider; let model = std::mem::take(&mut self.model); - let token_count = self.input_tokens + self.output_tokens; - self.executor.spawn_detached(async move { - db.record_usage(user_id, provider, &model, token_count, Utc::now()) + let input_token_count = self.input_tokens; + let output_token_count = self.output_tokens; + self.state.executor.spawn_detached(async move { + let usage = state + .db + .record_usage( + claims.user_id as i32, + provider, + &model, + input_token_count + output_token_count, + Utc::now(), + ) .await .log_err(); + + if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) { + report_llm_usage( + clickhouse_client, + LlmUsageEventRow { + time: Utc::now().timestamp_millis(), + user_id: claims.user_id as i32, + is_staff: claims.is_staff, + plan: match claims.plan { + Plan::Free => "free".to_string(), + Plan::ZedPro => "zed_pro".to_string(), + }, + model, + provider: provider.to_string(), + input_token_count: input_token_count as u64, + output_token_count: output_token_count as u64, + requests_this_minute: usage.requests_this_minute as u64, + tokens_this_minute: usage.tokens_this_minute as u64, + tokens_this_day: usage.tokens_this_day as u64, + tokens_this_month: usage.tokens_this_month as u64, + }, + ) + .await + .log_err(); + } }) } } diff --git a/crates/collab/src/llm/db/queries/usages.rs b/crates/collab/src/llm/db/queries/usages.rs index 6ea7439811b93..108d0e4111f07 100644 --- a/crates/collab/src/llm/db/queries/usages.rs +++ b/crates/collab/src/llm/db/queries/usages.rs @@ -107,7 +107,7 @@ impl LlmDatabase { model_name: &str, token_count: usize, now: DateTimeUtc, - ) -> Result<()> { + ) -> Result { self.transaction(|tx| async move { let model = self.model(provider, model_name)?; @@ -120,48 +120,57 @@ impl LlmDatabase { .all(&*tx) .await?; - self.update_usage_for_measure( - user_id, - model.id, - &usages, - UsageMeasure::RequestsPerMinute, - now, - 1, - &tx, - ) - .await?; - self.update_usage_for_measure( - user_id, - model.id, - &usages, - UsageMeasure::TokensPerMinute, - now, - token_count, - &tx, - ) - .await?; - self.update_usage_for_measure( - user_id, - model.id, - &usages, - UsageMeasure::TokensPerDay, - now, - token_count, - &tx, - ) - .await?; - self.update_usage_for_measure( - user_id, - model.id, - &usages, - UsageMeasure::TokensPerMonth, - now, - token_count, - &tx, - ) - .await?; + let requests_this_minute = self + .update_usage_for_measure( + user_id, + model.id, + &usages, + UsageMeasure::RequestsPerMinute, + now, + 1, + &tx, + ) + .await?; + let tokens_this_minute = self + .update_usage_for_measure( + user_id, + model.id, + &usages, + UsageMeasure::TokensPerMinute, + now, + token_count, + &tx, + ) + .await?; + let tokens_this_day = self + .update_usage_for_measure( + user_id, + model.id, + &usages, + UsageMeasure::TokensPerDay, + now, + token_count, + &tx, + ) + .await?; + let tokens_this_month = self + .update_usage_for_measure( + user_id, + model.id, + &usages, + UsageMeasure::TokensPerMonth, + now, + token_count, + &tx, + ) + .await?; - Ok(()) + Ok(Usage { + requests_this_minute, + tokens_this_minute, + tokens_this_day, + tokens_this_month, + }) }) .await } @@ -205,7 +214,7 @@ impl LlmDatabase { now: DateTimeUtc, usage_to_add: usize, tx: &DatabaseTransaction, - ) -> Result<()> { + ) -> Result { let now = now.naive_utc(); let measure_id = *self .usage_measure_ids @@ -230,6 +239,7 @@ impl LlmDatabase { } *buckets.last_mut().unwrap() += usage_to_add as i64; + let total_usage = buckets.iter().sum::() as usize; let mut model = usage::ActiveModel { user_id: ActiveValue::set(user_id), @@ -249,7 +259,7 @@ impl LlmDatabase { .await?; } - Ok(()) + Ok(total_usage) } fn get_usage_for_measure( diff --git a/crates/collab/src/llm/telemetry.rs b/crates/collab/src/llm/telemetry.rs new file mode 100644 index 0000000000000..941fe9a16de1a --- /dev/null +++ b/crates/collab/src/llm/telemetry.rs @@ -0,0 +1,25 @@ +use anyhow::Result; +use serde::Serialize; + +#[derive(Serialize, Debug, clickhouse::Row)] +pub struct LlmUsageEventRow { + pub time: i64, + pub user_id: i32, + pub is_staff: bool, + pub plan: String, + pub model: String, + pub provider: String, + pub input_token_count: u64, + pub output_token_count: u64, + pub requests_this_minute: u64, + pub tokens_this_minute: u64, + pub tokens_this_day: u64, + pub tokens_this_month: u64, +} + +pub async fn report_llm_usage(client: &clickhouse::Client, row: LlmUsageEventRow) -> Result<()> { + let mut insert = client.insert("llm_usage_events")?; + insert.write(&row).await?; + insert.end().await?; + Ok(()) +}