From 6e1f7c6e1dc97813aa7bd0da6db1aa520bc404cd Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 29 Jul 2024 16:42:08 +0200 Subject: [PATCH] Use tool calling instead of XML parsing to generate edit operations (#15385) Release Notes: - N/A --------- Co-authored-by: Nathan --- Cargo.lock | 13 +- crates/anthropic/src/anthropic.rs | 341 +++++++----- crates/assistant/Cargo.toml | 1 - crates/assistant/src/assistant_panel.rs | 39 +- crates/assistant/src/context.rs | 340 ++++-------- crates/assistant/src/inline_assistant.rs | 511 ++++++++++-------- crates/assistant/src/prompt_library.rs | 42 +- crates/assistant/src/prompts.rs | 13 +- .../src/terminal_inline_assistant.rs | 21 +- crates/collab/src/rpc.rs | 240 +++++--- crates/completion/Cargo.toml | 2 + crates/completion/src/completion.rs | 38 +- crates/language_model/src/language_model.rs | 16 + .../language_model/src/provider/anthropic.rs | 132 ++++- crates/language_model/src/provider/cloud.rs | 119 ++-- crates/language_model/src/provider/fake.rs | 23 +- crates/language_model/src/provider/google.rs | 13 +- crates/language_model/src/provider/ollama.rs | 13 +- crates/language_model/src/provider/open_ai.rs | 13 +- crates/language_model/src/request.rs | 16 +- crates/proto/proto/zed.proto | 45 +- crates/proto/src/proto.rs | 15 +- 22 files changed, 1154 insertions(+), 852 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4322539f7b423..10d230a167f95 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -435,7 +435,6 @@ dependencies = [ "rand 0.8.5", "regex", "rope", - "roxmltree 0.20.0", "schemars", "search", "semantic_index", @@ -2641,7 +2640,9 @@ dependencies = [ "language_model", "project", "rand 0.8.5", + "schemars", "serde", + "serde_json", "settings", "smol", "text", @@ -4237,7 +4238,7 @@ version = "0.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a595cb550439a117696039dfc69830492058211b771a2a165379f2a1a53d84d" dependencies = [ - "roxmltree 0.19.0", + "roxmltree", ] [[package]] @@ -8918,12 +8919,6 @@ version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3cd14fd5e3b777a7422cca79358c57a8f6e3a703d9ac187448d0daf220c2407f" -[[package]] -name = "roxmltree" -version = "0.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c20b6793b5c2fa6553b250154b78d6d0db37e72700ae35fad9387a46f487c97" - [[package]] name = "rpc" version = "0.1.0" @@ -11878,7 +11873,7 @@ dependencies = [ "kurbo", "log", "pico-args", - "roxmltree 0.19.0", + "roxmltree", "simplecss", "siphasher 1.0.1", "strict-num", diff --git a/crates/anthropic/src/anthropic.rs b/crates/anthropic/src/anthropic.rs index 45a4dfc0d3464..c24d19bd1d3b2 100644 --- a/crates/anthropic/src/anthropic.rs +++ b/crates/anthropic/src/anthropic.rs @@ -3,7 +3,7 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest}; use isahc::config::Configurable; use serde::{Deserialize, Serialize}; -use std::{convert::TryFrom, time::Duration}; +use std::time::Duration; use strum::EnumIter; pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com"; @@ -70,112 +70,53 @@ impl Model { } } -#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, -} +pub async fn complete( + client: &dyn HttpClient, + api_url: &str, + api_key: &str, + request: Request, +) -> Result { + let uri = format!("{api_url}/v1/messages"); + let request_builder = HttpRequest::builder() + .method(Method::POST) + .uri(uri) + .header("Anthropic-Version", "2023-06-01") + .header("Anthropic-Beta", "tools-2024-04-04") + .header("X-Api-Key", api_key) + .header("Content-Type", "application/json"); -impl TryFrom for Role { - type Error = anyhow::Error; + let serialized_request = serde_json::to_string(&request)?; + let request = request_builder.body(AsyncBody::from(serialized_request))?; - fn try_from(value: String) -> Result { - match value.as_str() { - "user" => Ok(Self::User), - "assistant" => Ok(Self::Assistant), - _ => Err(anyhow!("invalid role '{value}'")), - } - } -} - -impl From for String { - fn from(val: Role) -> Self { - match val { - Role::User => "user".to_owned(), - Role::Assistant => "assistant".to_owned(), - } + let mut response = client.send(request).await?; + if response.status().is_success() { + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + let response_message: Response = serde_json::from_slice(&body)?; + Ok(response_message) + } else { + let mut body = Vec::new(); + response.body_mut().read_to_end(&mut body).await?; + let body_str = std::str::from_utf8(&body)?; + Err(anyhow!( + "Failed to connect to API: {} {}", + response.status(), + body_str + )) } } -#[derive(Debug, Serialize, Deserialize)] -pub struct Request { - pub model: String, - pub messages: Vec, - pub stream: bool, - pub system: String, - pub max_tokens: u32, -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct RequestMessage { - pub role: Role, - pub content: String, -} - -#[derive(Deserialize, Serialize, Debug)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum ResponseEvent { - MessageStart { - message: ResponseMessage, - }, - ContentBlockStart { - index: u32, - content_block: ContentBlock, - }, - Ping {}, - ContentBlockDelta { - index: u32, - delta: TextDelta, - }, - ContentBlockStop { - index: u32, - }, - MessageDelta { - delta: ResponseMessage, - usage: Usage, - }, - MessageStop {}, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct ResponseMessage { - #[serde(rename = "type")] - pub message_type: Option, - pub id: Option, - pub role: Option, - pub content: Option>, - pub model: Option, - pub stop_reason: Option, - pub stop_sequence: Option, - pub usage: Option, -} - -#[derive(Serialize, Deserialize, Debug)] -pub struct Usage { - pub input_tokens: Option, - pub output_tokens: Option, -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum ContentBlock { - Text { text: String }, -} - -#[derive(Serialize, Deserialize, Debug)] -#[serde(tag = "type", rename_all = "snake_case")] -pub enum TextDelta { - TextDelta { text: String }, -} - pub async fn stream_completion( client: &dyn HttpClient, api_url: &str, api_key: &str, request: Request, low_speed_timeout: Option, -) -> Result>> { +) -> Result>> { + let request = StreamingRequest { + base: request, + stream: true, + }; let uri = format!("{api_url}/v1/messages"); let mut request_builder = HttpRequest::builder() .method(Method::POST) @@ -187,7 +128,9 @@ pub async fn stream_completion( if let Some(low_speed_timeout) = low_speed_timeout { request_builder = request_builder.low_speed_timeout(100, low_speed_timeout); } - let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?; + let serialized_request = serde_json::to_string(&request)?; + let request = request_builder.body(AsyncBody::from(serialized_request))?; + let mut response = client.send(request).await?; if response.status().is_success() { let reader = BufReader::new(response.into_body()); @@ -212,7 +155,7 @@ pub async fn stream_completion( let body_str = std::str::from_utf8(&body)?; - match serde_json::from_str::(body_str) { + match serde_json::from_str::(body_str) { Ok(_) => Err(anyhow!( "Unexpected success response while expecting an error: {}", body_str, @@ -227,16 +170,18 @@ pub async fn stream_completion( } pub fn extract_text_from_events( - response: impl Stream>, + response: impl Stream>, ) -> impl Stream> { response.filter_map(|response| async move { match response { Ok(response) => match response { - ResponseEvent::ContentBlockStart { content_block, .. } => match content_block { - ContentBlock::Text { text } => Some(Ok(text)), + Event::ContentBlockStart { content_block, .. } => match content_block { + Content::Text { text } => Some(Ok(text)), + _ => None, }, - ResponseEvent::ContentBlockDelta { delta, .. } => match delta { - TextDelta::TextDelta { text } => Some(Ok(text)), + Event::ContentBlockDelta { delta, .. } => match delta { + ContentDelta::TextDelta { text } => Some(Ok(text)), + _ => None, }, _ => None, }, @@ -245,42 +190,162 @@ pub fn extract_text_from_events( }) } -// #[cfg(test)] -// mod tests { -// use super::*; -// use http::IsahcHttpClient; +#[derive(Debug, Serialize, Deserialize)] +pub struct Message { + pub role: Role, + pub content: Vec, +} -// #[tokio::test] -// async fn stream_completion_success() { -// let http_client = IsahcHttpClient::new().unwrap(); +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, +} -// let request = Request { -// model: Model::Claude3Opus, -// messages: vec![RequestMessage { -// role: Role::User, -// content: "Ping".to_string(), -// }], -// stream: true, -// system: "Respond to ping with pong".to_string(), -// max_tokens: 4096, -// }; +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum Content { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image")] + Image { source: ImageSource }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + #[serde(rename = "tool_result")] + ToolResult { + tool_use_id: String, + content: String, + }, +} -// let stream = stream_completion( -// &http_client, -// "https://api.anthropic.com", -// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"), -// request, -// ) -// .await -// .unwrap(); +#[derive(Debug, Serialize, Deserialize)] +pub struct ImageSource { + #[serde(rename = "type")] + pub source_type: String, + pub media_type: String, + pub data: String, +} -// stream -// .for_each(|event| async { -// match event { -// Ok(event) => println!("{:?}", event), -// Err(e) => eprintln!("Error: {:?}", e), -// } -// }) -// .await; -// } -// } +#[derive(Debug, Serialize, Deserialize)] +pub struct Tool { + pub name: String, + pub description: String, + pub input_schema: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum ToolChoice { + Auto, + Any, + Tool { name: String }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Request { + pub model: String, + pub max_tokens: u32, + pub messages: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub tools: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub system: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub stop_sequences: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub top_p: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct StreamingRequest { + #[serde(flatten)] + pub base: Request, + pub stream: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Metadata { + pub user_id: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Usage { + #[serde(default, skip_serializing_if = "Option::is_none")] + pub input_tokens: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub output_tokens: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Response { + pub id: String, + #[serde(rename = "type")] + pub response_type: String, + pub role: Role, + pub content: Vec, + pub model: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub stop_reason: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub stop_sequence: Option, + pub usage: Usage, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum Event { + #[serde(rename = "message_start")] + MessageStart { message: Response }, + #[serde(rename = "content_block_start")] + ContentBlockStart { + index: usize, + content_block: Content, + }, + #[serde(rename = "content_block_delta")] + ContentBlockDelta { index: usize, delta: ContentDelta }, + #[serde(rename = "content_block_stop")] + ContentBlockStop { index: usize }, + #[serde(rename = "message_delta")] + MessageDelta { delta: MessageDelta, usage: Usage }, + #[serde(rename = "message_stop")] + MessageStop, + #[serde(rename = "ping")] + Ping, + #[serde(rename = "error")] + Error { error: ApiError }, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ContentDelta { + #[serde(rename = "text_delta")] + TextDelta { text: String }, + #[serde(rename = "input_json_delta")] + InputJsonDelta { partial_json: String }, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MessageDelta { + pub stop_reason: Option, + pub stop_sequence: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ApiError { + #[serde(rename = "type")] + pub error_type: String, + pub message: String, +} diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index a284b289a8e14..6e7bef88ca1ef 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -75,7 +75,6 @@ util.workspace = true uuid.workspace = true workspace.workspace = true picker.workspace = true -roxmltree = "0.20.0" [dev-dependencies] completion = { workspace = true, features = ["test-support"] } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index b8121f88fba74..ec72cbd1ff754 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1232,12 +1232,16 @@ impl ContextEditor { fn apply_edit_step(&mut self, cx: &mut ViewContext) -> bool { if let Some(step) = self.active_edit_step.as_ref() { - InlineAssistant::update_global(cx, |assistant, cx| { - for assist_id in &step.assist_ids { - assistant.start_assist(*assist_id, cx); - } - !step.assist_ids.is_empty() - }) + let assist_ids = step.assist_ids.clone(); + cx.window_context().defer(|cx| { + InlineAssistant::update_global(cx, |assistant, cx| { + for assist_id in assist_ids { + assistant.start_assist(assist_id, cx); + } + }) + }); + + !step.assist_ids.is_empty() } else { false } @@ -1286,11 +1290,7 @@ impl ContextEditor { .collect::() )); match &step.operations { - Some(EditStepOperations::Parsed { - operations, - raw_output, - }) => { - output.push_str(&format!("Raw Output:\n{raw_output}\n")); + Some(EditStepOperations::Ready(operations)) => { output.push_str("Parsed Operations:\n"); for op in operations { output.push_str(&format!(" {:?}\n", op)); @@ -1794,13 +1794,12 @@ impl ContextEditor { .anchor_in_excerpt(excerpt_id, suggestion.range.end) .unwrap() }; - let initial_text = suggestion.prepend_newline.then(|| "\n".into()); InlineAssistant::update_global(cx, |assistant, cx| { assist_ids.push(assistant.suggest_assist( &editor, range, description, - initial_text, + suggestion.initial_insertion, Some(workspace.clone()), assistant_panel.upgrade().as_ref(), cx, @@ -1862,9 +1861,11 @@ impl ContextEditor { .anchor_in_excerpt(excerpt_id, suggestion.range.end) .unwrap() }; - let initial_text = - suggestion.prepend_newline.then(|| "\n".to_string()); - inline_assist_suggestions.push((range, description, initial_text)); + inline_assist_suggestions.push(( + range, + description, + suggestion.initial_insertion, + )); } } } @@ -1875,12 +1876,12 @@ impl ContextEditor { .new_view(|cx| Editor::for_multibuffer(multibuffer, Some(project), true, cx))?; cx.update(|cx| { InlineAssistant::update_global(cx, |assistant, cx| { - for (range, description, initial_text) in inline_assist_suggestions { + for (range, description, initial_insertion) in inline_assist_suggestions { assist_ids.push(assistant.suggest_assist( &editor, range, description, - initial_text, + initial_insertion, Some(workspace.clone()), assistant_panel.upgrade().as_ref(), cx, @@ -2188,7 +2189,7 @@ impl ContextEditor { let button_text = match self.edit_step_for_cursor(cx) { Some(edit_step) => match &edit_step.operations { Some(EditStepOperations::Pending(_)) => "Computing Changes...", - Some(EditStepOperations::Parsed { .. }) => "Apply Changes", + Some(EditStepOperations::Ready(_)) => "Apply Changes", None => "Send", }, None => "Send", diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 0fa5a894d6ef8..217fe9ca0f428 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -1,6 +1,6 @@ use crate::{ - prompt_library::PromptStore, slash_command::SlashCommandLine, LanguageModelCompletionProvider, - MessageId, MessageStatus, + prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion, + LanguageModelCompletionProvider, MessageId, MessageStatus, }; use anyhow::{anyhow, Context as _, Result}; use assistant_slash_command::{ @@ -18,11 +18,11 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip use language::{ AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset, }; -use language_model::LanguageModelRequestMessage; -use language_model::{LanguageModelRequest, Role}; +use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role}; use open_ai::Model as OpenAiModel; use paths::contexts_dir; use project::Project; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::{ cmp, @@ -352,7 +352,7 @@ pub struct EditSuggestion { pub range: Range, /// If None, assume this is a suggestion to delete the range rather than transform it. pub description: Option, - pub prepend_newline: bool, + pub initial_insertion: Option, } impl EditStep { @@ -361,7 +361,7 @@ impl EditStep { project: &Model, cx: &AppContext, ) -> Task, Vec>> { - let Some(EditStepOperations::Parsed { operations, .. }) = &self.operations else { + let Some(EditStepOperations::Ready(operations)) = &self.operations else { return Task::ready(HashMap::default()); }; @@ -471,32 +471,28 @@ impl EditStep { } pub enum EditStepOperations { - Pending(Task>), - Parsed { - operations: Vec, - raw_output: String, - }, + Pending(Task>), + Ready(Vec), } impl Debug for EditStepOperations { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"), - EditStepOperations::Parsed { - operations, - raw_output, - } => f + EditStepOperations::Ready(operations) => f .debug_struct("EditStepOperations::Parsed") .field("operations", operations) - .field("raw_output", raw_output) .finish(), } } } -#[derive(Clone, Debug, PartialEq, Eq)] +/// A description of an operation to apply to one location in the codebase. +#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)] pub struct EditOperation { + /// The path to the file containing the relevant operation pub path: String, + #[serde(flatten)] pub kind: EditOperationKind, } @@ -523,7 +519,7 @@ impl EditOperation { parse_status.changed().await?; } - let prepend_newline = kind.prepend_newline(); + let initial_insertion = kind.initial_insertion(); let suggestion_range = if let Some(symbol) = kind.symbol() { let outline = buffer .update(&mut cx, |buffer, _| buffer.snapshot().outline(None))? @@ -601,39 +597,61 @@ impl EditOperation { EditSuggestion { range: suggestion_range, description: kind.description().map(ToString::to_string), - prepend_newline, + initial_insertion, }, )) }) } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)] +#[serde(tag = "kind")] pub enum EditOperationKind { + /// Rewrite the specified symbol in its entirely based on the given description. Update { + /// A full path to the symbol to be rewritten from the provided list. symbol: String, + /// A brief one-line description of the change that should be applied. description: String, }, + /// Create a new file with the given path based on the given description. Create { + /// A brief one-line description of the change that should be applied. description: String, }, + /// Insert a new symbol based on the given description before the specified symbol. InsertSiblingBefore { + /// A full path to the symbol to be rewritten from the provided list. symbol: String, + /// A brief one-line description of the change that should be applied. description: String, }, + /// Insert a new symbol based on the given description after the specified symbol. InsertSiblingAfter { + /// A full path to the symbol to be rewritten from the provided list. symbol: String, + /// A brief one-line description of the change that should be applied. description: String, }, + /// Insert a new symbol as a child of the specified symbol at the start. PrependChild { + /// An optional full path to the symbol to be rewritten from the provided list. + /// If not provided, the edit should be applied at the top of the file. symbol: Option, + /// A brief one-line description of the change that should be applied. description: String, }, + /// Insert a new symbol as a child of the specified symbol at the end. AppendChild { + /// An optional full path to the symbol to be rewritten from the provided list. + /// If not provided, the edit should be applied at the top of the file. symbol: Option, + /// A brief one-line description of the change that should be applied. description: String, }, + /// Delete the specified symbol. Delete { + /// A full path to the symbol to be rewritten from the provided list. symbol: String, }, } @@ -663,13 +681,13 @@ impl EditOperationKind { } } - pub fn prepend_newline(&self) -> bool { + pub fn initial_insertion(&self) -> Option { match self { - Self::PrependChild { .. } - | Self::AppendChild { .. } - | Self::InsertSiblingAfter { .. } - | Self::InsertSiblingBefore { .. } => true, - _ => false, + EditOperationKind::InsertSiblingBefore { .. } => Some(InitialInsertion::NewlineAfter), + EditOperationKind::InsertSiblingAfter { .. } => Some(InitialInsertion::NewlineBefore), + EditOperationKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter), + EditOperationKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore), + _ => None, } } } @@ -1137,18 +1155,15 @@ impl Context { .timer(Duration::from_millis(200)) .await; - if let Some(token_count) = cx.update(|cx| { - LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx) - })? { - let token_count = token_count.await?; - - this.update(&mut cx, |this, cx| { - this.token_count = Some(token_count); - cx.notify() - })?; - } - - anyhow::Ok(()) + let token_count = cx + .update(|cx| { + LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx) + })? + .await?; + this.update(&mut cx, |this, cx| { + this.token_count = Some(token_count); + cx.notify() + }) } .log_err() }); @@ -1304,7 +1319,24 @@ impl Context { &self, edit_step: &EditStep, cx: &mut ModelContext, - ) -> Task> { + ) -> Task> { + #[derive(Debug, Deserialize, JsonSchema)] + struct EditTool { + /// A sequence of operations to apply to the codebase. + /// When multiple operations are required for a step, be sure to include multiple operations in this list. + operations: Vec, + } + + impl LanguageModelTool for EditTool { + fn name() -> String { + "edit".into() + } + + fn description() -> String { + "suggest edits to one or more locations in the codebase".into() + } + } + let mut request = self.to_completion_request(cx); let edit_step_range = edit_step.source_range.clone(); let step_text = self @@ -1313,160 +1345,41 @@ impl Context { .text_for_range(edit_step_range.clone()) .collect::(); - cx.spawn(|this, mut cx| async move { - let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?; - - let mut prompt = prompt_store.operations_prompt(); - prompt.push_str(&step_text); + cx.spawn(|this, mut cx| { + async move { + let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?; - request.messages.push(LanguageModelRequestMessage { - role: Role::User, - content: prompt, - }); + let mut prompt = prompt_store.operations_prompt(); + prompt.push_str(&step_text); - let raw_output = cx - .update(|cx| { - LanguageModelCompletionProvider::read_global(cx).complete(request, cx) - })? - .await?; + request.messages.push(LanguageModelRequestMessage { + role: Role::User, + content: prompt, + }); - let operations = Self::parse_edit_operations(&raw_output); - this.update(&mut cx, |this, cx| { - let step_index = this - .edit_steps - .binary_search_by(|step| { - step.source_range - .cmp(&edit_step_range, this.buffer.read(cx)) - }) - .map_err(|_| anyhow!("edit step not found"))?; - if let Some(edit_step) = this.edit_steps.get_mut(step_index) { - edit_step.operations = Some(EditStepOperations::Parsed { - operations, - raw_output, - }); - cx.emit(ContextEvent::EditStepsChanged); - } - anyhow::Ok(()) - })? - }) - } + let tool_use = cx + .update(|cx| { + LanguageModelCompletionProvider::read_global(cx) + .use_tool::(request, cx) + })? + .await?; - fn parse_edit_operations(xml: &str) -> Vec { - let Some(start_ix) = xml.find("") else { - return Vec::new(); - }; - let Some(end_ix) = xml[start_ix..].find("") else { - return Vec::new(); - }; - let end_ix = end_ix + start_ix + "".len(); - - let doc = roxmltree::Document::parse(&xml[start_ix..end_ix]).log_err(); - doc.map_or(Vec::new(), |doc| { - doc.root_element() - .children() - .map(|node| { - let tag_name = node.tag_name().name(); - let path = node - .attribute("path") - .with_context(|| { - format!("invalid node {node:?}, missing attribute 'path'") - })? - .to_string(); - let kind = match tag_name { - "update" => EditOperationKind::Update { - symbol: node - .attribute("symbol") - .with_context(|| { - format!("invalid node {node:?}, missing attribute 'symbol'") - })? - .to_string(), - description: node - .attribute("description") - .with_context(|| { - format!( - "invalid node {node:?}, missing attribute 'description'" - ) - })? - .to_string(), - }, - "create" => EditOperationKind::Create { - description: node - .attribute("description") - .with_context(|| { - format!( - "invalid node {node:?}, missing attribute 'description'" - ) - })? - .to_string(), - }, - "insert_sibling_after" => EditOperationKind::InsertSiblingAfter { - symbol: node - .attribute("symbol") - .with_context(|| { - format!("invalid node {node:?}, missing attribute 'symbol'") - })? - .to_string(), - description: node - .attribute("description") - .with_context(|| { - format!( - "invalid node {node:?}, missing attribute 'description'" - ) - })? - .to_string(), - }, - "insert_sibling_before" => EditOperationKind::InsertSiblingBefore { - symbol: node - .attribute("symbol") - .with_context(|| { - format!("invalid node {node:?}, missing attribute 'symbol'") - })? - .to_string(), - description: node - .attribute("description") - .with_context(|| { - format!( - "invalid node {node:?}, missing attribute 'description'" - ) - })? - .to_string(), - }, - "prepend_child" => EditOperationKind::PrependChild { - symbol: node.attribute("symbol").map(String::from), - description: node - .attribute("description") - .with_context(|| { - format!( - "invalid node {node:?}, missing attribute 'description'" - ) - })? - .to_string(), - }, - "append_child" => EditOperationKind::AppendChild { - symbol: node.attribute("symbol").map(String::from), - description: node - .attribute("description") - .with_context(|| { - format!( - "invalid node {node:?}, missing attribute 'description'" - ) - })? - .to_string(), - }, - "delete" => EditOperationKind::Delete { - symbol: node - .attribute("symbol") - .with_context(|| { - format!("invalid node {node:?}, missing attribute 'symbol'") - })? - .to_string(), - }, - _ => return Err(anyhow!("invalid node {node:?}")), - }; - anyhow::Ok(EditOperation { path, kind }) - }) - .filter_map(|op| op.log_err()) - .collect() + this.update(&mut cx, |this, cx| { + let step_index = this + .edit_steps + .binary_search_by(|step| { + step.source_range + .cmp(&edit_step_range, this.buffer.read(cx)) + }) + .map_err(|_| anyhow!("edit step not found"))?; + if let Some(edit_step) = this.edit_steps.get_mut(step_index) { + edit_step.operations = Some(EditStepOperations::Ready(tool_use.operations)); + cx.emit(ContextEvent::EditStepsChanged); + } + anyhow::Ok(()) + })? + } + .log_err() }) } @@ -3083,55 +2996,6 @@ mod tests { } } - #[test] - fn test_parse_edit_operations() { - let operations = indoc! {r#" - Here are the operations to make all fields of the Canvas struct private: - - - - - - - - "#}; - - let parsed_operations = Context::parse_edit_operations(operations); - assert_eq!( - parsed_operations, - vec![ - EditOperation { - path: "font-kit/src/canvas.rs".to_string(), - kind: EditOperationKind::Update { - symbol: "pub struct Canvas pub pixels".to_string(), - description: "Remove pub keyword from pixels field".to_string(), - }, - }, - EditOperation { - path: "font-kit/src/canvas.rs".to_string(), - kind: EditOperationKind::Update { - symbol: "pub struct Canvas pub size".to_string(), - description: "Remove pub keyword from size field".to_string(), - }, - }, - EditOperation { - path: "font-kit/src/canvas.rs".to_string(), - kind: EditOperationKind::Update { - symbol: "pub struct Canvas pub stride".to_string(), - description: "Remove pub keyword from stride field".to_string(), - }, - }, - EditOperation { - path: "font-kit/src/canvas.rs".to_string(), - kind: EditOperationKind::Update { - symbol: "pub struct Canvas pub format".to_string(), - description: "Remove pub keyword from format field".to_string(), - }, - }, - ] - ); - } - #[gpui::test] async fn test_serialization(cx: &mut TestAppContext) { let settings_store = cx.update(SettingsStore::test); diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index 88a8382a9731a..bedd7e610f656 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -17,7 +17,7 @@ use editor::{ use fs::Fs; use futures::{ channel::mpsc, - future::LocalBoxFuture, + future::{BoxFuture, LocalBoxFuture}, stream::{self, BoxStream}, SinkExt, Stream, StreamExt, }; @@ -36,7 +36,7 @@ use similar::TextDiff; use smol::future::FutureExt; use std::{ cmp, - future::Future, + future::{self, Future}, mem, ops::{Range, RangeInclusive}, pin::Pin, @@ -46,7 +46,7 @@ use std::{ }; use theme::ThemeSettings; use ui::{prelude::*, IconButtonShape, Tooltip}; -use util::RangeExt; +use util::{RangeExt, ResultExt}; use workspace::{notifications::NotificationId, Toast, Workspace}; pub fn init(fs: Arc, telemetry: Arc, cx: &mut AppContext) { @@ -187,7 +187,13 @@ impl InlineAssistant { let [prompt_block_id, end_block_id] = self.insert_assist_blocks(editor, &range, &prompt_editor, cx); - assists.push((assist_id, prompt_editor, prompt_block_id, end_block_id)); + assists.push(( + assist_id, + range, + prompt_editor, + prompt_block_id, + end_block_id, + )); } let editor_assists = self @@ -195,7 +201,7 @@ impl InlineAssistant { .entry(editor.downgrade()) .or_insert_with(|| EditorInlineAssists::new(&editor, cx)); let mut assist_group = InlineAssistGroup::new(); - for (assist_id, prompt_editor, prompt_block_id, end_block_id) in assists { + for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists { self.assists.insert( assist_id, InlineAssist::new( @@ -206,6 +212,7 @@ impl InlineAssistant { &prompt_editor, prompt_block_id, end_block_id, + range, prompt_editor.read(cx).codegen.clone(), workspace.clone(), cx, @@ -227,7 +234,7 @@ impl InlineAssistant { editor: &View, mut range: Range, initial_prompt: String, - initial_insertion: Option, + initial_insertion: Option, workspace: Option>, assistant_panel: Option<&View>, cx: &mut WindowContext, @@ -239,22 +246,30 @@ impl InlineAssistant { let assist_id = self.next_assist_id.post_inc(); let buffer = editor.read(cx).buffer().clone(); - let prepend_transaction_id = initial_insertion.and_then(|initial_insertion| { - buffer.update(cx, |buffer, cx| { - buffer.start_transaction(cx); - buffer.edit([(range.start..range.start, initial_insertion)], None, cx); - buffer.end_transaction(cx) - }) - }); + { + let snapshot = buffer.read(cx).read(cx); + + let mut point_range = range.to_point(&snapshot); + if point_range.is_empty() { + point_range.start.column = 0; + point_range.end.column = 0; + } else { + point_range.start.column = 0; + if point_range.end.row > point_range.start.row && point_range.end.column == 0 { + point_range.end.row -= 1; + } + point_range.end.column = snapshot.line_len(MultiBufferRow(point_range.end.row)); + } - range.start = range.start.bias_left(&buffer.read(cx).read(cx)); - range.end = range.end.bias_right(&buffer.read(cx).read(cx)); + range.start = snapshot.anchor_before(point_range.start); + range.end = snapshot.anchor_after(point_range.end); + } let codegen = cx.new_model(|cx| { Codegen::new( editor.read(cx).buffer().clone(), range.clone(), - prepend_transaction_id, + initial_insertion, self.telemetry.clone(), cx, ) @@ -295,6 +310,7 @@ impl InlineAssistant { &prompt_editor, prompt_block_id, end_block_id, + range, prompt_editor.read(cx).codegen.clone(), workspace.clone(), cx, @@ -445,7 +461,7 @@ impl InlineAssistant { let buffer = editor.buffer().read(cx).snapshot(cx); for assist_id in &editor_assists.assist_ids { let assist = &self.assists[assist_id]; - let assist_range = assist.codegen.read(cx).range.to_offset(&buffer); + let assist_range = assist.range.to_offset(&buffer); if assist_range.contains(&selection.start) && assist_range.contains(&selection.end) { if matches!(assist.codegen.read(cx).status, CodegenStatus::Pending) { @@ -473,7 +489,7 @@ impl InlineAssistant { let buffer = editor.buffer().read(cx).snapshot(cx); for assist_id in &editor_assists.assist_ids { let assist = &self.assists[assist_id]; - let assist_range = assist.codegen.read(cx).range.to_offset(&buffer); + let assist_range = assist.range.to_offset(&buffer); if assist.decorations.is_some() && assist_range.contains(&selection.start) && assist_range.contains(&selection.end) @@ -551,7 +567,7 @@ impl InlineAssistant { assist.codegen.read(cx).status, CodegenStatus::Error(_) | CodegenStatus::Done ) { - let assist_range = assist.codegen.read(cx).range.to_offset(&snapshot); + let assist_range = assist.range.to_offset(&snapshot); if edited_ranges .iter() .any(|range| range.overlaps(&assist_range)) @@ -721,7 +737,7 @@ impl InlineAssistant { }); } - let position = assist.codegen.read(cx).range.start; + let position = assist.range.start; editor.update(cx, |editor, cx| { editor.change_selections(None, cx, |selections| { selections.select_anchor_ranges([position..position]) @@ -740,8 +756,7 @@ impl InlineAssistant { .0 as f32; } else { let snapshot = editor.snapshot(cx); - let codegen = assist.codegen.read(cx); - let start_row = codegen + let start_row = assist .range .start .to_display_point(&snapshot.display_snapshot) @@ -829,11 +844,7 @@ impl InlineAssistant { return; } - let Some(user_prompt) = assist - .decorations - .as_ref() - .map(|decorations| decorations.prompt_editor.read(cx).prompt(cx)) - else { + let Some(user_prompt) = assist.user_prompt(cx) else { return; }; @@ -843,139 +854,19 @@ impl InlineAssistant { self.prompt_history.pop_front(); } - let codegen = assist.codegen.clone(); - let telemetry_id = LanguageModelCompletionProvider::read_global(cx) - .active_model() - .map(|m| m.telemetry_id()) - .unwrap_or_default(); - let chunks: LocalBoxFuture>>> = - if user_prompt.trim().to_lowercase() == "delete" { - async { Ok(stream::empty().boxed()) }.boxed_local() - } else { - let request = self.request_for_inline_assist(assist_id, cx); - let mut cx = cx.to_async(); - async move { - let request = request.await?; - let chunks = cx - .update(|cx| { - LanguageModelCompletionProvider::read_global(cx) - .stream_completion(request, cx) - })? - .await?; - Ok(chunks.boxed()) - } - .boxed_local() - }; - codegen.update(cx, |codegen, cx| { - codegen.start(telemetry_id, chunks, cx); - }); - } - - fn request_for_inline_assist( - &self, - assist_id: InlineAssistId, - cx: &mut WindowContext, - ) -> Task> { - cx.spawn(|mut cx| async move { - let (user_prompt, context_request, project_name, buffer, range) = - cx.read_global(|this: &InlineAssistant, cx: &WindowContext| { - let assist = this.assists.get(&assist_id).context("invalid assist")?; - let decorations = assist.decorations.as_ref().context("invalid assist")?; - let editor = assist.editor.upgrade().context("invalid assist")?; - let user_prompt = decorations.prompt_editor.read(cx).prompt(cx); - let context_request = if assist.include_context { - assist.workspace.as_ref().and_then(|workspace| { - let workspace = workspace.upgrade()?.read(cx); - let assistant_panel = workspace.panel::(cx)?; - Some( - assistant_panel - .read(cx) - .active_context(cx)? - .read(cx) - .to_completion_request(cx), - ) - }) - } else { - None - }; - let project_name = assist.workspace.as_ref().and_then(|workspace| { - let workspace = workspace.upgrade()?; - Some( - workspace - .read(cx) - .project() - .read(cx) - .worktree_root_names(cx) - .collect::>() - .join("/"), - ) - }); - let buffer = editor.read(cx).buffer().read(cx).snapshot(cx); - let range = assist.codegen.read(cx).range.clone(); - anyhow::Ok((user_prompt, context_request, project_name, buffer, range)) - })??; - - let language = buffer.language_at(range.start); - let language_name = if let Some(language) = language.as_ref() { - if Arc::ptr_eq(language, &language::PLAIN_TEXT) { - None - } else { - Some(language.name()) - } - } else { - None - }; + let assistant_panel_context = assist.assistant_panel_context(cx); - // Higher Temperature increases the randomness of model outputs. - // If Markdown or No Language is Known, increase the randomness for more creative output - // If Code, decrease temperature to get more deterministic outputs - let temperature = if let Some(language) = language_name.clone() { - if language.as_ref() == "Markdown" { - 1.0 - } else { - 0.5 - } - } else { - 1.0 - }; - - let prompt = cx - .background_executor() - .spawn(async move { - let language_name = language_name.as_deref(); - let start = buffer.point_to_buffer_offset(range.start); - let end = buffer.point_to_buffer_offset(range.end); - let (buffer, range) = if let Some((start, end)) = start.zip(end) { - let (start_buffer, start_buffer_offset) = start; - let (end_buffer, end_buffer_offset) = end; - if start_buffer.remote_id() == end_buffer.remote_id() { - (start_buffer.clone(), start_buffer_offset..end_buffer_offset) - } else { - return Err(anyhow!("invalid transformation range")); - } - } else { - return Err(anyhow!("invalid transformation range")); - }; - generate_content_prompt(user_prompt, language_name, buffer, range, project_name) - }) - .await?; - - let mut messages = Vec::new(); - if let Some(context_request) = context_request { - messages = context_request.messages; - } - - messages.push(LanguageModelRequestMessage { - role: Role::User, - content: prompt, - }); - - Ok(LanguageModelRequest { - messages, - stop: vec!["|END|>".to_string()], - temperature, + assist + .codegen + .update(cx, |codegen, cx| { + codegen.start( + assist.range.clone(), + user_prompt, + assistant_panel_context, + cx, + ) }) - }) + .log_err(); } pub fn stop_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) { @@ -1006,12 +897,11 @@ impl InlineAssistant { let codegen = assist.codegen.read(cx); foreground_ranges.extend(codegen.last_equal_ranges().iter().cloned()); - if codegen.edit_position != codegen.range.end { - gutter_pending_ranges.push(codegen.edit_position..codegen.range.end); - } + gutter_pending_ranges + .push(codegen.edit_position.unwrap_or(assist.range.start)..assist.range.end); - if codegen.range.start != codegen.edit_position { - gutter_transformed_ranges.push(codegen.range.start..codegen.edit_position); + if let Some(edit_position) = codegen.edit_position { + gutter_transformed_ranges.push(assist.range.start..edit_position); } if assist.decorations.is_some() { @@ -1268,6 +1158,12 @@ fn build_assist_editor_renderer(editor: &View) -> RenderBlock { }) } +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum InitialInsertion { + NewlineBefore, + NewlineAfter, +} + #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] pub struct InlineAssistId(usize); @@ -1629,24 +1525,20 @@ impl PromptEditor { let assist_id = self.id; self.pending_token_count = cx.spawn(|this, mut cx| async move { cx.background_executor().timer(Duration::from_secs(1)).await; - let request = cx + let token_count = cx .update_global(|inline_assistant: &mut InlineAssistant, cx| { - inline_assistant.request_for_inline_assist(assist_id, cx) - })? + let assist = inline_assistant + .assists + .get(&assist_id) + .context("assist not found")?; + anyhow::Ok(assist.count_tokens(cx)) + })?? .await?; - if let Some(token_count) = cx.update(|cx| { - LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx) - })? { - let token_count = token_count.await?; - - this.update(&mut cx, |this, cx| { - this.token_count = Some(token_count); - cx.notify(); - }) - } else { - Ok(()) - } + this.update(&mut cx, |this, cx| { + this.token_count = Some(token_count); + cx.notify(); + }) }) } @@ -1855,6 +1747,7 @@ impl PromptEditor { struct InlineAssist { group_id: InlineAssistGroupId, + range: Range, editor: WeakView, decorations: Option, codegen: Model, @@ -1873,6 +1766,7 @@ impl InlineAssist { prompt_editor: &View, prompt_block_id: CustomBlockId, end_block_id: CustomBlockId, + range: Range, codegen: Model, workspace: Option>, cx: &mut WindowContext, @@ -1888,6 +1782,7 @@ impl InlineAssist { removed_line_block_ids: HashSet::default(), end_block_id, }), + range, codegen: codegen.clone(), workspace: workspace.clone(), _subscriptions: vec![ @@ -1963,6 +1858,41 @@ impl InlineAssist { ], } } + + fn user_prompt(&self, cx: &AppContext) -> Option { + let decorations = self.decorations.as_ref()?; + Some(decorations.prompt_editor.read(cx).prompt(cx)) + } + + fn assistant_panel_context(&self, cx: &WindowContext) -> Option { + if self.include_context { + let workspace = self.workspace.as_ref()?; + let workspace = workspace.upgrade()?.read(cx); + let assistant_panel = workspace.panel::(cx)?; + Some( + assistant_panel + .read(cx) + .active_context(cx)? + .read(cx) + .to_completion_request(cx), + ) + } else { + None + } + } + + pub fn count_tokens(&self, cx: &WindowContext) -> BoxFuture<'static, Result> { + let Some(user_prompt) = self.user_prompt(cx) else { + return future::ready(Err(anyhow!("no user prompt"))).boxed(); + }; + let assistant_panel_context = self.assistant_panel_context(cx); + self.codegen.read(cx).count_tokens( + self.range.clone(), + user_prompt, + assistant_panel_context, + cx, + ) + } } struct InlineAssistDecorations { @@ -1982,16 +1912,15 @@ pub struct Codegen { buffer: Model, old_buffer: Model, snapshot: MultiBufferSnapshot, - range: Range, - edit_position: Anchor, + edit_position: Option, last_equal_ranges: Vec>, - prepend_transaction_id: Option, - generation_transaction_id: Option, + transaction_id: Option, status: CodegenStatus, generation: Task<()>, diff: Diff, telemetry: Option>, _subscription: gpui::Subscription, + initial_insertion: Option, } enum CodegenStatus { @@ -2015,7 +1944,7 @@ impl Codegen { pub fn new( buffer: Model, range: Range, - prepend_transaction_id: Option, + initial_insertion: Option, telemetry: Option>, cx: &mut ModelContext, ) -> Self { @@ -2044,17 +1973,16 @@ impl Codegen { Self { buffer: buffer.clone(), old_buffer, - edit_position: range.start, - range, + edit_position: None, snapshot, last_equal_ranges: Default::default(), - prepend_transaction_id, - generation_transaction_id: None, + transaction_id: None, status: CodegenStatus::Idle, generation: Task::ready(()), diff: Diff::default(), telemetry, _subscription: cx.subscribe(&buffer, Self::handle_buffer_event), + initial_insertion, } } @@ -2065,13 +1993,8 @@ impl Codegen { cx: &mut ModelContext, ) { if let multi_buffer::Event::TransactionUndone { transaction_id } = event { - if self.generation_transaction_id == Some(*transaction_id) { - self.generation_transaction_id = None; - self.generation = Task::ready(()); - cx.emit(CodegenEvent::Undone); - } else if self.prepend_transaction_id == Some(*transaction_id) { - self.prepend_transaction_id = None; - self.generation_transaction_id = None; + if self.transaction_id == Some(*transaction_id) { + self.transaction_id = None; self.generation = Task::ready(()); cx.emit(CodegenEvent::Undone); } @@ -2082,19 +2005,152 @@ impl Codegen { &self.last_equal_ranges } + pub fn count_tokens( + &self, + edit_range: Range, + user_prompt: String, + assistant_panel_context: Option, + cx: &AppContext, + ) -> BoxFuture<'static, Result> { + let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx); + LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx) + } + pub fn start( &mut self, - telemetry_id: String, + mut edit_range: Range, + user_prompt: String, + assistant_panel_context: Option, + cx: &mut ModelContext, + ) -> Result<()> { + self.undo(cx); + + // Handle initial insertion + self.transaction_id = if let Some(initial_insertion) = self.initial_insertion { + self.buffer.update(cx, |buffer, cx| { + buffer.start_transaction(cx); + let offset = edit_range.start.to_offset(&self.snapshot); + let edit_position; + match initial_insertion { + InitialInsertion::NewlineBefore => { + buffer.edit([(offset..offset, "\n\n")], None, cx); + self.snapshot = buffer.snapshot(cx); + edit_position = self.snapshot.anchor_after(offset + 1); + } + InitialInsertion::NewlineAfter => { + buffer.edit([(offset..offset, "\n")], None, cx); + self.snapshot = buffer.snapshot(cx); + edit_position = self.snapshot.anchor_after(offset); + } + } + self.edit_position = Some(edit_position); + edit_range = edit_position.bias_left(&self.snapshot)..edit_position; + buffer.end_transaction(cx) + }) + } else { + self.edit_position = Some(edit_range.start.bias_right(&self.snapshot)); + None + }; + + let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx) + .active_model_telemetry_id() + .context("no active model")?; + + let chunks: LocalBoxFuture>>> = if user_prompt + .trim() + .to_lowercase() + == "delete" + { + async { Ok(stream::empty().boxed()) }.boxed_local() + } else { + let request = + self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx); + let chunks = + LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx); + async move { Ok(chunks.await?.boxed()) }.boxed_local() + }; + self.handle_stream(model_telemetry_id, edit_range, chunks, cx); + Ok(()) + } + + fn build_request( + &self, + user_prompt: String, + assistant_panel_context: Option, + edit_range: Range, + cx: &AppContext, + ) -> LanguageModelRequest { + let buffer = self.buffer.read(cx).snapshot(cx); + let language = buffer.language_at(edit_range.start); + let language_name = if let Some(language) = language.as_ref() { + if Arc::ptr_eq(language, &language::PLAIN_TEXT) { + None + } else { + Some(language.name()) + } + } else { + None + }; + + // Higher Temperature increases the randomness of model outputs. + // If Markdown or No Language is Known, increase the randomness for more creative output + // If Code, decrease temperature to get more deterministic outputs + let temperature = if let Some(language) = language_name.clone() { + if language.as_ref() == "Markdown" { + 1.0 + } else { + 0.5 + } + } else { + 1.0 + }; + + let language_name = language_name.as_deref(); + let start = buffer.point_to_buffer_offset(edit_range.start); + let end = buffer.point_to_buffer_offset(edit_range.end); + let (buffer, range) = if let Some((start, end)) = start.zip(end) { + let (start_buffer, start_buffer_offset) = start; + let (end_buffer, end_buffer_offset) = end; + if start_buffer.remote_id() == end_buffer.remote_id() { + (start_buffer.clone(), start_buffer_offset..end_buffer_offset) + } else { + panic!("invalid transformation range"); + } + } else { + panic!("invalid transformation range"); + }; + let prompt = generate_content_prompt(user_prompt, language_name, buffer, range); + + let mut messages = Vec::new(); + if let Some(context_request) = assistant_panel_context { + messages = context_request.messages; + } + + messages.push(LanguageModelRequestMessage { + role: Role::User, + content: prompt, + }); + + LanguageModelRequest { + messages, + stop: vec!["|END|>".to_string()], + temperature, + } + } + + pub fn handle_stream( + &mut self, + model_telemetry_id: String, + edit_range: Range, stream: impl 'static + Future>>>, cx: &mut ModelContext, ) { - let range = self.range.clone(); let snapshot = self.snapshot.clone(); let selected_text = snapshot - .text_for_range(range.start..range.end) + .text_for_range(edit_range.start..edit_range.end) .collect::(); - let selection_start = range.start.to_point(&snapshot); + let selection_start = edit_range.start.to_point(&snapshot); // Start with the indentation of the first line in the selection let mut suggested_line_indent = snapshot @@ -2105,7 +2161,7 @@ impl Codegen { // If the first line in the selection does not have indentation, check the following lines if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space { - for row in selection_start.row..=range.end.to_point(&snapshot).row { + for row in selection_start.row..=edit_range.end.to_point(&snapshot).row { let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row)); // Prefer tabs if a line in the selection uses tabs as indentation if line_indent.kind == IndentKind::Tab { @@ -2116,19 +2172,13 @@ impl Codegen { } let telemetry = self.telemetry.clone(); - self.edit_position = range.start; self.diff = Diff::default(); self.status = CodegenStatus::Pending; - if let Some(transaction_id) = self.generation_transaction_id.take() { - self.buffer - .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx)); - } + let mut edit_start = edit_range.start.to_offset(&snapshot); self.generation = cx.spawn(|this, mut cx| { async move { let chunks = stream.await; let generate = async { - let mut edit_start = range.start.to_offset(&snapshot); - let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); let diff: Task> = cx.background_executor().spawn(async move { @@ -2218,7 +2268,7 @@ impl Codegen { telemetry.report_assistant_event( None, telemetry_events::AssistantKind::Inline, - telemetry_id, + model_telemetry_id, response_latency, error_message, ); @@ -2262,13 +2312,13 @@ impl Codegen { None, cx, ); - this.edit_position = snapshot.anchor_after(edit_start); + this.edit_position = Some(snapshot.anchor_after(edit_start)); buffer.end_transaction(cx) }); if let Some(transaction) = transaction { - if let Some(first_transaction) = this.generation_transaction_id { + if let Some(first_transaction) = this.transaction_id { // Group all assistant edits into the first transaction. this.buffer.update(cx, |buffer, cx| { buffer.merge_transactions( @@ -2278,14 +2328,14 @@ impl Codegen { ) }); } else { - this.generation_transaction_id = Some(transaction); + this.transaction_id = Some(transaction); this.buffer.update(cx, |buffer, cx| { buffer.finalize_last_transaction(cx) }); } } - this.update_diff(cx); + this.update_diff(edit_range.clone(), cx); cx.notify(); })?; } @@ -2321,27 +2371,22 @@ impl Codegen { } pub fn undo(&mut self, cx: &mut ModelContext) { - if let Some(transaction_id) = self.prepend_transaction_id.take() { - self.buffer - .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx)); - } - - if let Some(transaction_id) = self.generation_transaction_id.take() { + if let Some(transaction_id) = self.transaction_id.take() { self.buffer .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx)); } } - fn update_diff(&mut self, cx: &mut ModelContext) { + fn update_diff(&mut self, edit_range: Range, cx: &mut ModelContext) { if self.diff.task.is_some() { self.diff.should_update = true; } else { self.diff.should_update = false; let old_snapshot = self.snapshot.clone(); - let old_range = self.range.to_point(&old_snapshot); + let old_range = edit_range.to_point(&old_snapshot); let new_snapshot = self.buffer.read(cx).snapshot(cx); - let new_range = self.range.to_point(&new_snapshot); + let new_range = edit_range.to_point(&new_snapshot); self.diff.task = Some(cx.spawn(|this, mut cx| async move { let (deleted_row_ranges, inserted_row_ranges) = cx @@ -2422,7 +2467,7 @@ impl Codegen { this.diff.inserted_row_ranges = inserted_row_ranges; this.diff.task = None; if this.diff.should_update { - this.update_diff(cx); + this.update_diff(edit_range, cx); } cx.notify(); }) @@ -2629,12 +2674,14 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5)) }); - let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx)); + let codegen = + cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx)); let (chunks_tx, chunks_rx) = mpsc::unbounded(); codegen.update(cx, |codegen, cx| { - codegen.start( + codegen.handle_stream( String::new(), + range, future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), cx, ) @@ -2690,12 +2737,14 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6)) }); - let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx)); + let codegen = + cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx)); let (chunks_tx, chunks_rx) = mpsc::unbounded(); codegen.update(cx, |codegen, cx| { - codegen.start( + codegen.handle_stream( String::new(), + range.clone(), future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), cx, ) @@ -2755,12 +2804,14 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2)) }); - let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx)); + let codegen = + cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx)); let (chunks_tx, chunks_rx) = mpsc::unbounded(); codegen.update(cx, |codegen, cx| { - codegen.start( + codegen.handle_stream( String::new(), + range.clone(), future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), cx, ) @@ -2819,12 +2870,14 @@ mod tests { let snapshot = buffer.snapshot(cx); snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2)) }); - let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, None, cx)); + let codegen = + cx.new_model(|cx| Codegen::new(buffer.clone(), range.clone(), None, None, cx)); let (chunks_tx, chunks_rx) = mpsc::unbounded(); codegen.update(cx, |codegen, cx| { - codegen.start( + codegen.handle_stream( String::new(), + range.clone(), future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())), cx, ) diff --git a/crates/assistant/src/prompt_library.rs b/crates/assistant/src/prompt_library.rs index 0fbac05ef56b6..cea5db2a6feb2 100644 --- a/crates/assistant/src/prompt_library.rs +++ b/crates/assistant/src/prompt_library.rs @@ -734,29 +734,27 @@ impl PromptLibrary { const DEBOUNCE_TIMEOUT: Duration = Duration::from_secs(1); cx.background_executor().timer(DEBOUNCE_TIMEOUT).await; - if let Some(token_count) = cx.update(|cx| { - LanguageModelCompletionProvider::read_global(cx).count_tokens( - LanguageModelRequest { - messages: vec![LanguageModelRequestMessage { - role: Role::System, - content: body.to_string(), - }], - stop: Vec::new(), - temperature: 1., - }, - cx, - ) - })? { - let token_count = token_count.await?; + let token_count = cx + .update(|cx| { + LanguageModelCompletionProvider::read_global(cx).count_tokens( + LanguageModelRequest { + messages: vec![LanguageModelRequestMessage { + role: Role::System, + content: body.to_string(), + }], + stop: Vec::new(), + temperature: 1., + }, + cx, + ) + })? + .await?; - this.update(&mut cx, |this, cx| { - let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap(); - prompt_editor.token_count = Some(token_count); - cx.notify(); - }) - } else { - Ok(()) - } + this.update(&mut cx, |this, cx| { + let prompt_editor = this.prompt_editors.get_mut(&prompt_id).unwrap(); + prompt_editor.token_count = Some(token_count); + cx.notify(); + }) } .log_err() }); diff --git a/crates/assistant/src/prompts.rs b/crates/assistant/src/prompts.rs index 4c99caca82312..3c955fbe7cbe0 100644 --- a/crates/assistant/src/prompts.rs +++ b/crates/assistant/src/prompts.rs @@ -6,8 +6,7 @@ pub fn generate_content_prompt( language_name: Option<&str>, buffer: BufferSnapshot, range: Range, - _project_name: Option, -) -> anyhow::Result { +) -> String { let mut prompt = String::new(); let content_type = match language_name { @@ -15,14 +14,16 @@ pub fn generate_content_prompt( writeln!( prompt, "Here's a file of text that I'm going to ask you to make an edit to." - )?; + ) + .unwrap(); "text" } Some(language_name) => { writeln!( prompt, "Here's a file of {language_name} that I'm going to ask you to make an edit to." - )?; + ) + .unwrap(); "code" } }; @@ -70,7 +71,7 @@ pub fn generate_content_prompt( write!(prompt, "\n\n").unwrap(); if is_truncated { - writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n")?; + writeln!(prompt, "The context around the relevant section has been truncated (possibly in the middle of a line) for brevity.\n").unwrap(); } if range.is_empty() { @@ -107,7 +108,7 @@ pub fn generate_content_prompt( prompt.push_str("\n\nImmediately start with the following format with no remarks:\n\n```\n{{REWRITTEN_CODE}}\n```"); } - Ok(prompt) + prompt } pub fn generate_terminal_assistant_prompt( diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index e9c4d0e73c9d8..bea35ea89bca9 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -707,18 +707,15 @@ impl PromptEditor { inline_assistant.request_for_inline_assist(assist_id, cx) })??; - if let Some(token_count) = cx.update(|cx| { - LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx) - })? { - let token_count = token_count.await?; - - this.update(&mut cx, |this, cx| { - this.token_count = Some(token_count); - cx.notify(); - }) - } else { - Ok(()) - } + let token_count = cx + .update(|cx| { + LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx) + })? + .await?; + this.update(&mut cx, |this, cx| { + this.token_count = Some(token_count); + cx.notify(); + }) }) } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 92e5b1a58411d..f536f41aca686 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -10,7 +10,7 @@ use crate::{ ServerId, UpdatedChannelMessage, User, UserId, }, executor::Executor, - AppState, Error, RateLimit, RateLimiter, Result, + AppState, Config, Error, RateLimit, RateLimiter, Result, }; use anyhow::{anyhow, bail, Context as _}; use async_tungstenite::tungstenite::{ @@ -605,17 +605,39 @@ impl Server { )) .add_message_handler(broadcast_project_message_from_host::) .add_message_handler(update_context) + .add_request_handler({ + let app_state = app_state.clone(); + move |request, response, session| { + let app_state = app_state.clone(); + async move { + complete_with_language_model(request, response, session, &app_state.config) + .await + } + } + }) .add_streaming_request_handler({ let app_state = app_state.clone(); move |request, response, session| { - complete_with_language_model( - request, - response, - session, - app_state.config.openai_api_key.clone(), - app_state.config.google_ai_api_key.clone(), - app_state.config.anthropic_api_key.clone(), - ) + let app_state = app_state.clone(); + async move { + stream_complete_with_language_model( + request, + response, + session, + &app_state.config, + ) + .await + } + } + }) + .add_request_handler({ + let app_state = app_state.clone(); + move |request, response, session| { + let app_state = app_state.clone(); + async move { + count_language_model_tokens(request, response, session, &app_state.config) + .await + } } }) .add_request_handler({ @@ -4503,103 +4525,119 @@ impl RateLimit for CompleteWithLanguageModelRateLimit { } async fn complete_with_language_model( - query: proto::QueryLanguageModel, - response: StreamingResponse, + request: proto::CompleteWithLanguageModel, + response: Response, session: Session, - open_ai_api_key: Option>, - google_ai_api_key: Option>, - anthropic_api_key: Option>, + config: &Config, ) -> Result<()> { let Some(session) = session.for_user() else { return Err(anyhow!("user not found"))?; }; authorize_access_to_language_models(&session).await?; - match proto::LanguageModelRequestKind::from_i32(query.kind) { - Some(proto::LanguageModelRequestKind::Complete) => { - session - .rate_limiter - .check::(session.user_id()) - .await?; - } - Some(proto::LanguageModelRequestKind::CountTokens) => { - session - .rate_limiter - .check::(session.user_id()) - .await?; + session + .rate_limiter + .check::(session.user_id()) + .await?; + + let result = match proto::LanguageModelProvider::from_i32(request.provider) { + Some(proto::LanguageModelProvider::Anthropic) => { + let api_key = config + .anthropic_api_key + .as_ref() + .context("no Anthropic AI API key configured on the server")?; + anthropic::complete( + session.http_client.as_ref(), + anthropic::ANTHROPIC_API_URL, + api_key, + serde_json::from_str(&request.request)?, + ) + .await? } - None => Err(anyhow!("unknown request kind"))?, - } + _ => return Err(anyhow!("unsupported provider"))?, + }; + + response.send(proto::CompleteWithLanguageModelResponse { + completion: serde_json::to_string(&result)?, + })?; + + Ok(()) +} - match proto::LanguageModelProvider::from_i32(query.provider) { +async fn stream_complete_with_language_model( + request: proto::StreamCompleteWithLanguageModel, + response: StreamingResponse, + session: Session, + config: &Config, +) -> Result<()> { + let Some(session) = session.for_user() else { + return Err(anyhow!("user not found"))?; + }; + authorize_access_to_language_models(&session).await?; + + session + .rate_limiter + .check::(session.user_id()) + .await?; + + match proto::LanguageModelProvider::from_i32(request.provider) { Some(proto::LanguageModelProvider::Anthropic) => { - let api_key = - anthropic_api_key.context("no Anthropic AI API key configured on the server")?; + let api_key = config + .anthropic_api_key + .as_ref() + .context("no Anthropic AI API key configured on the server")?; let mut chunks = anthropic::stream_completion( session.http_client.as_ref(), anthropic::ANTHROPIC_API_URL, - &api_key, - serde_json::from_str(&query.request)?, + api_key, + serde_json::from_str(&request.request)?, None, ) .await?; - while let Some(chunk) = chunks.next().await { - let chunk = chunk?; - response.send(proto::QueryLanguageModelResponse { - response: serde_json::to_string(&chunk)?, + while let Some(event) = chunks.next().await { + let chunk = event?; + response.send(proto::StreamCompleteWithLanguageModelResponse { + event: serde_json::to_string(&chunk)?, })?; } } Some(proto::LanguageModelProvider::OpenAi) => { - let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?; - let mut chunks = open_ai::stream_completion( + let api_key = config + .openai_api_key + .as_ref() + .context("no OpenAI API key configured on the server")?; + let mut events = open_ai::stream_completion( session.http_client.as_ref(), open_ai::OPEN_AI_API_URL, - &api_key, - serde_json::from_str(&query.request)?, + api_key, + serde_json::from_str(&request.request)?, None, ) .await?; - while let Some(chunk) = chunks.next().await { - let chunk = chunk?; - response.send(proto::QueryLanguageModelResponse { - response: serde_json::to_string(&chunk)?, + while let Some(event) = events.next().await { + let event = event?; + response.send(proto::StreamCompleteWithLanguageModelResponse { + event: serde_json::to_string(&event)?, })?; } } Some(proto::LanguageModelProvider::Google) => { - let api_key = - google_ai_api_key.context("no Google AI API key configured on the server")?; - - match proto::LanguageModelRequestKind::from_i32(query.kind) { - Some(proto::LanguageModelRequestKind::Complete) => { - let mut chunks = google_ai::stream_generate_content( - session.http_client.as_ref(), - google_ai::API_URL, - &api_key, - serde_json::from_str(&query.request)?, - ) - .await?; - while let Some(chunk) = chunks.next().await { - let chunk = chunk?; - response.send(proto::QueryLanguageModelResponse { - response: serde_json::to_string(&chunk)?, - })?; - } - } - Some(proto::LanguageModelRequestKind::CountTokens) => { - let tokens_response = google_ai::count_tokens( - session.http_client.as_ref(), - google_ai::API_URL, - &api_key, - serde_json::from_str(&query.request)?, - ) - .await?; - response.send(proto::QueryLanguageModelResponse { - response: serde_json::to_string(&tokens_response)?, - })?; - } - None => Err(anyhow!("unknown request kind"))?, + let api_key = config + .google_ai_api_key + .as_ref() + .context("no Google AI API key configured on the server")?; + let mut events = google_ai::stream_generate_content( + session.http_client.as_ref(), + google_ai::API_URL, + api_key, + serde_json::from_str(&request.request)?, + ) + .await?; + while let Some(event) = events.next().await { + let event = event?; + response.send(proto::StreamCompleteWithLanguageModelResponse { + event: serde_json::to_string(&event)?, + })?; } } None => return Err(anyhow!("unknown provider"))?, @@ -4608,11 +4646,51 @@ async fn complete_with_language_model( Ok(()) } -struct CountTokensWithLanguageModelRateLimit; +async fn count_language_model_tokens( + request: proto::CountLanguageModelTokens, + response: Response, + session: Session, + config: &Config, +) -> Result<()> { + let Some(session) = session.for_user() else { + return Err(anyhow!("user not found"))?; + }; + authorize_access_to_language_models(&session).await?; + + session + .rate_limiter + .check::(session.user_id()) + .await?; + + let result = match proto::LanguageModelProvider::from_i32(request.provider) { + Some(proto::LanguageModelProvider::Google) => { + let api_key = config + .google_ai_api_key + .as_ref() + .context("no Google AI API key configured on the server")?; + google_ai::count_tokens( + session.http_client.as_ref(), + google_ai::API_URL, + api_key, + serde_json::from_str(&request.request)?, + ) + .await? + } + _ => return Err(anyhow!("unsupported provider"))?, + }; + + response.send(proto::CountLanguageModelTokensResponse { + token_count: result.total_tokens as u32, + })?; + + Ok(()) +} + +struct CountLanguageModelTokensRateLimit; -impl RateLimit for CountTokensWithLanguageModelRateLimit { +impl RateLimit for CountLanguageModelTokensRateLimit { fn capacity() -> usize { - std::env::var("COUNT_TOKENS_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR") + std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR") .ok() .and_then(|v| v.parse().ok()) .unwrap_or(600) // Picked arbitrarily @@ -4623,7 +4701,7 @@ impl RateLimit for CountTokensWithLanguageModelRateLimit { } fn db_name() -> &'static str { - "count-tokens-with-language-model" + "count-language-model-tokens" } } diff --git a/crates/completion/Cargo.toml b/crates/completion/Cargo.toml index 9e3855676e01f..7224dc6b0dd89 100644 --- a/crates/completion/Cargo.toml +++ b/crates/completion/Cargo.toml @@ -26,7 +26,9 @@ anyhow.workspace = true futures.workspace = true gpui.workspace = true language_model.workspace = true +schemars.workspace = true serde.workspace = true +serde_json.workspace = true settings.workspace = true smol.workspace = true ui.workspace = true diff --git a/crates/completion/src/completion.rs b/crates/completion/src/completion.rs index 376f8d9f73466..f55818e2841f2 100644 --- a/crates/completion/src/completion.rs +++ b/crates/completion/src/completion.rs @@ -3,10 +3,13 @@ use futures::{future::BoxFuture, stream::BoxStream, StreamExt}; use gpui::{AppContext, Global, Model, ModelContext, Task}; use language_model::{ LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry, - LanguageModelRequest, + LanguageModelRequest, LanguageModelTool, }; -use smol::lock::{Semaphore, SemaphoreGuardArc}; -use std::{pin::Pin, sync::Arc, task::Poll}; +use smol::{ + future::FutureExt, + lock::{Semaphore, SemaphoreGuardArc}, +}; +use std::{future, pin::Pin, sync::Arc, task::Poll}; use ui::Context; pub fn init(cx: &mut AppContext) { @@ -143,11 +146,11 @@ impl LanguageModelCompletionProvider { &self, request: LanguageModelRequest, cx: &AppContext, - ) -> Option>> { + ) -> BoxFuture<'static, Result> { if let Some(model) = self.active_model() { - Some(model.count_tokens(request, cx)) + model.count_tokens(request, cx) } else { - None + future::ready(Err(anyhow!("no active model"))).boxed() } } @@ -183,6 +186,29 @@ impl LanguageModelCompletionProvider { Ok(completion) }) } + + pub fn use_tool( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> Task> { + if let Some(language_model) = self.active_model() { + cx.spawn(|cx| async move { + let schema = schemars::schema_for!(T); + let schema_json = serde_json::to_value(&schema).unwrap(); + let request = + language_model.use_tool(request, T::name(), T::description(), schema_json, &cx); + let response = request.await?; + Ok(serde_json::from_value(response)?) + }) + } else { + Task::ready(Err(anyhow!("No active model set"))) + } + } + + pub fn active_model_telemetry_id(&self) -> Option { + self.active_model.as_ref().map(|m| m.telemetry_id()) + } } #[cfg(test)] diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 91edcab6f544f..b9f3262f30230 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -16,6 +16,8 @@ pub use model::*; pub use registry::*; pub use request::*; pub use role::*; +use schemars::JsonSchema; +use serde::de::DeserializeOwned; pub fn init(client: Arc, cx: &mut AppContext) { settings::init(cx); @@ -42,6 +44,20 @@ pub trait LanguageModel: Send + Sync { request: LanguageModelRequest, cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>>; + + fn use_tool( + &self, + request: LanguageModelRequest, + name: String, + description: String, + schema: serde_json::Value, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>; +} + +pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema { + fn name() -> String; + fn description() -> String; } pub trait LanguageModelProvider: 'static { diff --git a/crates/language_model/src/provider/anthropic.rs b/crates/language_model/src/provider/anthropic.rs index 7cc9922546d4a..cfca9358a16a1 100644 --- a/crates/language_model/src/provider/anthropic.rs +++ b/crates/language_model/src/provider/anthropic.rs @@ -1,5 +1,9 @@ -use anthropic::stream_completion; -use anyhow::{anyhow, Result}; +use crate::{ + settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, + LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, + LanguageModelProviderState, LanguageModelRequest, Role, +}; +use anyhow::{anyhow, Context as _, Result}; use collections::BTreeMap; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; @@ -15,12 +19,6 @@ use theme::ThemeSettings; use ui::prelude::*; use util::ResultExt; -use crate::{ - settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName, - LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, - LanguageModelProviderState, LanguageModelRequest, Role, -}; - const PROVIDER_ID: &str = "anthropic"; const PROVIDER_NAME: &str = "Anthropic"; @@ -188,6 +186,61 @@ pub fn count_anthropic_tokens( .boxed() } +impl AnthropicModel { + fn request_completion( + &self, + request: anthropic::Request, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result> { + let http_client = self.http_client.clone(); + + let Ok((api_key, api_url)) = cx.read_model(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).anthropic; + (state.api_key.clone(), settings.api_url.clone()) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + async move { + let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + anthropic::complete(http_client.as_ref(), &api_url, &api_key, request).await + } + .boxed() + } + + fn stream_completion( + &self, + request: anthropic::Request, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result>>> { + let http_client = self.http_client.clone(); + + let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| { + let settings = &AllLanguageModelSettings::get_global(cx).anthropic; + ( + state.api_key.clone(), + settings.api_url.clone(), + settings.low_speed_timeout, + ) + }) else { + return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); + }; + + async move { + let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; + let request = anthropic::stream_completion( + http_client.as_ref(), + &api_url, + &api_key, + request, + low_speed_timeout, + ); + request.await + } + .boxed() + } +} + impl LanguageModel for AnthropicModel { fn id(&self) -> LanguageModelId { self.id.clone() @@ -227,34 +280,53 @@ impl LanguageModel for AnthropicModel { cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { let request = request.into_anthropic(self.model.id().into()); - - let http_client = self.http_client.clone(); - - let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| { - let settings = &AllLanguageModelSettings::get_global(cx).anthropic; - ( - state.api_key.clone(), - settings.api_url.clone(), - settings.low_speed_timeout, - ) - }) else { - return futures::future::ready(Err(anyhow!("App state dropped"))).boxed(); - }; - + let request = self.stream_completion(request, cx); async move { - let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?; - let request = stream_completion( - http_client.as_ref(), - &api_url, - &api_key, - request, - low_speed_timeout, - ); let response = request.await?; Ok(anthropic::extract_text_from_events(response).boxed()) } .boxed() } + + fn use_tool( + &self, + request: LanguageModelRequest, + tool_name: String, + tool_description: String, + input_schema: serde_json::Value, + cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result> { + let mut request = request.into_anthropic(self.model.id().into()); + request.tool_choice = Some(anthropic::ToolChoice::Tool { + name: tool_name.clone(), + }); + request.tools = vec![anthropic::Tool { + name: tool_name.clone(), + description: tool_description, + input_schema, + }]; + + let response = self.request_completion(request, cx); + async move { + let response = response.await?; + response + .content + .into_iter() + .find_map(|content| { + if let anthropic::Content::ToolUse { name, input, .. } = content { + if name == tool_name { + Some(input) + } else { + None + } + } else { + None + } + }) + .context("tool not used") + } + .boxed() + } } struct AuthenticationPrompt { diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index d290876ad9ccf..8c32c723c9bf8 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -4,7 +4,7 @@ use crate::{ LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, }; -use anyhow::Result; +use anyhow::{anyhow, Context as _, Result}; use client::Client; use collections::BTreeMap; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; @@ -12,7 +12,7 @@ use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsStore}; -use std::sync::Arc; +use std::{future, sync::Arc}; use strum::IntoEnumIterator; use ui::prelude::*; @@ -234,15 +234,13 @@ impl LanguageModel for CloudLanguageModel { }; async move { let request = serde_json::to_string(&request)?; - let response = client.request(proto::QueryLanguageModel { - provider: proto::LanguageModelProvider::Google as i32, - kind: proto::LanguageModelRequestKind::CountTokens as i32, - request, - }); - let response = response.await?; - let response = - serde_json::from_str::(&response.response)?; - Ok(response.total_tokens) + let response = client + .request(proto::CountLanguageModelTokens { + provider: proto::LanguageModelProvider::Google as i32, + request, + }) + .await?; + Ok(response.token_count as usize) } .boxed() } @@ -260,14 +258,14 @@ impl LanguageModel for CloudLanguageModel { let request = request.into_anthropic(model.id().into()); async move { let request = serde_json::to_string(&request)?; - let response = client.request_stream(proto::QueryLanguageModel { - provider: proto::LanguageModelProvider::Anthropic as i32, - kind: proto::LanguageModelRequestKind::Complete as i32, - request, - }); - let chunks = response.await?; + let stream = client + .request_stream(proto::StreamCompleteWithLanguageModel { + provider: proto::LanguageModelProvider::Anthropic as i32, + request, + }) + .await?; Ok(anthropic::extract_text_from_events( - chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)), + stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), ) .boxed()) } @@ -278,14 +276,14 @@ impl LanguageModel for CloudLanguageModel { let request = request.into_open_ai(model.id().into()); async move { let request = serde_json::to_string(&request)?; - let response = client.request_stream(proto::QueryLanguageModel { - provider: proto::LanguageModelProvider::OpenAi as i32, - kind: proto::LanguageModelRequestKind::Complete as i32, - request, - }); - let chunks = response.await?; + let stream = client + .request_stream(proto::StreamCompleteWithLanguageModel { + provider: proto::LanguageModelProvider::OpenAi as i32, + request, + }) + .await?; Ok(open_ai::extract_text_from_events( - chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)), + stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), ) .boxed()) } @@ -296,14 +294,14 @@ impl LanguageModel for CloudLanguageModel { let request = request.into_google(model.id().into()); async move { let request = serde_json::to_string(&request)?; - let response = client.request_stream(proto::QueryLanguageModel { - provider: proto::LanguageModelProvider::Google as i32, - kind: proto::LanguageModelRequestKind::Complete as i32, - request, - }); - let chunks = response.await?; + let stream = client + .request_stream(proto::StreamCompleteWithLanguageModel { + provider: proto::LanguageModelProvider::Google as i32, + request, + }) + .await?; Ok(google_ai::extract_text_from_events( - chunks.map(|chunk| Ok(serde_json::from_str(&chunk?.response)?)), + stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), ) .boxed()) } @@ -311,6 +309,63 @@ impl LanguageModel for CloudLanguageModel { } } } + + fn use_tool( + &self, + request: LanguageModelRequest, + tool_name: String, + tool_description: String, + input_schema: serde_json::Value, + _cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result> { + match &self.model { + CloudModel::Anthropic(model) => { + let client = self.client.clone(); + let mut request = request.into_anthropic(model.id().into()); + request.tool_choice = Some(anthropic::ToolChoice::Tool { + name: tool_name.clone(), + }); + request.tools = vec![anthropic::Tool { + name: tool_name.clone(), + description: tool_description, + input_schema, + }]; + + async move { + let request = serde_json::to_string(&request)?; + let response = client + .request(proto::CompleteWithLanguageModel { + provider: proto::LanguageModelProvider::Anthropic as i32, + request, + }) + .await?; + let response: anthropic::Response = serde_json::from_str(&response.completion)?; + response + .content + .into_iter() + .find_map(|content| { + if let anthropic::Content::ToolUse { name, input, .. } = content { + if name == tool_name { + Some(input) + } else { + None + } + } else { + None + } + }) + .context("tool not used") + } + .boxed() + } + CloudModel::OpenAi(_) => { + future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed() + } + CloudModel::Google(_) => { + future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed() + } + } + } } struct AuthenticationPrompt { diff --git a/crates/language_model/src/provider/fake.rs b/crates/language_model/src/provider/fake.rs index 8f91155cd484b..7d5a6192a8ebe 100644 --- a/crates/language_model/src/provider/fake.rs +++ b/crates/language_model/src/provider/fake.rs @@ -1,15 +1,17 @@ -use std::sync::{Arc, Mutex}; - -use collections::HashMap; -use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; - use crate::{ LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, }; +use anyhow::anyhow; +use collections::HashMap; +use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, AsyncAppContext, Task}; use http_client::Result; +use std::{ + future, + sync::{Arc, Mutex}, +}; use ui::WindowContext; pub fn language_model_id() -> LanguageModelId { @@ -170,4 +172,15 @@ impl LanguageModel for FakeLanguageModel { .insert(serde_json::to_string(&request).unwrap(), tx); async move { Ok(rx.map(Ok).boxed()) }.boxed() } + + fn use_tool( + &self, + _request: LanguageModelRequest, + _name: String, + _description: String, + _schema: serde_json::Value, + _cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result> { + future::ready(Err(anyhow!("not implemented"))).boxed() + } } diff --git a/crates/language_model/src/provider/google.rs b/crates/language_model/src/provider/google.rs index 3a0c0a3f7e25d..3a869773a3579 100644 --- a/crates/language_model/src/provider/google.rs +++ b/crates/language_model/src/provider/google.rs @@ -9,7 +9,7 @@ use gpui::{ }; use http_client::HttpClient; use settings::{Settings, SettingsStore}; -use std::{sync::Arc, time::Duration}; +use std::{future, sync::Arc, time::Duration}; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::prelude::*; @@ -238,6 +238,17 @@ impl LanguageModel for GoogleLanguageModel { } .boxed() } + + fn use_tool( + &self, + _request: LanguageModelRequest, + _name: String, + _description: String, + _schema: serde_json::Value, + _cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result> { + future::ready(Err(anyhow!("not implemented"))).boxed() + } } struct AuthenticationPrompt { diff --git a/crates/language_model/src/provider/ollama.rs b/crates/language_model/src/provider/ollama.rs index c1896aafe9c8c..3502748e08fbc 100644 --- a/crates/language_model/src/provider/ollama.rs +++ b/crates/language_model/src/provider/ollama.rs @@ -6,7 +6,7 @@ use ollama::{ get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, }; use settings::{Settings, SettingsStore}; -use std::{sync::Arc, time::Duration}; +use std::{future, sync::Arc, time::Duration}; use ui::{prelude::*, ButtonLike, ElevationIndex}; use crate::{ @@ -298,6 +298,17 @@ impl LanguageModel for OllamaLanguageModel { } .boxed() } + + fn use_tool( + &self, + _request: LanguageModelRequest, + _name: String, + _description: String, + _schema: serde_json::Value, + _cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result> { + future::ready(Err(anyhow!("not implemented"))).boxed() + } } struct DownloadOllamaMessage { diff --git a/crates/language_model/src/provider/open_ai.rs b/crates/language_model/src/provider/open_ai.rs index 1b3bf18dd5a12..d2465e444661d 100644 --- a/crates/language_model/src/provider/open_ai.rs +++ b/crates/language_model/src/provider/open_ai.rs @@ -9,7 +9,7 @@ use gpui::{ use http_client::HttpClient; use open_ai::stream_completion; use settings::{Settings, SettingsStore}; -use std::{sync::Arc, time::Duration}; +use std::{future, sync::Arc, time::Duration}; use strum::IntoEnumIterator; use theme::ThemeSettings; use ui::prelude::*; @@ -225,6 +225,17 @@ impl LanguageModel for OpenAiLanguageModel { } .boxed() } + + fn use_tool( + &self, + _request: LanguageModelRequest, + _name: String, + _description: String, + _schema: serde_json::Value, + _cx: &AsyncAppContext, + ) -> BoxFuture<'static, Result> { + future::ready(Err(anyhow!("not implemented"))).boxed() + } } pub fn count_open_ai_tokens( diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs index fc3b8c019282c..ca57706f15dcc 100644 --- a/crates/language_model/src/request.rs +++ b/crates/language_model/src/request.rs @@ -106,19 +106,27 @@ impl LanguageModelRequest { messages: new_messages .into_iter() .filter_map(|message| { - Some(anthropic::RequestMessage { + Some(anthropic::Message { role: match message.role { Role::User => anthropic::Role::User, Role::Assistant => anthropic::Role::Assistant, Role::System => return None, }, - content: message.content, + content: vec![anthropic::Content::Text { + text: message.content, + }], }) }) .collect(), - stream: true, max_tokens: 4092, - system: system_message, + system: Some(system_message), + tools: Vec::new(), + tool_choice: None, + metadata: None, + stop_sequences: Vec::new(), + temperature: None, + top_k: None, + top_p: None, } } } diff --git a/crates/proto/proto/zed.proto b/crates/proto/proto/zed.proto index 658d552848b3d..404acb42e81eb 100644 --- a/crates/proto/proto/zed.proto +++ b/crates/proto/proto/zed.proto @@ -194,8 +194,12 @@ message Envelope { JoinHostedProject join_hosted_project = 164; - QueryLanguageModel query_language_model = 224; - QueryLanguageModelResponse query_language_model_response = 225; // current max + CompleteWithLanguageModel complete_with_language_model = 226; + CompleteWithLanguageModelResponse complete_with_language_model_response = 227; + StreamCompleteWithLanguageModel stream_complete_with_language_model = 228; + StreamCompleteWithLanguageModelResponse stream_complete_with_language_model_response = 229; + CountLanguageModelTokens count_language_model_tokens = 230; + CountLanguageModelTokensResponse count_language_model_tokens_response = 231; // current max GetCachedEmbeddings get_cached_embeddings = 189; GetCachedEmbeddingsResponse get_cached_embeddings_response = 190; ComputeEmbeddings compute_embeddings = 191; @@ -267,6 +271,7 @@ message Envelope { reserved 158 to 161; reserved 166 to 169; + reserved 224 to 225; } // Messages @@ -2050,25 +2055,37 @@ enum LanguageModelRole { reserved 3; } -message QueryLanguageModel { +message CompleteWithLanguageModel { LanguageModelProvider provider = 1; - LanguageModelRequestKind kind = 2; - string request = 3; + string request = 2; } -enum LanguageModelProvider { - Anthropic = 0; - OpenAI = 1; - Google = 2; +message CompleteWithLanguageModelResponse { + string completion = 1; +} + +message StreamCompleteWithLanguageModel { + LanguageModelProvider provider = 1; + string request = 2; +} + +message StreamCompleteWithLanguageModelResponse { + string event = 1; +} + +message CountLanguageModelTokens { + LanguageModelProvider provider = 1; + string request = 2; } -enum LanguageModelRequestKind { - Complete = 0; - CountTokens = 1; +message CountLanguageModelTokensResponse { + uint32 token_count = 1; } -message QueryLanguageModelResponse { - string response = 1; +enum LanguageModelProvider { + Anthropic = 0; + OpenAI = 1; + Google = 2; } message GetCachedEmbeddings { diff --git a/crates/proto/src/proto.rs b/crates/proto/src/proto.rs index 7ef1866acd3e2..632a6f69517ee 100644 --- a/crates/proto/src/proto.rs +++ b/crates/proto/src/proto.rs @@ -294,8 +294,12 @@ messages!( (PrepareRename, Background), (PrepareRenameResponse, Background), (ProjectEntryResponse, Foreground), - (QueryLanguageModel, Background), - (QueryLanguageModelResponse, Background), + (CompleteWithLanguageModel, Background), + (CompleteWithLanguageModelResponse, Background), + (StreamCompleteWithLanguageModel, Background), + (StreamCompleteWithLanguageModelResponse, Background), + (CountLanguageModelTokens, Background), + (CountLanguageModelTokensResponse, Background), (RefreshInlayHints, Foreground), (RejoinChannelBuffers, Foreground), (RejoinChannelBuffersResponse, Foreground), @@ -463,7 +467,12 @@ request_messages!( (PerformRename, PerformRenameResponse), (Ping, Ack), (PrepareRename, PrepareRenameResponse), - (QueryLanguageModel, QueryLanguageModelResponse), + (CompleteWithLanguageModel, CompleteWithLanguageModelResponse), + ( + StreamCompleteWithLanguageModel, + StreamCompleteWithLanguageModelResponse + ), + (CountLanguageModelTokens, CountLanguageModelTokensResponse), (RefreshInlayHints, Ack), (RejoinChannelBuffers, RejoinChannelBuffersResponse), (RejoinRoom, RejoinRoomResponse),