diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 7456ae0afc..91f2e87480 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -65,6 +65,7 @@ use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, instrument, warn}; const DEFAULT_MAX_TURNS: u32 = 1000; +const DEFAULT_STOP_HOOK_BLOCK_CAP: u32 = 8; const COMPACTION_THINKING_TEXT: &str = "goose is compacting the conversation..."; const DEFAULT_FRONTEND_INSTRUCTIONS: &str = "The following tools are provided directly by the frontend and will be executed by the frontend when called."; @@ -98,6 +99,35 @@ fn extract_string_arg(input: &Value, keys: &[&str]) -> Option { None } +fn stop_hook_denial_context_message(plugin: &str, reason: &str) -> Message { + let nudge = format!( + "Stop hook `{plugin}` blocked ending this turn: + +{reason} + +Address this policy hook denial before trying to stop again." + ); + Message::user() + .with_text(nudge) + .with_visibility(false, true) +} + +fn stop_hook_denial_notification(plugin: &str) -> Message { + Message::assistant().with_system_notification( + SystemNotificationType::InlineMessage, + format!("Stop hook `{plugin}` blocked ending this turn."), + ) +} + +fn stop_hook_block_cap_warning(plugin: &str, cap: u32) -> Message { + Message::assistant().with_system_notification( + SystemNotificationType::InlineMessage, + format!( + "Stop hook `{plugin}` blocked the turn from ending more than {cap} consecutive times — overriding and ending turn to avoid an infinite loop. Set GOOSE_STOP_HOOK_BLOCK_CAP to raise this limit." + ), + ) +} + /// Context needed for the reply function pub struct ReplyContext { pub conversation: Conversation, @@ -211,6 +241,8 @@ pub struct Agent { pub(super) retry_manager: RetryManager, pub(super) tool_inspection_manager: ToolInspectionManager, pub(super) hook_manager: crate::hooks::HookManager, + #[cfg(test)] + stop_hook_block_cap_override: Option, container: Mutex>, goal: Mutex>, grind: Mutex>, @@ -330,6 +362,8 @@ impl Agent { provider.clone(), ), hook_manager: crate::hooks::HookManager::load(std::env::current_dir().ok().as_deref()), + #[cfg(test)] + stop_hook_block_cap_override: None, container: Mutex::new(None), goal: Mutex::new(None), grind: Mutex::new(None), @@ -338,6 +372,27 @@ impl Agent { /// Emit a lifecycle hook event with no extra context. Useful for events /// that have no matcher (e.g. `SessionStart`, `SessionEnd`). + #[cfg(test)] + pub(crate) fn set_hook_manager_for_test(&mut self, hook_manager: crate::hooks::HookManager) { + self.hook_manager = hook_manager; + } + + #[cfg(test)] + pub(crate) fn set_stop_hook_block_cap_for_test(&mut self, cap: u32) { + self.stop_hook_block_cap_override = Some(cap); + } + + fn stop_hook_block_cap(&self) -> u32 { + #[cfg(test)] + if let Some(cap) = self.stop_hook_block_cap_override { + return cap; + } + + Config::global() + .get_param::("GOOSE_STOP_HOOK_BLOCK_CAP") + .unwrap_or(DEFAULT_STOP_HOOK_BLOCK_CAP) + } + pub async fn emit_hook(&self, event: crate::hooks::HookEvent, session_id: &str) { if !self.hook_manager.has_hooks(event) { return; @@ -1643,21 +1698,63 @@ impl Agent { let mut last_assistant_text = String::new(); let mut goal_check_pending = false; let mut tool_pair_summarization_done = false; + let mut stop_hook_handled_for_exit = false; + let mut retrying_after_stop_hook_denial = false; + let mut consecutive_stop_hook_blocks = 0u32; + let stop_hook_block_cap = self.stop_hook_block_cap(); loop { if is_token_cancelled(&cancel_token) { break; } - { - let guard = self.final_output_tool.lock().await; - if let Some(ref output) = guard.as_ref().and_then(|fot| fot.final_output.clone()) { - yield AgentEvent::Message(Message::assistant().with_text(output)); - break; + let final_output = { + let mut guard = self.final_output_tool.lock().await; + guard.as_mut().and_then(|fot| fot.final_output.take()) + }; + if let Some(output) = final_output { + let message = Message::assistant().with_text(output); + yield AgentEvent::Message(message.clone()); + session_manager.add_message(&session_config.id, &message).await?; + conversation.push(message); + + let ctx = crate::hooks::HookContext::new( + crate::hooks::HookEvent::Stop, + &session_config.id, + ); + match self + .hook_manager + .emit_blocking(crate::hooks::HookEvent::Stop, ctx) + .await + { + crate::hooks::HookDecision::Allow => { + stop_hook_handled_for_exit = true; + break; + } + crate::hooks::HookDecision::Deny { reason, plugin } => { + consecutive_stop_hook_blocks += 1; + if consecutive_stop_hook_blocks > stop_hook_block_cap { + let message = stop_hook_block_cap_warning(&plugin, stop_hook_block_cap); + session_manager.add_message(&session_config.id, &message).await?; + yield AgentEvent::Message(message); + stop_hook_handled_for_exit = true; + break; + } + let message = stop_hook_denial_context_message(&plugin, &reason); + session_manager.add_message(&session_config.id, &message).await?; + conversation.push(message); + yield AgentEvent::Message(stop_hook_denial_notification(&plugin)); + retrying_after_stop_hook_denial = true; + continue; + } } } - turns_taken += 1; + if retrying_after_stop_hook_denial { + retrying_after_stop_hook_denial = false; + } else { + turns_taken += 1; + } if turns_taken > max_turns { yield AgentEvent::Message( Message::assistant().with_text( @@ -1706,6 +1803,7 @@ impl Agent { let mut tools_updated = false; let mut did_recovery_compact_this_iteration = false; let mut exit_chat = false; + let mut pending_final_output: Option = None; // Track whether this provider turn has already emitted visible // thinking so a later tool-call chunk can suppress replayed @@ -2135,8 +2233,8 @@ impl Agent { // Lock, extract state, drop guard before branching — handle_retry_logic // also locks final_output_tool and tokio::sync::Mutex is not reentrant. let final_output = { - let guard = self.final_output_tool.lock().await; - guard.as_ref().map(|fot| fot.final_output.clone()) + let mut guard = self.final_output_tool.lock().await; + guard.as_mut().map(|fot| fot.final_output.take()) }; match final_output { @@ -2147,9 +2245,7 @@ impl Agent { yield AgentEvent::Message(message); } Some(Some(output)) => { - let message = Message::assistant().with_text(output); - messages_to_add.push(message.clone()); - yield AgentEvent::Message(message); + pending_final_output = Some(output); exit_chat = true; } None if did_recovery_compact_this_iteration => { @@ -2257,6 +2353,12 @@ impl Agent { } } + if let Some(output) = pending_final_output.take() { + let message = Message::assistant().with_text(output); + messages_to_add.push(message.clone()); + yield AgentEvent::Message(message); + } + let messages_to_add = if let Some(ref inference) = inference { Conversation::new_unvalidated( messages_to_add @@ -2271,8 +2373,37 @@ impl Agent { session_manager.add_message(&session_config.id, msg).await?; } conversation.extend(messages_to_add); + if exit_chat { - break; + let ctx = crate::hooks::HookContext::new( + crate::hooks::HookEvent::Stop, + &session_config.id, + ); + match self + .hook_manager + .emit_blocking(crate::hooks::HookEvent::Stop, ctx) + .await + { + crate::hooks::HookDecision::Allow => { + stop_hook_handled_for_exit = true; + break; + } + crate::hooks::HookDecision::Deny { reason, plugin } => { + consecutive_stop_hook_blocks += 1; + if consecutive_stop_hook_blocks > stop_hook_block_cap { + let message = stop_hook_block_cap_warning(&plugin, stop_hook_block_cap); + session_manager.add_message(&session_config.id, &message).await?; + yield AgentEvent::Message(message); + stop_hook_handled_for_exit = true; + break; + } + let message = stop_hook_denial_context_message(&plugin, &reason); + session_manager.add_message(&session_config.id, &message).await?; + conversation.push(message); + yield AgentEvent::Message(stop_hook_denial_notification(&plugin)); + retrying_after_stop_hook_denial = true; + } + } } tokio::task::yield_now().await; @@ -2282,7 +2413,9 @@ impl Agent { tracing::Span::current().record("trace_output", last_assistant_text.as_str()); } - self.emit_hook(crate::hooks::HookEvent::Stop, &session_config.id).await; + if !stop_hook_handled_for_exit { + self.emit_hook(crate::hooks::HookEvent::Stop, &session_config.id).await; + } }.instrument(reply_stream_span)); Ok(inner) } @@ -2788,8 +2921,17 @@ impl Agent { mod tests { use super::*; use crate::permission::permission_confirmation::PrincipalType; - use crate::providers::base::PermissionRouting; + use crate::plugins::discovery::{DiscoveredPlugin, PluginScope}; + use crate::providers::base::{ + stream_from_single_message, MessageStream, PermissionRouting, ProviderUsage, Usage, + }; + use crate::providers::errors::ProviderError; use crate::recipe::Response; + use crate::session::session_manager::SessionType; + use rmcp::model::Tool; + use std::path::PathBuf; + use std::sync::atomic::{AtomicUsize, Ordering}; + use tempfile::TempDir; struct ActionRequiredProvider { handled: tokio::sync::Mutex>, @@ -2907,6 +3049,252 @@ mod tests { assert_eq!(conf.permission, crate::permission::Permission::AllowOnce); } + const ALWAYS_BLOCK_SCRIPT: &str = r#"#!/bin/sh +echo blocked >> "$PLUGIN_ROOT/hook.log" +echo "always block" >&2 +exit 2 +"#; + + const ALTERNATE_BLOCK_ALLOW_SCRIPT: &str = r#"#!/bin/sh +count_file="$PLUGIN_ROOT/count" +count=0 +if [ -f "$count_file" ]; then + count=$(cat "$count_file") +fi +count=$((count + 1)) +echo "$count" > "$count_file" +echo "$count" >> "$PLUGIN_ROOT/hook.log" +if [ $((count % 2)) -eq 1 ]; then + echo "block $count" >&2 + exit 2 +fi +exit 0 +"#; + + struct StopHookTestEnv { + temp_dir: TempDir, + hook_log: PathBuf, + } + + impl StopHookTestEnv { + fn new(script: &str) -> Result { + let temp_dir = tempfile::tempdir()?; + let plugin_dir = temp_dir.path().join("stop-blocker"); + std::fs::create_dir_all(plugin_dir.join("hooks"))?; + std::fs::write( + plugin_dir.join("hooks/hooks.json"), + r#"{ + "hooks": { + "Stop": [ + { + "hooks": [ + { "type": "command", "command": "sh ${PLUGIN_ROOT}/block.sh" } + ] + } + ] + } +} +"#, + )?; + std::fs::write(plugin_dir.join("block.sh"), script)?; + + Ok(Self { + temp_dir, + hook_log: plugin_dir.join("hook.log"), + }) + } + + fn hook_manager(&self) -> crate::hooks::HookManager { + crate::hooks::HookManager::from_plugins_for_test(vec![DiscoveredPlugin { + name: "stop-blocker".into(), + root: self.temp_dir.path().join("stop-blocker"), + scope: PluginScope::Project, + }]) + } + + fn data_dir(&self) -> PathBuf { + self.temp_dir.path().join("data") + } + + fn hook_invocations(&self) -> usize { + std::fs::read_to_string(&self.hook_log) + .unwrap_or_default() + .lines() + .count() + } + } + + struct CountingTextProvider { + call_count: AtomicUsize, + } + + impl CountingTextProvider { + fn new() -> Self { + Self { + call_count: AtomicUsize::new(0), + } + } + + fn call_count(&self) -> usize { + self.call_count.load(Ordering::SeqCst) + } + } + + #[async_trait::async_trait] + impl crate::providers::base::Provider for CountingTextProvider { + async fn stream( + &self, + _model_config: &crate::model::ModelConfig, + _session_id: &str, + _system_prompt: &str, + _messages: &[Message], + _tools: &[Tool], + ) -> Result { + let call = self.call_count.fetch_add(1, Ordering::SeqCst); + let message = Message::assistant().with_text(format!("provider response {call}")); + let usage = ProviderUsage::new("mock-model".to_string(), Usage::default()); + Ok(stream_from_single_message(message, usage)) + } + + fn get_model_config(&self) -> crate::model::ModelConfig { + crate::model::ModelConfig::new("mock-model").unwrap() + } + + fn get_name(&self) -> &str { + "counting-text" + } + } + + async fn create_stop_hook_test_agent( + env: &StopHookTestEnv, + stop_hook_block_cap: u32, + ) -> Result<(Agent, String, Arc)> { + let session_manager = Arc::new(SessionManager::new(env.data_dir())); + let permission_manager = Arc::new(PermissionManager::new(env.data_dir())); + let config = AgentConfig::new( + session_manager.clone(), + permission_manager, + None, + GooseMode::Auto, + true, + GoosePlatform::GooseCli, + ); + let mut agent = Agent::with_config(config); + agent.set_hook_manager_for_test(env.hook_manager()); + agent.set_stop_hook_block_cap_for_test(stop_hook_block_cap); + let provider = Arc::new(CountingTextProvider::new()); + let session = session_manager + .create_session( + PathBuf::default(), + "stop-hook-test".to_string(), + SessionType::Hidden, + GooseMode::Auto, + ) + .await?; + agent.update_provider(provider.clone(), &session.id).await?; + Ok((agent, session.id, provider)) + } + + async fn run_stop_hook_test_turn( + agent: &Agent, + session_id: &str, + text: &str, + ) -> Result> { + let session_config = SessionConfig { + id: session_id.to_string(), + schedule_id: None, + max_turns: Some(10), + retry_config: None, + }; + let reply_stream = agent + .reply(Message::user().with_text(text), session_config, None) + .await?; + tokio::pin!(reply_stream); + + let mut messages = Vec::new(); + while let Some(event) = reply_stream.next().await { + match event? { + AgentEvent::Message(message) => messages.push(message), + AgentEvent::McpNotification(_) | AgentEvent::HistoryReplaced(_) => {} + } + } + Ok(messages) + } + + fn visible_texts(messages: &[Message]) -> Vec { + messages + .iter() + .map(Message::as_concat_text) + .filter(|text| !text.is_empty()) + .collect() + } + + #[tokio::test] + async fn stop_hook_block_cap_allows_configured_consecutive_blocks_then_overrides() -> Result<()> + { + let env = StopHookTestEnv::new(ALWAYS_BLOCK_SCRIPT)?; + let (agent, session_id, provider) = create_stop_hook_test_agent(&env, 2).await?; + + let messages = run_stop_hook_test_turn(&agent, &session_id, "hello").await?; + let texts = visible_texts(&messages); + + assert_eq!( + provider.call_count(), + 3, + "cap=2 should allow two blocked retries, then override on the third block" + ); + assert_eq!( + env.hook_invocations(), + 3, + "Stop hook should run for the initial response plus the two honored retries" + ); + assert!(texts.iter().any(|text| text == "provider response 0")); + assert!(texts.iter().any(|text| text == "provider response 1")); + assert!(texts.iter().any(|text| text == "provider response 2")); + assert!(messages.iter().any(|message| { + message.content.iter().any(|content| { + matches!( + content, + MessageContent::SystemNotification(notification) + if notification.msg.contains("more than 2 consecutive times") + && notification.msg.contains("GOOSE_STOP_HOOK_BLOCK_CAP") + ) + }) + })); + + Ok(()) + } + + #[tokio::test] + async fn stop_hook_block_cap_counts_only_consecutive_blocks() -> Result<()> { + let env = StopHookTestEnv::new(ALTERNATE_BLOCK_ALLOW_SCRIPT)?; + let (agent, session_id, provider) = create_stop_hook_test_agent(&env, 1).await?; + + let first_turn = run_stop_hook_test_turn(&agent, &session_id, "first").await?; + let second_turn = run_stop_hook_test_turn(&agent, &session_id, "second").await?; + let mut texts = visible_texts(&first_turn); + texts.extend(visible_texts(&second_turn)); + + assert_eq!( + provider.call_count(), + 4, + "each turn should honor one block, retry, then stop when the next Stop hook allows" + ); + assert_eq!(env.hook_invocations(), 4); + assert!(texts.iter().any(|text| text == "provider response 0")); + assert!(texts.iter().any(|text| text == "provider response 1")); + assert!(texts.iter().any(|text| text == "provider response 2")); + assert!(texts.iter().any(|text| text == "provider response 3")); + assert!( + !texts + .iter() + .any(|text| text.contains("overriding and ending turn")), + "non-consecutive Stop hook blocks should not trip the cap warning" + ); + + Ok(()) + } + #[tokio::test] async fn test_add_final_output_tool() -> Result<()> { let agent = Agent::new(); diff --git a/crates/goose/src/hooks/mod.rs b/crates/goose/src/hooks/mod.rs index ae1a2d48bc..590b2adcf8 100644 --- a/crates/goose/src/hooks/mod.rs +++ b/crates/goose/src/hooks/mod.rs @@ -233,6 +233,11 @@ impl HookManager { Self::from_plugins(plugins) } + #[cfg(test)] + pub(crate) fn from_plugins_for_test(plugins: Vec) -> Self { + Self::from_plugins(plugins) + } + fn from_plugins(plugins: Vec) -> Self { let mut rules: HashMap> = HashMap::new(); let mut total = 0usize; @@ -620,6 +625,33 @@ mod tests { assert_eq!(written.trim(), root.to_string_lossy()); } + #[tokio::test] + async fn stop_hook_emit_blocking_returns_denial() { + let tmp = tempfile::tempdir().unwrap(); + let root = write_plugin( + tmp.path(), + "p", + r#"{"hooks":{"Stop":[{"hooks":[{"type":"command","command":"printf '%s' '{\"decision\":\"block\",\"reason\":\"say something first\"}'"}]}]}}"#, + ); + let mgr = make_manager(vec![DiscoveredPlugin { + name: "p".into(), + root, + scope: PluginScope::User, + }]); + + let decision = mgr + .emit_blocking(HookEvent::Stop, HookContext::new(HookEvent::Stop, "s")) + .await; + + assert_eq!( + decision, + HookDecision::Deny { + reason: "say something first".into(), + plugin: "p".into(), + } + ); + } + #[tokio::test] async fn matcher_filters_by_tool_name() { let tmp = tempfile::tempdir().unwrap();