Skip to content

Commit

Permalink
assistant: Require user to accept TOS for cloud provider (#16111)
Browse files Browse the repository at this point in the history
This adds the requirement for users to accept the terms of service the
first time they send a message with the Cloud provider.

Once this is out and in a nightly, we need to add the check to the
server side too, to authenticate access to the models.

Demo:


https://github.com/user-attachments/assets/0edebf74-8120-4fa2-b801-bb76f04e8a17



Release Notes:

- N/A
  • Loading branch information
mrnugget authored Aug 12, 2024
1 parent 98f314b commit fbb533b
Show file tree
Hide file tree
Showing 14 changed files with 297 additions and 9 deletions.
44 changes: 44 additions & 0 deletions crates/assistant/src/assistant_panel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ impl AssistantPanel {
}
language_model::Event::ProviderStateChanged => {
this.ensure_authenticated(cx);
cx.notify()
}
language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
Expand Down Expand Up @@ -1712,6 +1713,7 @@ pub struct ContextEditor {
assistant_panel: WeakView<AssistantPanel>,
error_message: Option<SharedString>,
debug_inspector: Option<ContextInspector>,
show_accept_terms: bool,
}

const DEFAULT_TAB_TITLE: &str = "New Context";
Expand Down Expand Up @@ -1772,6 +1774,7 @@ impl ContextEditor {
assistant_panel,
error_message: None,
debug_inspector: None,
show_accept_terms: false,
};
this.update_message_headers(cx);
this.insert_slash_command_output_sections(sections, cx);
Expand Down Expand Up @@ -1804,6 +1807,16 @@ impl ContextEditor {
}

fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
let provider = LanguageModelRegistry::read_global(cx).active_provider();
if provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx))
{
self.show_accept_terms = true;
cx.notify();
return;
}

