Skip to content

Commit

Permalink
Add telemetry for LLM usage (#16049)
Browse files Browse the repository at this point in the history
Release Notes:

- N/A

Co-authored-by: Marshall <marshall@zed.dev>
  • Loading branch information
maxbrunsfeld and maxdeviant authored Aug 9, 2024
1 parent 423c7b9 commit 8688b2a
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 60 deletions.
74 changes: 58 additions & 16 deletions crates/collab/src/llm.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
mod authorization;
pub mod db;
mod telemetry;
mod token;

use crate::{api::CloudflareIpCountryHeader, executor::Executor, Config, Error, Result};
use crate::{
api::CloudflareIpCountryHeader, build_clickhouse_client, executor::Executor, Config, Error,
Result,
};
use anyhow::{anyhow, Context as _};
use authorization::authorize_access_to_language_model;
use axum::{
Expand All @@ -17,12 +21,15 @@ use chrono::{DateTime, Duration, Utc};
use db::{ActiveUserCount, LlmDatabase};
use futures::{Stream, StreamExt as _};
use http_client::IsahcHttpClient;
use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
use rpc::{
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
};
use std::{
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use telemetry::{report_llm_usage, LlmUsageEventRow};
use tokio::sync::RwLock;
use util::ResultExt;

Expand All @@ -33,6 +40,7 @@ pub struct LlmState {
pub executor: Executor,
pub db: Arc<LlmDatabase>,
pub http_client: IsahcHttpClient,
pub clickhouse_client: Option<clickhouse::Client>,
active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
}

Expand Down Expand Up @@ -65,11 +73,15 @@ impl LlmState {
Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));

let this = Self {
config,
executor,
db,
http_client,
clickhouse_client: config
.clickhouse_url
.as_ref()
.and_then(|_| build_clickhouse_client(&config).log_err()),
active_user_count: RwLock::new(initial_active_user_count),
config,
};

Ok(Arc::new(this))
Expand Down Expand Up @@ -155,8 +167,6 @@ async fn perform_completion(
&model,
)?;

let user_id = claims.user_id as i32;

check_usage_limit(&state, params.provider, &model, &claims).await?;

let stream = match params.provider {
Expand Down Expand Up @@ -310,9 +320,8 @@ async fn perform_completion(
};

Ok(Response::new(Body::wrap_stream(TokenCountingStream {
db: state.db.clone(),
executor: state.executor.clone(),
user_id,
state,
claims,
provider: params.provider,
model,
input_tokens: 0,
Expand Down Expand Up @@ -403,9 +412,8 @@ async fn check_usage_limit(
}

struct TokenCountingStream<S> {
db: Arc<LlmDatabase>,
executor: Executor,
user_id: i32,
state: Arc<LlmState>,
claims: LlmTokenClaims,
provider: LanguageModelProvider,
model: String,
input_tokens: usize,
Expand Down Expand Up @@ -436,15 +444,49 @@ where

impl<S> Drop for TokenCountingStream<S> {
fn drop(&mut self) {
let db = self.db.clone();
let user_id = self.user_id;
let state = self.state.clone();
let claims = self.claims.clone();
let provider = self.provider;
let model = std::mem::take(&mut self.model);
let token_count = self.input_tokens + self.output_tokens;
self.executor.spawn_detached(async move {
db.record_usage(user_id, provider, &model, token_count, Utc::now())
let input_token_count = self.input_tokens;
let output_token_count = self.output_tokens;
self.state.executor.spawn_detached(async move {
let usage = state
.db
.record_usage(
claims.user_id as i32,
provider,
&model,
input_token_count + output_token_count,
Utc::now(),
)
.await
.log_err();

if let Some((clickhouse_client, usage)) = state.clickhouse_client.as_ref().zip(usage) {
report_llm_usage(
clickhouse_client,
LlmUsageEventRow {
time: Utc::now().timestamp_millis(),
user_id: claims.user_id as i32,
is_staff: claims.is_staff,
plan: match claims.plan {
Plan::Free => "free".to_string(),
Plan::ZedPro => "zed_pro".to_string(),
},
model,
provider: provider.to_string(),
input_token_count: input_token_count as u64,
output_token_count: output_token_count as u64,
requests_this_minute: usage.requests_this_minute as u64,
tokens_this_minute: usage.tokens_this_minute as u64,
tokens_this_day: usage.tokens_this_day as u64,
tokens_this_month: usage.tokens_this_month as u64,
},
)
.await
.log_err();
}
})
}
}
98 changes: 54 additions & 44 deletions crates/collab/src/llm/db/queries/usages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ impl LlmDatabase {
model_name: &str,
token_count: usize,
now: DateTimeUtc,
) -> Result<()> {
) -> Result<Usage> {
self.transaction(|tx| async move {
let model = self.model(provider, model_name)?;

Expand All @@ -120,48 +120,57 @@ impl LlmDatabase {
.all(&*tx)
.await?;

self.update_usage_for_measure(
user_id,
model.id,
&usages,
UsageMeasure::RequestsPerMinute,
now,
1,
&tx,
)
.await?;
self.update_usage_for_measure(
user_id,
model.id,
&usages,
UsageMeasure::TokensPerMinute,
now,
token_count,
&tx,
)
.await?;
self.update_usage_for_measure(
user_id,
model.id,
&usages,
UsageMeasure::TokensPerDay,
now,
token_count,
&tx,
)
.await?;
self.update_usage_for_measure(
user_id,
model.id,
&usages,
UsageMeasure::TokensPerMonth,
now,
token_count,
&tx,
)
.await?;
let requests_this_minute = self
.update_usage_for_measure(
user_id,
model.id,
&usages,
UsageMeasure::RequestsPerMinute,
now,
1,
&tx,
)
.await?;
let tokens_this_minute = self
.update_usage_for_measure(
user_id,
model.id,
&usages,
UsageMeasure::TokensPerMinute,
now,
token_count,
&tx,
)
.await?;
let tokens_this_day = self
.update_usage_for_measure(
user_id,
model.id,
&usages,
UsageMeasure::TokensPerDay,
now,
token_count,
&tx,
)
.await?;
let tokens_this_month = self
.update_usage_for_measure(
user_id,
model.id,
&usages,
UsageMeasure::TokensPerMonth,
now,
token_count,
&tx,
)
.await?;

Ok(())
Ok(Usage {
requests_this_minute,
tokens_this_minute,
tokens_this_day,
tokens_this_month,
})
})
.await
}
Expand Down Expand Up @@ -205,7 +214,7 @@ impl LlmDatabase {
now: DateTimeUtc,
usage_to_add: usize,
tx: &DatabaseTransaction,
) -> Result<()> {
) -> Result<usize> {
let now = now.naive_utc();
let measure_id = *self
.usage_measure_ids
Expand All @@ -230,6 +239,7 @@ impl LlmDatabase {
}

*buckets.last_mut().unwrap() += usage_to_add as i64;
let total_usage = buckets.iter().sum::<i64>() as usize;

let mut model = usage::ActiveModel {
user_id: ActiveValue::set(user_id),
Expand All @@ -249,7 +259,7 @@ impl LlmDatabase {
.await?;
}

Ok(())
Ok(total_usage)
}

fn get_usage_for_measure(
Expand Down
25 changes: 25 additions & 0 deletions crates/collab/src/llm/telemetry.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use anyhow::Result;
use serde::Serialize;

#[derive(Serialize, Debug, clickhouse::Row)]
pub struct LlmUsageEventRow {
pub time: i64,
pub user_id: i32,
pub is_staff: bool,
pub plan: String,
pub model: String,
pub provider: String,
pub input_token_count: u64,
pub output_token_count: u64,
pub requests_this_minute: u64,
pub tokens_this_minute: u64,
pub tokens_this_day: u64,
pub tokens_this_month: u64,
}

pub async fn report_llm_usage(client: &clickhouse::Client, row: LlmUsageEventRow) -> Result<()> {
let mut insert = client.insert("llm_usage_events")?;
insert.write(&row).await?;
insert.end().await?;
Ok(())
}

0 comments on commit 8688b2a

Please sign in to comment.