Skip to content

Commit

Permalink
Merge branch 'v1.0' into micn/initial-dirs
Browse files Browse the repository at this point in the history
* v1.0:
  [app] Fix message ordering + better UI for ToolInvocations
  fix: server state persists (#378)
  cli: review command naming and help descriptions (#347)
  fix: Handle interrupt after tool response but before assistant message (#374)
  fix: set the path when running from app (#375)
  • Loading branch information
michaelneale committed Dec 1, 2024
2 parents ef4ab5c + a8b6bb3 commit 7d6e6e0
Show file tree
Hide file tree
Showing 28 changed files with 1,919 additions and 139 deletions.
117 changes: 103 additions & 14 deletions crates/goose-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct Cli {
#[arg(long)]
databricks_token: Option<String>,

/// Model to use
/// The machine learning model to use for operations. Use 'gpt-4o' for enhanced competence.
#[arg(short, long, default_value = "gpt-4o")]
model: String,

Expand All @@ -62,45 +62,134 @@ struct Cli {

#[derive(Subcommand)]
enum Command {
/// Configure Goose settings and profiles
#[command(about = "Configure Goose settings and profiles")]
Configure {
/// Name of the profile to configure
#[arg(
help = "Profile name to configure",
long_help = "Create or modify a named configuration profile. Use 'default' for the default profile."
)]
profile_name: Option<String>,
},

/// Manage system prompts and behaviors
#[command(about = "Manage system prompts and behaviors")]
System {
#[command(subcommand)]
action: SystemCommands,
},
/// Start or resume sessions with an optional session name

/// Start or resume interactive chat sessions
#[command(
about = "Start or resume interactive chat sessions",
alias = "s",
)]
Session {
#[arg(short, long)]
/// Name for the chat session
#[arg(
short,
long,
value_name = "NAME",
help = "Name for the chat session (e.g., 'project-x')",
long_help = "Specify a name for your chat session. When used with --resume, will resume this specific session if it exists."
)]
session: Option<String>,
#[arg(short, long)]

/// Configuration profile to use
#[arg(
short,
long,
value_name = "PROFILE",
help = "Configuration profile to use (e.g., 'default')",
long_help = "Use a specific configuration profile. Profiles contain settings like API keys and model preferences."
)]
profile: Option<String>,
#[arg(short, long, action = clap::ArgAction::SetTrue)]

/// Resume a previous session
#[arg(
short,
long,
help = "Resume a previous session (last used or specified by --session)",
long_help = "Continue from a previous chat session. If --session is provided, resumes that specific session. Otherwise resumes the last used session."
)]
resume: bool,
},
/// Run goose once-off with instructions from a file

/// Execute commands from an instruction file
#[command(about = "Execute commands from an instruction file")]
Run {
#[arg(short, long)]
/// Path to instruction file containing commands
#[arg(
short,
long,
required = true,
value_name = "FILE",
help = "Path to instruction file containing commands",
)]
instructions: Option<String>,
#[arg(short = 't', long = "text")]
input_text: Option<String>,
#[arg(short, long)]

/// Configuration profile to use
#[arg(
short,
long,
value_name = "PROFILE",
help = "Configuration profile to use (e.g., 'default')",
long_help = "Use a specific configuration profile. Profiles contain settings like API keys and model preferences."
)]
profile: Option<String>,
#[arg(short, long)]

/// Input text containing commands
#[arg(
short = 't',
long = "text",
value_name = "TEXT",
help = "Input text to provide to Goose directly",
long_help = "Input text containing commands for Goose. Use this in lieu of the instructions argument."
)]
input_text: Option<String>,

/// Name for this run session
#[arg(
short,
long,
value_name = "NAME",
help = "Name for this run session (e.g., 'daily-tasks')",
long_help = "Specify a name for this run session. This helps identify and resume specific runs later."
)]
session: Option<String>,
#[arg(short, long, action = clap::ArgAction::SetTrue)]

/// Resume a previous run
#[arg(
short,
long,
action = clap::ArgAction::SetTrue,
help = "Resume from a previous run",
long_help = "Continue from a previous run, maintaining the execution state and context."
)]
resume: bool,
},
}