if !self.apply_active_workflow_step(cx) {
self.error_message = None;
self.send_to_model(cx);
Expand Down Expand Up @@ -3388,7 +3401,14 @@ impl ContextEditor {
None => (ButtonStyle::Filled, None),
};

let provider = LanguageModelRegistry::read_global(cx).active_provider();
let disabled = self.show_accept_terms
&& provider
.as_ref()
.map_or(false, |provider| provider.must_accept_terms(cx));

ButtonLike::new("send_button")
.disabled(disabled)
.style(style)
.when_some(tooltip, |button, tooltip| {
button.tooltip(move |_| tooltip.clone())
Expand Down Expand Up @@ -3437,6 +3457,15 @@ impl EventEmitter<SearchEvent> for ContextEditor {}

impl Render for ContextEditor {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let provider = LanguageModelRegistry::read_global(cx).active_provider();
let accept_terms = if self.show_accept_terms {
provider
.as_ref()
.and_then(|provider| provider.render_accept_terms(cx))
} else {
None
};

v_flex()
.key_context("ContextEditor")
.capture_action(cx.listener(ContextEditor::cancel))
Expand All @@ -3455,6 +3484,21 @@ impl Render for ContextEditor {
.bg(cx.theme().colors().editor_background)
.child(self.editor.clone()),
)
.when_some(accept_terms, |this, element| {
this.child(
div()
.absolute()
.right_4()
.bottom_10()
.max_w_96()
.py_2()
.px_3()
.elevation_2(cx)
.bg(cx.theme().colors().surface_background)
.occlude()
.child(element),
)
})
.child(
h_flex().flex_none().relative().child(
h_flex()
Expand Down
7 changes: 7 additions & 0 deletions crates/client/src/test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
use anyhow::{anyhow, Result};
use chrono::Duration;
use futures::{stream::BoxStream, StreamExt};
use gpui::{BackgroundExecutor, Context, Model, TestAppContext};
use parking_lot::Mutex;
Expand Down Expand Up @@ -162,6 +163,11 @@ impl FakeServer {
return Ok(*message.downcast().unwrap());
}

let accepted_tos_at = chrono::Utc::now()
.checked_sub_signed(Duration::hours(5))
.expect("failed to build accepted_tos_at")
.timestamp() as u64;

if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
self.respond(
message
Expand All @@ -172,6 +178,7 @@ impl FakeServer {
metrics_id: "the-metrics-id".into(),
staff: false,
flags: Default::default(),
accepted_tos_at: Some(accepted_tos_at),
},
);
continue;
Expand Down
53 changes: 49 additions & 4 deletions crates/client/src/user.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::{proto, Client, Status, TypedEnvelope};
use anyhow::{anyhow, Context, Result};
use chrono::{DateTime, Utc};
use collections::{hash_map::Entry, HashMap, HashSet};
use feature_flags::FeatureFlagAppExt;
use futures::{channel::mpsc, Future, StreamExt};
Expand Down Expand Up @@ -94,6 +95,7 @@ pub struct UserStore {
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
current_plan: Option<proto::Plan>,
current_user: watch::Receiver<Option<Arc<User>>>,
accepted_tos_at: Option<Option<DateTime<Utc>>>,
contacts: Vec<Arc<Contact>>,
incoming_contact_requests: Vec<Arc<User>>,
outgoing_contact_requests: Vec<Arc<User>>,
Expand Down Expand Up @@ -150,6 +152,7 @@ impl UserStore {
by_github_login: Default::default(),
current_user: current_user_rx,
current_plan: None,
accepted_tos_at: None,
contacts: Default::default(),
incoming_contact_requests: Default::default(),
participant_indices: Default::default(),
Expand Down Expand Up @@ -189,9 +192,10 @@ impl UserStore {
} else {
break;
};
let fetch_metrics_id =
let fetch_private_user_info =
client.request(proto::GetPrivateUserInfo {}).log_err();
let (user, info) = futures::join!(fetch_user, fetch_metrics_id);
let (user, info) =
futures::join!(fetch_user, fetch_private_user_info);

cx.update(|cx| {
if let Some(info) = info {
Expand All @@ -202,9 +206,17 @@ impl UserStore {
client.telemetry.set_authenticated_user_info(
Some(info.metrics_id.clone()),
staff,
)
);

this.update(cx, |this, _| {
this.set_current_user_accepted_tos_at(
info.accepted_tos_at,
);
})
} else {
anyhow::Ok(())
}
})?;
})??;

current_user_tx.send(user).await.ok();

Expand Down Expand Up @@ -680,6 +692,39 @@ impl UserStore {
self.current_user.clone()
}

pub fn current_user_has_accepted_terms(&self) -> Option<bool> {
self.accepted_tos_at
.map(|accepted_tos_at| accepted_tos_at.is_some())
}

pub fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
if self.current_user().is_none() {
return Task::ready(Err(anyhow!("no current user")));
};

let client = self.client.clone();
cx.spawn(move |this, mut cx| async move {
if let Some(client) = client.upgrade() {
let response = client
.request(proto::AcceptTermsOfService {})
.await
.context("error accepting tos")?;

this.update(&mut cx, |this, _| {
this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at))
})
} else {
Err(anyhow!("client not found"))
}
})
}

fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option<u64>) {
self.accepted_tos_at = Some(
accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)),
);
}

