diff --git a/crates/goose-cli/src/agents/mock_agent.rs b/crates/goose-cli/src/agents/mock_agent.rs index bb7e1b09..e19b0eb8 100644 --- a/crates/goose-cli/src/agents/mock_agent.rs +++ b/crates/goose-cli/src/agents/mock_agent.rs @@ -9,9 +9,7 @@ pub struct MockAgent; #[async_trait] impl Agent for MockAgent { - fn add_system(&mut self, _system: Box) { - - } + fn add_system(&mut self, _system: Box) {} async fn reply(&self, _messages: &[Message]) -> Result>> { Ok(Box::pin(futures::stream::empty())) diff --git a/crates/goose-cli/src/prompt.rs b/crates/goose-cli/src/prompt.rs index b564acc0..09f5ef90 100644 --- a/crates/goose-cli/src/prompt.rs +++ b/crates/goose-cli/src/prompt.rs @@ -16,6 +16,9 @@ pub trait Prompt { println!("Goose is running! Enter your instructions, or try asking what goose can do."); println!("\n"); } + // Used for testing. Allows us to downcast to any type. + #[cfg(test)] + fn as_any(&self) -> &dyn std::any::Any; } pub struct Input { diff --git a/crates/goose-cli/src/prompt/cliclack.rs b/crates/goose-cli/src/prompt/cliclack.rs index b5206c72..7dda7da7 100644 --- a/crates/goose-cli/src/prompt/cliclack.rs +++ b/crates/goose-cli/src/prompt/cliclack.rs @@ -271,6 +271,11 @@ impl Prompt for CliclackPrompt { fn close(&self) { // No cleanup required } + + #[cfg(test)] + fn as_any(&self) -> &dyn std::any::Any { + panic!("Not implemented"); + } } ////// diff --git a/crates/goose-cli/src/prompt/rustyline.rs b/crates/goose-cli/src/prompt/rustyline.rs index f33d908b..9305b61a 100644 --- a/crates/goose-cli/src/prompt/rustyline.rs +++ b/crates/goose-cli/src/prompt/rustyline.rs @@ -360,4 +360,9 @@ impl Prompt for RustylinePrompt { fn close(&self) { // No cleanup required } + + #[cfg(test)] + fn as_any(&self) -> &dyn std::any::Any { + panic!("Not implemented"); + } } diff --git a/crates/goose-cli/src/session.rs b/crates/goose-cli/src/session.rs index 6f0ec450..1a94046d 100644 --- a/crates/goose-cli/src/session.rs +++ b/crates/goose-cli/src/session.rs @@ -158,8 +158,14 @@ impl<'a> Session<'a> { self.prompt.show_busy(); } Some(Err(e)) => { - // TODO: Handle error display through prompt eprintln!("Error: {}", e); + drop(stream); + self.rewind_messages(); + self.prompt.render(raw_message(r#" +\x1b[31mThe error above was an exception we were not able to handle.\n\n\x1b[0m +These errors are often related to connection or authentication\n +We've removed the conversation up to the most recent user message + - \x1b[33mdepending on the error you may be able to continue\x1b[0m"#)); break; } None => break, @@ -167,8 +173,7 @@ impl<'a> Session<'a> { } _ = tokio::signal::ctrl_c() => { drop(stream); - self.rewind_messages(); - self.prompt.render(raw_message(" Interrupt: Resetting conversation to before the last sent message...\n")); + self.handle_interrupted_messages(); break; } } @@ -176,7 +181,7 @@ impl<'a> Session<'a> { } /// Rewind the messages to before the last user message (they have cancelled it). - pub fn rewind_messages(&mut self) { + fn rewind_messages(&mut self) { if self.messages.is_empty() { return; } @@ -200,6 +205,65 @@ impl<'a> Session<'a> { } } + fn handle_interrupted_messages(&mut self) { + // First, get any tool requests from the last message if it exists + let tool_requests = self + .messages + .last() + .filter(|msg| msg.role == Role::Assistant) + .map_or(Vec::new(), |msg| { + msg.content + .iter() + .filter_map(|content| { + if let MessageContent::ToolRequest(req) = content { + Some((req.id.clone(), req.tool_call.clone())) + } else { + None + } + }) + .collect() + }); + + // Handle the interruption based on whether we were in a tool request + if !tool_requests.is_empty() { + // Create a message with tool responses for all interrupted requests + let mut response_message = Message::user(); + let last_tool_name = tool_requests + .last() + .and_then(|(_, tool_call)| tool_call.as_ref().ok().map(|tool| tool.name.clone())) + .unwrap_or_else(|| "tool".to_string()); + + for (req_id, _) in &tool_requests { + response_message.content.push(MessageContent::tool_response( + req_id.clone(), + Err(goose::errors::AgentError::ExecutionError( + "Interrupted by the user to make a correction".to_string(), + )), + )); + } + self.messages.push(response_message); + + // Add assistant message about the interruption using the last tool name + let prompt_response = &format!( + "We interrupted the existing call to {}. How would you like to proceed?", + last_tool_name + ); + self.messages + .push(Message::assistant().with_text(prompt_response)); + self.prompt.render(raw_message(prompt_response)); + } else { + // Default behavior for non-tool interruptions, remove the last user message + if let Some(last_msg) = self.messages.last() { + if last_msg.role == Role::User { + self.messages.pop(); + } + } + let prompt_response = + "We interrupted before the model replied and removed the last message."; + self.prompt.render(raw_message(prompt_response)); + } + } + fn setup_session(&mut self) { let system = Box::new(DeveloperSystem::new()); self.agent.add_system(system); @@ -225,10 +289,15 @@ fn raw_message(content: &str) -> Box { #[cfg(test)] mod tests { + use std::any::Any; + use std::sync::{Arc, Mutex}; + use crate::agents::mock_agent::MockAgent; use crate::prompt::{self, Input}; use super::*; + use goose::models::content::Content; + use goose::models::tool; use goose::{errors::AgentResult, models::tool::ToolCall}; use tempfile::NamedTempFile; @@ -236,12 +305,39 @@ mod tests { fn create_test_session() -> Session<'static> { let temp_file = NamedTempFile::new().unwrap(); let agent = Box::new(MockAgent {}); - let prompt = Box::new(MockPrompt {}); + let prompt = Box::new(MockPrompt::new()); + Session::new(agent, prompt, temp_file.path().to_path_buf()) + } + + fn create_test_session_with_prompt<'a>(prompt: Box) -> Session<'a> { + let temp_file = NamedTempFile::new().unwrap(); + let agent = Box::new(MockAgent {}); Session::new(agent, prompt, temp_file.path().to_path_buf()) } // Mock prompt implementation for testing - struct MockPrompt {} + pub struct MockPrompt { + messages: Arc>>, // Thread-safe, owned storage + } + + impl MockPrompt { + pub fn new() -> Self { + Self { + messages: Arc::new(Mutex::new(Vec::new())), + } + } + + pub fn add_message(&self, message: Message) { + let mut messages = self.messages.lock().unwrap(); // Lock to safely modify + messages.push(message); + } + + pub fn get_messages(&self) -> Vec { + let messages = self.messages.lock().unwrap(); // Lock to safely read + messages.clone() // Return a clone to avoid borrowing issues + } + } + impl Prompt for MockPrompt { fn get_input(&mut self) -> std::result::Result { Ok(Input { @@ -249,11 +345,16 @@ mod tests { content: Some("Msg:".to_string()), }) } - fn render(&mut self, _: Box) {} + fn render(&mut self, message: Box) { + self.add_message(message.as_ref().clone()); + } fn show_busy(&mut self) {} fn hide_busy(&self) {} fn goose_ready(&self) {} fn close(&self) {} + fn as_any(&self) -> &dyn Any { + self + } } #[test] @@ -336,4 +437,234 @@ mod tests { MessageContent::text("Response 1") ); } + + #[test] + fn test_interrupted_messages_only_1_user_msg() { + let mut session = create_test_session_with_prompt(Box::new(MockPrompt::new())); + session.messages.push(Message::user().with_text("Hello")); + + session.handle_interrupted_messages(); + + assert!(session.messages.is_empty()); + + assert_last_prompt_text( + &session, + "We interrupted before the model replied and removed the last message.", + ); + } + + #[test] + fn test_interrupted_messages_removes_last_user_msg() { + let mut session = create_test_session_with_prompt(Box::new(MockPrompt::new())); + session.messages.push(Message::user().with_text("Hello")); + session.messages.push(Message::assistant().with_text("Hi")); + session + .messages + .push(Message::user().with_text("How are you?")); + + session.handle_interrupted_messages(); + + assert_eq!(session.messages.len(), 2); + assert_eq!(session.messages[0].role, Role::User); + assert_eq!( + session.messages[0].content[0], + MessageContent::text("Hello") + ); + assert_eq!(session.messages[1].role, Role::Assistant); + assert_eq!(session.messages[1].content[0], MessageContent::text("Hi")); + + assert_last_prompt_text( + &session, + "We interrupted before the model replied and removed the last message.", + ); + } + + #[test] + fn test_interrupted_tool_use_resolves_with_last_tool_use_interrupted() { + let tool_name1 = "test"; + let tool_call1 = tool::ToolCall::new(tool_name1, "test".into()); + let tool_result1 = AgentResult::Ok(vec![Content::text("Task 1 done")]); + + let tool_name2 = "test2"; + let tool_call2 = tool::ToolCall::new(tool_name2, "test2".into()); + let mut session = create_test_session_with_prompt(Box::new(MockPrompt::new())); + session + .messages + .push(Message::user().with_text("Do something")); + session.messages.push( + Message::assistant() + .with_text("Doing it") + .with_tool_request("1", Ok(tool_call1.clone())), + ); + session.messages.push( + Message::user() + .with_text("Did Task 1") + .with_tool_response("1", tool_result1.clone()), + ); + session + .messages + .push(Message::user().with_text("Do something else")); + session.messages.push( + Message::assistant() + .with_text("Doing task 2") + .with_tool_request("2", Ok(tool_call2.clone())), + ); + + session.handle_interrupted_messages(); + + assert_eq!(session.messages.len(), 7); + assert_eq!(session.messages[0].role, Role::User); + assert_eq!( + session.messages[0].content[0], + MessageContent::text("Do something") + ); + assert_eq!(session.messages[1].role, Role::Assistant); + assert_eq!( + session.messages[1].content[0], + MessageContent::text("Doing it") + ); + assert_eq!( + session.messages[1].content[1], + MessageContent::tool_request("1", Ok(tool_call1)) + ); + + assert_eq!(session.messages[2].role, Role::User); + assert_eq!( + session.messages[2].content[0], + MessageContent::text("Did Task 1") + ); + assert_eq!( + session.messages[2].content[1], + MessageContent::tool_response("1", tool_result1) + ); + + assert_eq!(session.messages[3].role, Role::User); + assert_eq!( + session.messages[3].content[0], + MessageContent::text("Do something else") + ); + + assert_eq!( + session.messages[4].content[0], + MessageContent::text("Doing task 2") + ); + assert_eq!( + session.messages[4].content[1], + MessageContent::tool_request("2", Ok(tool_call2)) + ); + // Check the interrupted tool response message + assert_eq!(session.messages[5].role, Role::User); + let tool_result = Err(goose::errors::AgentError::ExecutionError( + "Interrupted by the user to make a correction".to_string(), + )); + assert_eq!( + session.messages[5].content[0], + MessageContent::tool_response("2", tool_result) + ); + + // Check the follow-up assistant message + assert_eq!(session.messages[6].role, Role::Assistant); + assert_eq!( + session.messages[6].content[0], + MessageContent::text(format!( + "We interrupted the existing call to {}. How would you like to proceed?", + tool_name2 + )) + ); + + assert_last_prompt_text( + &session, + format!( + "We interrupted the existing call to {}. How would you like to proceed?", + tool_name2 + ) + .as_str(), + ); + } + + #[test] + fn test_interrupted_tool_use_interrupts_multiple_tools() { + let tool_name1 = "test"; + let tool_call1 = tool::ToolCall::new(tool_name1, "test".into()); + + let tool_name2 = "test2"; + let tool_call2 = tool::ToolCall::new(tool_name2, "test2".into()); + let mut session = create_test_session_with_prompt(Box::new(MockPrompt::new())); + session + .messages + .push(Message::user().with_text("Do something")); + session.messages.push( + Message::assistant() + .with_text("Doing it") + .with_tool_request("1", Ok(tool_call1.clone())) + .with_tool_request("2", Ok(tool_call2.clone())), + ); + + session.handle_interrupted_messages(); + + assert_eq!(session.messages.len(), 4); + assert_eq!(session.messages[0].role, Role::User); + assert_eq!( + session.messages[0].content[0], + MessageContent::text("Do something") + ); + assert_eq!(session.messages[1].role, Role::Assistant); + assert_eq!( + session.messages[1].content[0], + MessageContent::text("Doing it") + ); + assert_eq!( + session.messages[1].content[1], + MessageContent::tool_request("1", Ok(tool_call1)) + ); + assert_eq!( + session.messages[1].content[2], + MessageContent::tool_request("2", Ok(tool_call2)) + ); + + // Check the interrupted tool response message + assert_eq!(session.messages[2].role, Role::User); + let tool_result = Err(goose::errors::AgentError::ExecutionError( + "Interrupted by the user to make a correction".to_string(), + )); + assert_eq!( + session.messages[2].content[0], + MessageContent::tool_response("1", tool_result.clone()) + ); + assert_eq!( + session.messages[2].content[1], + MessageContent::tool_response("2", tool_result) + ); + + // Check the follow-up assistant message + assert_eq!(session.messages[3].role, Role::Assistant); + assert_eq!( + session.messages[3].content[0], + MessageContent::text(format!( + "We interrupted the existing call to {}. How would you like to proceed?", + tool_name2 + )) + ); + + assert_last_prompt_text( + &session, + format!( + "We interrupted the existing call to {}. How would you like to proceed?", + tool_name2 + ) + .as_str(), + ); + } + + fn assert_last_prompt_text(session: &Session, expected_text: &str) { + let prompt = session + .prompt + .as_any() + .downcast_ref::() + .expect("Failed to downcast"); + let messages = prompt.get_messages(); + let msg = messages.last().unwrap(); + assert_eq!(msg.role, Role::Assistant); + assert_eq!(msg.content[0], MessageContent::text(expected_text)); + } }