#[derive(Subcommand)]
enum SystemCommands {
/// Add a new system prompt
#[command(about = "Add a new system prompt from URL")]
Add {
#[arg(help = "The URL to add system")]
#[arg(
help = "URL of the system prompt to add",
long_help = "URL pointing to a file containing the system prompt to be added."
)]
url: String,
},

/// Remove an existing system prompt
#[command(about = "Remove an existing system prompt")]
Remove {
#[arg(help = "The URL to remove system")]
#[arg(
help = "URL of the system prompt to remove",
long_help = "URL of the system prompt that should be removed from the configuration."
)]
url: String,
},
}
Expand Down
86 changes: 78 additions & 8 deletions crates/goose-cli/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use anyhow::Result;
use core::panic;
use futures::StreamExt;
use serde_json;
use std::fs::{self, File};
Expand Down Expand Up @@ -224,9 +225,9 @@ We've removed the conversation up to the most recent user message
.collect()
});

// Handle the interruption based on whether we were in a tool request
if !tool_requests.is_empty() {
// Create a message with tool responses for all interrupted requests
// Interrupted during a tool request
// Create tool responses for all interrupted tool requests
let mut response_message = Message::user();
let last_tool_name = tool_requests
.last()
Expand All @@ -243,7 +244,6 @@ We've removed the conversation up to the most recent user message
}
self.messages.push(response_message);

// Add assistant message about the interruption using the last tool name
let prompt_response = &format!(
"We interrupted the existing call to {}. How would you like to proceed?",
last_tool_name
Expand All @@ -252,15 +252,27 @@ We've removed the conversation up to the most recent user message
.push(Message::assistant().with_text(prompt_response));
self.prompt.render(raw_message(prompt_response));
} else {
// Default behavior for non-tool interruptions, remove the last user message
// An interruption occurred outside of a tool request-response.
if let Some(last_msg) = self.messages.last() {
if last_msg.role == Role::User {
self.messages.pop();
match last_msg.content.first() {
Some(MessageContent::ToolResponse(_)) => {
// Interruption occurred after a tool had completed but not assistant reply
let prompt_response = "We interrupted the existing calls to tools. How would you like to proceed?";
self.messages
.push(Message::assistant().with_text(prompt_response));
self.prompt.render(raw_message(prompt_response));
}
Some(_) => {
// A real users message
self.messages.pop();
let prompt_response = "We interrupted before the model replied and removed the last message.";
self.prompt.render(raw_message(prompt_response));
}
None => panic!("No content in last message"),
}
}
}
let prompt_response =
"We interrupted before the model replied and removed the last message.";
self.prompt.render(raw_message(prompt_response));
}
}

Expand Down Expand Up @@ -656,6 +668,64 @@ mod tests {
);
}

#[test]
fn test_interrupted_tool_use_interrupts_completed_tool_result_but_no_assistant_msg_yet() {
let tool_name1 = "test";
let tool_call1 = tool::ToolCall::new(tool_name1, "test".into());
let tool_result1 = AgentResult::Ok(vec![Content::text("Task 1 done")]);

let mut session = create_test_session_with_prompt(Box::new(MockPrompt::new()));
session
.messages
.push(Message::user().with_text("Do something"));
session.messages.push(
Message::assistant()
.with_text("Doing part 1")
.with_tool_request("1", Ok(tool_call1.clone())),
);
session
.messages
.push(Message::user().with_tool_response("1", tool_result1.clone()));

session.handle_interrupted_messages();

assert_eq!(session.messages.len(), 4);
assert_eq!(session.messages[0].role, Role::User);
assert_eq!(
session.messages[0].content[0],
MessageContent::text("Do something")
);
assert_eq!(session.messages[1].role, Role::Assistant);
assert_eq!(
session.messages[1].content[0],
MessageContent::text("Doing part 1")
);
assert_eq!(
session.messages[1].content[1],
MessageContent::tool_request("1", Ok(tool_call1))
);

assert_eq!(session.messages[2].role, Role::User);
assert_eq!(
session.messages[2].content[0],
MessageContent::tool_response("1", tool_result1.clone())
);

// Check the follow-up assistant message
assert_eq!(session.messages[3].role, Role::Assistant);
assert_eq!(
session.messages[3].content[0],
MessageContent::text(
"We interrupted the existing calls to tools. How would you like to proceed?",
)
);

assert_last_prompt_text(
&session,
"We interrupted the existing calls to tools. How would you like to proceed?",
);
}