fn load_users(
&mut self,
request: impl RequestMessage<Response = UsersResponse>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ CREATE TABLE "users" (
"connected_once" BOOLEAN NOT NULL DEFAULT false,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"metrics_id" TEXT,
"github_user_id" INTEGER
"github_user_id" INTEGER,
"accepted_tos_at" TIMESTAMP WITHOUT TIME ZONE
);
CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login");
CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE users ADD accepted_tos_at TIMESTAMP WITHOUT TIME ZONE;
20 changes: 20 additions & 0 deletions crates/collab/src/db/queries/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,26 @@ impl Database {
.await
}

/// Sets "accepted_tos_at" on the user to the given timestamp.
pub async fn set_user_accepted_tos_at(
&self,
id: UserId,
accepted_tos_at: Option<DateTime>,
) -> Result<()> {
self.transaction(|tx| async move {
user::Entity::update_many()
.filter(user::Column::Id.eq(id))
.set(user::ActiveModel {
accepted_tos_at: ActiveValue::set(accepted_tos_at),
..Default::default()
})
.exec(&*tx)
.await?;
Ok(())
})
.await
}

/// hard delete the user.
pub async fn destroy_user(&self, id: UserId) -> Result<()> {
self.transaction(|tx| async move {
Expand Down
1 change: 1 addition & 0 deletions crates/collab/src/db/tables/user.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub struct Model {
pub connected_once: bool,
pub metrics_id: Uuid,
pub created_at: DateTime,
pub accepted_tos_at: Option<DateTime>,
}

#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
Expand Down
1 change: 1 addition & 0 deletions crates/collab/src/db/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod extension_tests;
mod feature_flag_tests;
mod message_tests;
mod processed_stripe_event_tests;
mod user_tests;

use crate::migrations::run_database_migrations;

Expand Down
45 changes: 45 additions & 0 deletions crates/collab/src/db/tests/user_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use chrono::Utc;

use crate::{
db::{Database, NewUserParams},
test_both_dbs,
};
use std::sync::Arc;

test_both_dbs!(
test_accepted_tos,
test_accepted_tos_postgres,
test_accepted_tos_sqlite
);

async fn test_accepted_tos(db: &Arc<Database>) {
let user_id = db
.create_user(
"user1@example.com",
false,
NewUserParams {
github_login: "user1".to_string(),
github_user_id: 1,
},
)
.await
.unwrap()
.user_id;

let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
assert!(user.accepted_tos_at.is_none());

let accepted_tos_at = Utc::now().naive_utc();
db.set_user_accepted_tos_at(user_id, Some(accepted_tos_at))
.await
.unwrap();

let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
assert!(user.accepted_tos_at.is_some());
assert_eq!(user.accepted_tos_at, Some(accepted_tos_at));

db.set_user_accepted_tos_at(user_id, None).await.unwrap();

let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
assert!(user.accepted_tos_at.is_none());
}
21 changes: 21 additions & 0 deletions crates/collab/src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use axum::{
routing::get,
Extension, Router, TypedHeader,
};
use chrono::Utc;
use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
Expand Down Expand Up @@ -604,6 +605,7 @@ impl Server {
.add_message_handler(user_message_handler(update_followers))
.add_request_handler(user_handler(get_private_user_info))
.add_request_handler(user_handler(get_llm_api_token))
.add_request_handler(user_handler(accept_terms_of_service))
.add_message_handler(user_message_handler(acknowledge_channel_message))
.add_message_handler(user_message_handler(acknowledge_buffer_version))
.add_request_handler(user_handler(get_supermaven_api_key))
Expand Down Expand Up @@ -4882,6 +4884,25 @@ async fn get_private_user_info(
metrics_id,
staff: user.admin,
flags,
accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
})?;
Ok(())
}

/// Accept the terms of service (tos) on behalf of the current user
async fn accept_terms_of_service(
_request: proto::AcceptTermsOfService,
response: Response<proto::AcceptTermsOfService>,
session: UserSession,
) -> Result<()> {
let db = session.db().await;

let accepted_tos_at = Utc::now();
db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
.await?;

response.send(proto::AcceptTermsOfServiceResponse {
accepted_tos_at: accepted_tos_at.timestamp() as u64,
})?;
Ok(())
}
Expand Down
10 changes: 9 additions & 1 deletion crates/language_model/src/language_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ pub mod settings;
use anyhow::Result;
use client::{Client, UserStore};
use futures::{future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext};
use gpui::{
AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
};
pub use model::*;
use project::Fs;
use proto::Plan;
Expand Down Expand Up @@ -114,6 +116,12 @@ pub trait LanguageModelProvider: 'static {
fn is_authenticated(&self, cx: &AppContext) -> bool;
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
fn configuration_view(&self, cx: &mut WindowContext) -> AnyView;
fn must_accept_terms(&self, _cx: &AppContext) -> bool {
false
}
fn render_accept_terms(&self, _cx: &mut WindowContext) -> Option<AnyElement> {
None
}
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
}

Expand Down
Loading

0 comments on commit fbb533b

Please sign in to comment.