Skip to content

Commit

Permalink
Send llm events to snowflake too (#21091)
Browse files Browse the repository at this point in the history
Closes #ISSUE

Release Notes:

- N/A
  • Loading branch information
ConradIrwin authored Nov 23, 2024
1 parent 5766afe commit 984bb19
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 25 deletions.
8 changes: 8 additions & 0 deletions crates/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,8 @@ impl Client {
let proxy = http.proxy().cloned();
let credentials = credentials.clone();
let rpc_url = self.rpc_url(http, release_channel);
let system_id = self.telemetry.system_id();
let metrics_id = self.telemetry.metrics_id();
cx.background_executor().spawn(async move {
use HttpOrHttps::*;

Expand Down Expand Up @@ -1118,6 +1120,12 @@ impl Client {
"x-zed-release-channel",
HeaderValue::from_str(release_channel.map(|r| r.dev_name()).unwrap_or("unknown"))?,
);
if let Some(system_id) = system_id {
request_headers.insert("x-zed-system-id", HeaderValue::from_str(&system_id)?);
}
if let Some(metrics_id) = metrics_id {
request_headers.insert("x-zed-metrics-id", HeaderValue::from_str(&metrics_id)?);
}

match url_scheme {
Https => {
Expand Down
4 changes: 4 additions & 0 deletions crates/client/src/telemetry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,10 @@ impl Telemetry {
self.state.lock().metrics_id.clone()
}

pub fn system_id(self: &Arc<Self>) -> Option<Arc<str>> {
self.state.lock().system_id.clone()
}

pub fn installation_id(self: &Arc<Self>) -> Option<Arc<str>> {
self.state.lock().installation_id.clone()
}
Expand Down
33 changes: 33 additions & 0 deletions crates/collab/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,39 @@ impl std::fmt::Display for CloudflareIpCountryHeader {
}
}

pub struct SystemIdHeader(String);

impl Header for SystemIdHeader {
fn name() -> &'static HeaderName {
static SYSTEM_ID_HEADER: OnceLock<HeaderName> = OnceLock::new();
SYSTEM_ID_HEADER.get_or_init(|| HeaderName::from_static("x-zed-system-id"))
}

fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i axum::http::HeaderValue>,
{
let system_id = values
.next()
.ok_or_else(axum::headers::Error::invalid)?
.to_str()
.map_err(|_| axum::headers::Error::invalid())?;

Ok(Self(system_id.to_string()))
}

fn encode<E: Extend<axum::http::HeaderValue>>(&self, _values: &mut E) {
unimplemented!()
}
}

impl std::fmt::Display for SystemIdHeader {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}

pub fn routes(rpc_server: Arc<rpc::Server>) -> Router<(), Body> {
Router::new()
.route("/user", get(get_authenticated_user))
Expand Down
43 changes: 41 additions & 2 deletions crates/collab/src/api/events.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1578,8 +1578,8 @@ fn for_snowflake(
})
}

#[derive(Serialize, Deserialize)]
struct SnowflakeRow {
#[derive(Serialize, Deserialize, Debug)]
pub struct SnowflakeRow {
pub time: chrono::DateTime<chrono::Utc>,
pub user_id: Option<String>,
pub device_id: Option<String>,
Expand All @@ -1588,3 +1588,42 @@ struct SnowflakeRow {
pub user_properties: Option<serde_json::Value>,
pub insert_id: Option<String>,
}

impl SnowflakeRow {
pub fn new(
event_type: impl Into<String>,
metrics_id: Option<Uuid>,
is_staff: bool,
system_id: Option<String>,
event_properties: serde_json::Value,
) -> Self {
Self {
time: chrono::Utc::now(),
event_type: event_type.into(),
device_id: system_id,
user_id: metrics_id.map(|id| id.to_string()),
insert_id: Some(uuid::Uuid::new_v4().to_string()),
event_properties,
user_properties: Some(json!({"is_staff": is_staff})),
}
}

pub async fn write(
self,
client: &Option<aws_sdk_kinesis::Client>,
stream: &Option<String>,
) -> anyhow::Result<()> {
let Some((client, stream)) = client.as_ref().zip(stream.as_ref()) else {
return Ok(());
};
let row = serde_json::to_vec(&self)?;
client
.put_record()
.stream_name(stream)
.partition_key(&self.user_id.unwrap_or_default())
.data(row.into())
.send()
.await?;
Ok(())
}
}
3 changes: 3 additions & 0 deletions crates/collab/src/cents.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use serde::Serialize;

/// A number of cents.
#[derive(
Debug,
Expand All @@ -12,6 +14,7 @@
derive_more::AddAssign,
derive_more::Sub,
derive_more::SubAssign,
Serialize,
)]
pub struct Cents(pub u32);

Expand Down
95 changes: 75 additions & 20 deletions crates/collab/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ pub mod db;
mod telemetry;
mod token;

use crate::api::events::SnowflakeRow;
use crate::api::CloudflareIpCountryHeader;
use crate::build_kinesis_client;
use crate::{
api::CloudflareIpCountryHeader, build_clickhouse_client, db::UserId, executor::Executor, Cents,
Config, Error, Result,
build_clickhouse_client, db::UserId, executor::Executor, Cents, Config, Error, Result,
};
use anyhow::{anyhow, Context as _};
use authorization::authorize_access_to_language_model;
Expand All @@ -28,6 +30,7 @@ use rpc::{
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
};
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
use serde_json::json;
use std::{
pin::Pin,
sync::Arc,
Expand All @@ -45,6 +48,7 @@ pub struct LlmState {
pub executor: Executor,
pub db: Arc<LlmDatabase>,
pub http_client: ReqwestClient,
pub kinesis_client: Option<aws_sdk_kinesis::Client>,
pub clickhouse_client: Option<clickhouse::Client>,
active_user_count_by_model:
RwLock<HashMap<(LanguageModelProvider, String), (DateTime<Utc>, ActiveUserCount)>>,
Expand Down Expand Up @@ -77,6 +81,11 @@ impl LlmState {
executor,
db,
http_client,
kinesis_client: if config.kinesis_access_key.is_some() {
build_kinesis_client(&config).await.log_err()
} else {
None
},
clickhouse_client: config
.clickhouse_url
.as_ref()
Expand Down Expand Up @@ -521,25 +530,50 @@ async fn check_usage_limit(
UsageMeasure::TokensPerDay => "tokens_per_day",
};

if let Some(client) = state.clickhouse_client.as_ref() {
tracing::info!(
target: "user rate limit",
user_id = claims.user_id,
login = claims.github_user_login,
authn.jti = claims.jti,
is_staff = claims.is_staff,
provider = provider.to_string(),
model = model.name,
requests_this_minute = usage.requests_this_minute,
tokens_this_minute = usage.tokens_this_minute,
tokens_this_day = usage.tokens_this_day,
users_in_recent_minutes = users_in_recent_minutes,
users_in_recent_days = users_in_recent_days,
max_requests_per_minute = per_user_max_requests_per_minute,
max_tokens_per_minute = per_user_max_tokens_per_minute,
max_tokens_per_day = per_user_max_tokens_per_day,
);
tracing::info!(
target: "user rate limit",
user_id = claims.user_id,
login = claims.github_user_login,
authn.jti = claims.jti,
is_staff = claims.is_staff,
provider = provider.to_string(),
model = model.name,
requests_this_minute = usage.requests_this_minute,
tokens_this_minute = usage.tokens_this_minute,
tokens_this_day = usage.tokens_this_day,
users_in_recent_minutes = users_in_recent_minutes,
users_in_recent_days = users_in_recent_days,
max_requests_per_minute = per_user_max_requests_per_minute,
max_tokens_per_minute = per_user_max_tokens_per_minute,
max_tokens_per_day = per_user_max_tokens_per_day,
);

SnowflakeRow::new(
"Language Model Rate Limited",
claims.metrics_id,
claims.is_staff,
claims.system_id.clone(),
json!({
"usage": usage,
"users_in_recent_minutes": users_in_recent_minutes,
"users_in_recent_days": users_in_recent_days,
"max_requests_per_minute": per_user_max_requests_per_minute,
"max_tokens_per_minute": per_user_max_tokens_per_minute,
"max_tokens_per_day": per_user_max_tokens_per_day,
"plan": match claims.plan {
Plan::Free => "free".to_string(),
Plan::ZedPro => "zed_pro".to_string(),
},
"model": model.name.clone(),
"provider": provider.to_string(),
"usage_measure": resource.to_string(),
}),
)
.write(&state.kinesis_client, &state.config.kinesis_stream)
.await
.log_err();

if let Some(client) = state.clickhouse_client.as_ref() {
report_llm_rate_limit(
client,
LlmRateLimitEventRow {
Expand Down Expand Up @@ -652,6 +686,27 @@ impl<S> Drop for TokenCountingStream<S> {
tokens_this_minute = usage.tokens_this_minute,
);

let properties = json!({
"plan": match claims.plan {
Plan::Free => "free".to_string(),
Plan::ZedPro => "zed_pro".to_string(),
},
"model": model,
"provider": provider,
"usage": usage,
"tokens": tokens
});
SnowflakeRow::new(
"Language Model Used",
claims.metrics_id,
claims.is_staff,
claims.system_id.clone(),
properties,
)
.write(&state.kinesis_client, &state.config.kinesis_stream)
.await
.log_err();

if let Some(clickhouse_client) = state.clickhouse_client.as_ref() {
report_llm_usage(
clickhouse_client,
Expand Down
4 changes: 2 additions & 2 deletions crates/collab/src/llm/db/queries/usages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use strum::IntoEnumIterator as _;

use super::*;

#[derive(Debug, PartialEq, Clone, Copy, Default)]
#[derive(Debug, PartialEq, Clone, Copy, Default, serde::Serialize)]
pub struct TokenUsage {
pub input: usize,
pub input_cache_creation: usize,
Expand All @@ -23,7 +23,7 @@ impl TokenUsage {
}
}

#[derive(Debug, PartialEq, Clone, Copy)]
#[derive(Debug, PartialEq, Clone, Copy, serde::Serialize)]
pub struct Usage {
pub requests_this_minute: usize,
pub tokens_this_minute: usize,
Expand Down
8 changes: 8 additions & 0 deletions crates/collab/src/llm/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::time::Duration;
use thiserror::Error;
use uuid::Uuid;

#[derive(Clone, Debug, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
Expand All @@ -16,6 +17,10 @@ pub struct LlmTokenClaims {
pub exp: u64,
pub jti: String,
pub user_id: u64,
#[serde(default)]
pub system_id: Option<String>,
#[serde(default)]
pub metrics_id: Option<Uuid>,
pub github_user_login: String,
pub is_staff: bool,
pub has_llm_closed_beta_feature_flag: bool,
Expand All @@ -36,6 +41,7 @@ impl LlmTokenClaims {
has_llm_closed_beta_feature_flag: bool,
has_llm_subscription: bool,
plan: rpc::proto::Plan,
system_id: Option<String>,
config: &Config,
) -> Result<String> {
let secret = config
Expand All @@ -49,6 +55,8 @@ impl LlmTokenClaims {
exp: (now + LLM_TOKEN_LIFETIME).timestamp() as u64,
jti: uuid::Uuid::new_v4().to_string(),
user_id: user.id.to_proto(),
system_id,
metrics_id: Some(user.metrics_id),
github_user_login: user.github_login.clone(),
is_staff,
has_llm_closed_beta_feature_flag,
Expand Down
9 changes: 8 additions & 1 deletion crates/collab/src/rpc.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod connection_pool;

use crate::api::CloudflareIpCountryHeader;
use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
use crate::llm::LlmTokenClaims;
use crate::{
auth,
Expand Down Expand Up @@ -137,6 +137,7 @@ struct Session {
/// The GeoIP country code for the user.
#[allow(unused)]
geoip_country_code: Option<String>,
system_id: Option<String>,
_executor: Executor,
}

Expand Down Expand Up @@ -682,6 +683,7 @@ impl Server {
principal: Principal,
zed_version: ZedVersion,
geoip_country_code: Option<String>,
system_id: Option<String>,
send_connection_id: Option<oneshot::Sender<ConnectionId>>,
executor: Executor,
) -> impl Future<Output = ()> {
Expand Down Expand Up @@ -737,6 +739,7 @@ impl Server {
app_state: this.app_state.clone(),
http_client,
geoip_country_code,
system_id,
_executor: executor.clone(),
supermaven_client,
};
Expand Down Expand Up @@ -1056,13 +1059,15 @@ pub fn routes(server: Arc<Server>) -> Router<(), Body> {
.layer(Extension(server))
}

#[allow(clippy::too_many_arguments)]
pub async fn handle_websocket_request(
TypedHeader(ProtocolVersion(protocol_version)): TypedHeader<ProtocolVersion>,
app_version_header: Option<TypedHeader<AppVersionHeader>>,
ConnectInfo(socket_address): ConnectInfo<SocketAddr>,
Extension(server): Extension<Arc<Server>>,
Extension(principal): Extension<Principal>,
country_code_header: Option<TypedHeader<CloudflareIpCountryHeader>>,
system_id_header: Option<TypedHeader<SystemIdHeader>>,
ws: WebSocketUpgrade,
) -> axum::response::Response {
if protocol_version != rpc::PROTOCOL_VERSION {
Expand Down Expand Up @@ -1104,6 +1109,7 @@ pub async fn handle_websocket_request(
principal,
version,
country_code_header.map(|header| header.to_string()),
system_id_header.map(|header| header.to_string()),
None,
Executor::Production,
)
Expand Down Expand Up @@ -4053,6 +4059,7 @@ async fn get_llm_api_token(
has_llm_closed_beta_feature_flag,
has_llm_subscription,
session.current_plan(&db).await?,
session.system_id.clone(),
&session.app_state.config,
)?;
response.send(proto::GetLlmTokenResponse { token })?;
Expand Down
Loading

0 comments on commit 984bb19

Please sign in to comment.