fn assert_last_prompt_text(session: &Session, expected_text: &str) {
let prompt = session
.prompt
Expand Down
4 changes: 1 addition & 3 deletions crates/goose-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ async fn main() -> anyhow::Result<()> {
let settings = configuration::Settings::new()?;

// Create app state
let state = state::AppState {
provider_config: settings.provider.into_config(),
};
let state = state::AppState::new(settings.provider.into_config())?;

// Create router with CORS support
let cors = CorsLayer::new()
Expand Down
26 changes: 8 additions & 18 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@ use axum::{
use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use goose::{
agent::Agent,
developer::DeveloperSystem,
models::content::Content,
models::message::{Message, MessageContent},
models::role::Role,
providers::factory,
};
use serde::Deserialize;
use serde_json::{json, Value};
Expand Down Expand Up @@ -201,7 +198,8 @@ async fn stream_message(
.await?;
}
Err(err) => {
let result = vec![Content::text(format!("Error {}", err))];
let result =
vec![Content::text(format!("Error {}", err)).with_priority(0.0)];
tx.send(ProtocolFormatter::format_tool_response(
&response.id,
&result,
Expand Down Expand Up @@ -272,19 +270,15 @@ async fn handler(
let (tx, rx) = mpsc::channel(100);
let stream = ReceiverStream::new(rx);

// Setup agent with developer system
let system = Box::new(DeveloperSystem::new());
let provider = factory::get_provider(state.provider_config)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

let mut agent = Agent::new(provider);
agent.add_system(system);

// Convert incoming messages
let messages = convert_messages(request.messages);

// Get a lock on the shared agent
let agent = state.agent.clone();

// Spawn task to handle streaming
tokio::spawn(async move {
let agent = agent.lock().await;
let mut stream = match agent.reply(&messages).await {
Ok(stream) => stream,
Err(e) => {
Expand Down Expand Up @@ -345,12 +339,8 @@ async fn ask_handler(
State(state): State<AppState>,
Json(request): Json<AskRequest>,
) -> Result<Json<AskResponse>, StatusCode> {
let provider = factory::get_provider(state.provider_config)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let system = Box::new(DeveloperSystem::new());

let mut agent = Agent::new(provider);
agent.add_system(system);
let agent = state.agent.clone();
let agent = agent.lock().await;

// Create a single message for the prompt
let messages = vec![Message::user().with_text(request.prompt)];
Expand Down
24 changes: 23 additions & 1 deletion crates/goose-server/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,29 @@
use goose::providers::configs::ProviderConfig;
use anyhow::Result;
use goose::{
agent::Agent,
developer::DeveloperSystem,
providers::{configs::ProviderConfig, factory},
};
use std::sync::Arc;
use tokio::sync::Mutex;

/// Shared application state
pub struct AppState {
pub provider_config: ProviderConfig,
pub agent: Arc<Mutex<Agent>>,
}

impl AppState {
pub fn new(provider_config: ProviderConfig) -> Result<Self> {
let provider = factory::get_provider(provider_config.clone())?;
let mut agent = Agent::new(provider);
agent.add_system(Box::new(DeveloperSystem::new()));

Ok(Self {
provider_config,
agent: Arc::new(Mutex::new(agent)),
})
}
}

// Manual Clone implementation since we know ProviderConfig variants can be cloned
Expand Down Expand Up @@ -38,6 +59,7 @@ impl Clone for AppState {
})
}
},
agent: self.agent.clone(),
}
}
}
Loading

0 comments on commit 7d6e6e0

Please sign in to comment.