From 7fc3537751350c564ee10c1a48ad36f4c982627b Mon Sep 17 00:00:00 2001 From: Eugenio <292452+eugenio@users.noreply.github.com> Date: Thu, 14 May 2026 07:49:37 +0200 Subject: [PATCH] fix: prevent tool-use marker leakage in toolshim output (#8310) Signed-off-by: Eugenio La Cava Signed-off-by: Michael Neale Co-authored-by: Claude Opus 4.6 (1M context) Co-authored-by: Michael Neale --- .gitignore | 6 + crates/goose-cli/src/main.rs | 21 +- crates/goose/src/agents/reply_parts.rs | 87 +- crates/goose/src/providers/toolshim.rs | 1016 +++++++++++++++++++++++- crates/goose/src/providers/utils.rs | 79 +- ui/desktop/src/types/message.ts | 12 +- 6 files changed, 1161 insertions(+), 60 deletions(-) diff --git a/.gitignore b/.gitignore index 23f1a06202..b82cf5c753 100644 --- a/.gitignore +++ b/.gitignore @@ -64,6 +64,12 @@ do_not_version/ /working_dir +# Local build scripts and generated snapshot artifacts +/build.bat +/build.ps1 +/build_check.ps1 +/crates/goose/src/agents/snapshots/*.snap.new + # Error log artifacts from mcp replay tests crates/goose/tests/mcp_replays/*errors.txt diff --git a/crates/goose-cli/src/main.rs b/crates/goose-cli/src/main.rs index 3530f5a37c..51a7eccafb 100644 --- a/crates/goose-cli/src/main.rs +++ b/crates/goose-cli/src/main.rs @@ -1,8 +1,7 @@ use anyhow::Result; use goose_cli::cli::cli; -#[tokio::main] -async fn main() -> Result<()> { +async fn run() -> Result<()> { if let Err(e) = goose_cli::logging::setup_logging(None) { eprintln!("Warning: Failed to initialize logging: {}", e); } @@ -17,3 +16,21 @@ async fn main() -> Result<()> { result } + +fn main() -> Result<()> { + let handle = std::thread::Builder::new() + .name("goose-cli-main".to_string()) + .stack_size(8 * 1024 * 1024) + .spawn(|| { + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .expect("Failed to build Tokio runtime"); + runtime.block_on(run()) + }) + .map_err(|e| anyhow::anyhow!("Failed to spawn goose-cli main thread: {}", e))?; + + handle + .join() + .map_err(|_| anyhow::anyhow!("goose-cli main thread panicked"))? +} diff --git a/crates/goose/src/agents/reply_parts.rs b/crates/goose/src/agents/reply_parts.rs index 07e65e18c7..f4939bc9f4 100644 --- a/crates/goose/src/agents/reply_parts.rs +++ b/crates/goose/src/agents/reply_parts.rs @@ -17,10 +17,11 @@ use crate::providers::base::stream_from_single_message; use crate::providers::base::{MessageStream, Provider, ProviderUsage}; use crate::providers::errors::ProviderError; use crate::providers::toolshim::{ - augment_message_with_tool_calls, convert_tool_messages_to_text, - modify_system_prompt_for_tool_json, OllamaInterpreter, + augment_message_with_selected_tool_interpreter, convert_tool_messages_to_text, + modify_system_prompt_for_tool_json, sanitize_residual_markers, }; use rmcp::model::Tool; +use tracing::warn; async fn enhance_model_error(error: ProviderError, provider: &Arc) -> ProviderError { let ProviderError::RequestFailed(ref msg) = error else { @@ -123,13 +124,16 @@ async fn toolshim_postprocess( response: Message, toolshim_tools: &[Tool], ) -> Result { - let interpreter = OllamaInterpreter::new().map_err(|e| { - ProviderError::ExecutionError(format!("Failed to create OllamaInterpreter: {}", e)) - })?; - - augment_message_with_tool_calls(&interpreter, response, toolshim_tools) - .await - .map_err(|e| ProviderError::ExecutionError(format!("Failed to augment message: {}", e))) + match augment_message_with_selected_tool_interpreter(response.clone(), toolshim_tools).await { + Ok(message) => Ok(message), + Err(e) => { + warn!( + "Toolshim augmentation failed, skipping tool augmentation: {}", + e + ); + Ok(sanitize_residual_markers(response)) + } + } } impl Agent { @@ -302,20 +306,67 @@ impl Agent { }; Ok(Box::pin(try_stream! { - while let Some(result) = stream.next().await { - let (mut message, usage) = result?; + if config.toolshim { + // Toolshim mode: accumulate the full response before processing + // so that tool-use markers spanning multiple chunks are detected + // and stripped before any output reaches the UI. + let mut accumulated_message: Option = None; + let mut final_usage: Option = None; - // Store the model information in the global store - if let Some(usage) = usage.as_ref() { - crate::providers::base::set_current_model(&usage.model); + while let Some(result) = stream.next().await { + let (msg_opt, usage_opt) = result?; + + if let Some(usage) = usage_opt.as_ref() { + crate::providers::base::set_current_model(&usage.model); + } + + if let Some(msg) = msg_opt { + accumulated_message = Some(match accumulated_message { + Some(mut prev) => { + for new_content in msg.content { + match (&mut prev.content.last_mut(), &new_content) { + ( + Some(MessageContent::Text(last_text)), + MessageContent::Text(new_text), + ) => { + last_text.text.push_str(&new_text.text); + } + _ => { + prev.content.push(new_content); + } + } + } + prev + } + None => msg, + }); + } + + if let Some(usage) = usage_opt { + final_usage = Some(usage); + } + + // Yield empty item so the agent loop can check cancellation + yield (None, None); } - // Post-process / structure the response only if tool interpretation is enabled - if message.is_some() && config.toolshim { - message = Some(toolshim_postprocess(message.unwrap(), &toolshim_tools).await?); + if let Some(msg) = accumulated_message { + let processed = toolshim_postprocess(msg, &toolshim_tools).await?; + yield (Some(processed), final_usage); + } else if final_usage.is_some() { + // Preserve usage-only responses (no message content) + yield (None, final_usage); } + } else { + while let Some(result) = stream.next().await { + let (message, usage) = result?; - yield (message, usage); + if let Some(usage) = usage.as_ref() { + crate::providers::base::set_current_model(&usage.model); + } + + yield (message, usage); + } } })) } diff --git a/crates/goose/src/providers/toolshim.rs b/crates/goose/src/providers/toolshim.rs index d0be2dd007..92c174d628 100644 --- a/crates/goose/src/providers/toolshim.rs +++ b/crates/goose/src/providers/toolshim.rs @@ -31,6 +31,7 @@ //! use super::errors::ProviderError; +use super::local_inference::LOCAL_LLM_MODEL_CONFIG_KEY; use super::ollama::OLLAMA_DEFAULT_PORT; use super::ollama::OLLAMA_HOST; use crate::conversation::message::{Message, MessageContent}; @@ -39,6 +40,7 @@ use crate::model::ModelConfig; use crate::providers::base::DEFAULT_PROVIDER_TIMEOUT_SECS; use crate::providers::formats::openai::create_request; use anyhow::Result; +use futures::StreamExt; use reqwest::Client; use rmcp::model::{object, CallToolRequestParams, RawContent, Tool}; use serde_json::{json, Value}; @@ -48,6 +50,481 @@ use uuid::Uuid; /// Default model to use for tool interpretation pub const DEFAULT_INTERPRETER_MODEL_OLLAMA: &str = "mistral-nemo"; +pub const TOOLSHIM_BACKEND_ENV_VAR: &str = "GOOSE_TOOLSHIM_BACKEND"; +pub const TOOLSHIM_LOCAL_MODEL_ENV_VAR: &str = "GOOSE_TOOLSHIM_MODEL"; + +const TOOL_CALLS_SECTION_BEGIN: &str = "<|tool_calls_section_begin|>"; +const TOOL_CALLS_SECTION_END: &str = "<|tool_calls_section_end|>"; +const TOOL_CALL_BEGIN: &str = "<|tool_call_begin|>"; +const TOOL_CALL_ARGUMENT_BEGIN: &str = "<|tool_call_argument_begin|>"; +const TOOL_CALL_ARGUMENT_END: &str = "<|tool_call_argument_end|>"; +const TOOL_CALL_END: &str = "<|tool_call_end|>"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum ToolshimBackend { + Ollama, + Local, +} + +fn parse_toolshim_backend(value: &str) -> Result { + match value.trim().to_ascii_lowercase().as_str() { + "" | "ollama" => Ok(ToolshimBackend::Ollama), + "local" | "llama.cpp" | "llama_cpp" => Ok(ToolshimBackend::Local), + other => Err(ProviderError::RequestFailed(format!( + "Invalid {} value '{}'. Expected one of: ollama, local, llama.cpp", + TOOLSHIM_BACKEND_ENV_VAR, other + ))), + } +} + +fn get_toolshim_backend() -> Result { + match std::env::var(TOOLSHIM_BACKEND_ENV_VAR) { + Ok(value) => parse_toolshim_backend(&value), + Err(_) => Ok(ToolshimBackend::Ollama), + } +} + +fn resolve_local_interpreter_model() -> Result { + let env_model = std::env::var(TOOLSHIM_LOCAL_MODEL_ENV_VAR) + .ok() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()); + let config_model = crate::config::Config::global() + .get_param::(LOCAL_LLM_MODEL_CONFIG_KEY) + .ok() + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()); + + resolve_local_interpreter_model_from_sources(env_model, config_model) +} + +fn resolve_local_interpreter_model_from_sources( + env_model: Option, + config_model: Option, +) -> Result { + env_model.or(config_model).ok_or_else(|| { + ProviderError::RequestFailed(format!( + "Local toolshim backend requires {} or {} to be set", + TOOLSHIM_LOCAL_MODEL_ENV_VAR, LOCAL_LLM_MODEL_CONFIG_KEY + )) + }) +} + +fn resolve_tool_name(raw_tool_name: &str, tools: &[Tool]) -> Option { + let trimmed = raw_tool_name.trim(); + let without_index = trimmed.split(':').next().unwrap_or(trimmed).trim(); + let without_functions_prefix = without_index + .strip_prefix("functions.") + .unwrap_or(without_index) + .trim(); + let short_name = without_functions_prefix + .rsplit('.') + .next() + .unwrap_or(without_functions_prefix) + .trim(); + + // Also try replacing dots with double-underscores (goose tool name convention) + let with_dunder = without_functions_prefix.replace('.', "__"); + + let mut candidates = vec![ + trimmed.to_string(), + without_index.to_string(), + without_functions_prefix.to_string(), + with_dunder, + short_name.to_string(), + ]; + candidates.dedup(); + + for candidate in &candidates { + if tools.iter().any(|tool| tool.name == *candidate) { + return Some(candidate.clone()); + } + } + + for candidate in &candidates { + let mut matches: Vec = tools + .iter() + .filter(|tool| tool.name.ends_with(&format!("__{}", candidate))) + .map(|tool| tool.name.to_string()) + .collect(); + matches.sort(); + matches.dedup(); + + if matches.len() == 1 { + return Some(matches[0].clone()); + } + } + + None +} + +fn normalized_tool_alias(raw_tool_name: &str) -> String { + let trimmed = raw_tool_name.trim(); + let without_index = trimmed.split(':').next().unwrap_or(trimmed).trim(); + let without_functions_prefix = without_index + .strip_prefix("functions.") + .unwrap_or(without_index) + .trim(); + + without_functions_prefix + .rsplit('.') + .next() + .unwrap_or(without_functions_prefix) + .trim() + .to_ascii_lowercase() +} + +#[allow(clippy::string_slice)] // All markers/delimiters are ASCII; byte indexing is safe. +fn extract_shell_command_from_execute_code(code: &str) -> Option { + let marker = "command"; + let marker_idx = code.find(marker)?; + let after_marker = &code[marker_idx + marker.len()..]; + let colon_idx = after_marker.find(':')?; + let after_colon = after_marker[colon_idx + 1..].trim_start(); + + let quote = after_colon.chars().next()?; + if quote != '"' && quote != '\'' { + return None; + } + + let mut escaped = false; + let mut command = String::new(); + for ch in after_colon[1..].chars() { + if escaped { + command.push(ch); + escaped = false; + continue; + } + + if ch == '\\' { + escaped = true; + continue; + } + + if ch == quote { + return Some(command); + } + + command.push(ch); + } + + None +} + +fn maybe_convert_execute_to_shell_tool_call( + raw_tool_name: &str, + arguments_value: &Value, + tools: &[Tool], +) -> Option { + let alias = normalized_tool_alias(raw_tool_name); + if alias != "execute" && alias != "execute_code" { + return None; + } + + let shell_tool_name = resolve_tool_name("shell", tools)?; + let code = arguments_value.get("code")?.as_str()?; + let command = extract_shell_command_from_execute_code(code)?; + + let shell_args = json!({ "command": command }); + Some(CallToolRequestParams::new(shell_tool_name).with_arguments(object(shell_args))) +} + +fn escape_invalid_backslashes_in_json_strings(input: &str) -> String { + let mut out = String::with_capacity(input.len() + 8); + let mut in_string = false; + let mut escaped = false; + + for ch in input.chars() { + if in_string { + if escaped { + if !matches!(ch, '"' | '\\' | '/' | 'b' | 'f' | 'n' | 'r' | 't' | 'u') { + out.push('\\'); + } + out.push(ch); + escaped = false; + continue; + } + + match ch { + '\\' => { + out.push('\\'); + escaped = true; + } + '"' => { + out.push('"'); + in_string = false; + } + _ => out.push(ch), + } + continue; + } + + if ch == '"' { + in_string = true; + } + out.push(ch); + } + + if escaped { + out.push('\\'); + } + + out +} + +fn parse_json_value_tolerant(input: &str) -> Option { + serde_json::from_str::(input).ok().or_else(|| { + let escaped = escape_invalid_backslashes_in_json_strings(input); + serde_json::from_str::(&escaped).ok() + }) +} + +#[allow(clippy::string_slice)] // All markers are ASCII; byte indexing is safe. +fn parse_tokenized_tool_calls(content: &str, tools: &[Tool]) -> Vec { + let mut calls = Vec::new(); + let mut remainder = content; + + while let Some(begin_idx) = remainder.find(TOOL_CALL_BEGIN) { + let after_begin = &remainder[begin_idx + TOOL_CALL_BEGIN.len()..]; + + // Find the end of this tool call first + let Some(call_end_offset) = after_begin.find(TOOL_CALL_END) else { + break; + }; + let call_body = &after_begin[..call_end_offset]; + + // Try standard format: name <|tool_call_argument_begin|> {json} + // Fall back to: name {json} (no argument marker) + let (raw_tool_name, raw_args) = + if let Some(arg_idx) = call_body.find(TOOL_CALL_ARGUMENT_BEGIN) { + let name = call_body[..arg_idx].trim(); + let args = call_body[arg_idx + TOOL_CALL_ARGUMENT_BEGIN.len()..].trim(); + (name, args) + } else if let Some(json_start) = call_body.find('{') { + let name = call_body[..json_start].trim(); + let args = call_body[json_start..].trim(); + (name, args) + } else { + remainder = &after_begin[call_end_offset + TOOL_CALL_END.len()..]; + continue; + }; + + if let Some(arguments_value) = parse_json_value_tolerant(raw_args) { + if let Some(tool_name) = resolve_tool_name(raw_tool_name, tools) { + if arguments_value.is_object() { + calls.push( + CallToolRequestParams::new(tool_name) + .with_arguments(object(arguments_value.clone())), + ); + } + } else if let Some(shell_call) = + maybe_convert_execute_to_shell_tool_call(raw_tool_name, &arguments_value, tools) + { + calls.push(shell_call); + } + } + + remainder = &after_begin[call_end_offset + TOOL_CALL_END.len()..]; + } + + calls +} + +#[allow(clippy::string_slice)] // Indices come from char_indices(); slicing is safe. +fn extract_first_json_object(input: &str) -> Option<(&str, usize)> { + if !input.starts_with('{') { + return None; + } + + let mut depth = 0usize; + let mut in_string = false; + let mut escaped = false; + + for (idx, ch) in input.char_indices() { + if in_string { + if escaped { + escaped = false; + continue; + } + match ch { + '\\' => escaped = true, + '"' => in_string = false, + _ => {} + } + continue; + } + + match ch { + '"' => in_string = true, + '{' => depth += 1, + '}' => { + depth = depth.saturating_sub(1); + if depth == 0 { + let end = idx + ch.len_utf8(); + return Some((&input[..end], end)); + } + } + _ => {} + } + } + + None +} + +#[allow(clippy::string_slice)] // Indices from find('{') on ASCII; byte slicing is safe. +fn parse_inline_json_tool_calls(content: &str, tools: &[Tool]) -> Vec { + let mut calls = Vec::new(); + let mut remainder = content; + + while let Some(start_idx) = remainder.find('{') { + let maybe_json = &remainder[start_idx..]; + let Some((json_obj, consumed_len)) = extract_first_json_object(maybe_json) else { + break; + }; + + if let Some(value) = parse_json_value_tolerant(json_obj) { + let maybe_name = value.get("name").and_then(Value::as_str); + let maybe_args = value.get("arguments").and_then(Value::as_object); + if let (Some(raw_name), Some(arguments)) = (maybe_name, maybe_args) { + if let Some(tool_name) = resolve_tool_name(raw_name, tools) { + calls.push( + CallToolRequestParams::new(tool_name).with_arguments(arguments.clone()), + ); + } + } + } + + remainder = &maybe_json[consumed_len..]; + } + + calls +} + +#[allow(clippy::string_slice)] // Marker constants are ASCII; byte indexing is safe. +fn strip_tokenized_tool_markup(content: &str) -> String { + let mut stripped = content.to_string(); + + while let Some(section_start) = stripped.find(TOOL_CALLS_SECTION_BEGIN) { + let after_start = section_start + TOOL_CALLS_SECTION_BEGIN.len(); + if let Some(section_end_rel) = stripped[after_start..].find(TOOL_CALLS_SECTION_END) { + let section_end = after_start + section_end_rel + TOOL_CALLS_SECTION_END.len(); + stripped.replace_range(section_start..section_end, ""); + } else { + stripped.replace_range(section_start..stripped.len(), ""); + break; + } + } + + for marker in [ + TOOL_CALL_BEGIN, + TOOL_CALL_ARGUMENT_BEGIN, + TOOL_CALL_ARGUMENT_END, + TOOL_CALL_END, + TOOL_CALLS_SECTION_BEGIN, + TOOL_CALLS_SECTION_END, + ] { + stripped = stripped.replace(marker, " "); + } + + stripped.trim().to_string() +} + +fn append_tool_calls_to_message( + mut message: Message, + tool_calls: Vec, +) -> Message { + for tool_call in tool_calls { + if tool_call.name != "noop" { + let id = Uuid::new_v4().to_string(); + message = message.with_tool_request(id, Ok(tool_call)); + } + } + message +} + +fn sanitize_message_after_tokenized_parse(mut message: Message) -> Message { + for content in &mut message.content { + if let MessageContent::Text(text) = content { + text.text = strip_tokenized_tool_markup(&text.text); + } + } + + message.content.retain(|content| match content { + MessageContent::Text(text) => !text.text.trim().is_empty(), + _ => true, + }); + + message +} + +fn sanitize_message_after_json_tool_parse(mut message: Message) -> Message { + for content in &mut message.content { + if let MessageContent::Text(text) = content { + let lower = text.text.to_ascii_lowercase(); + let looks_like_tool_directive = lower.contains("using tool:") + || (text.text.contains("\"name\"") && text.text.contains("\"arguments\"")); + + if looks_like_tool_directive { + text.text.clear(); + } + } + } + + message.content.retain(|content| match content { + MessageContent::Text(text) => !text.text.trim().is_empty(), + _ => true, + }); + + message +} + +/// Returns `true` if the text contains any raw tool-use markers that should +/// never appear in final assistant output. +fn has_tool_markers(text: &str) -> bool { + let lower = text.to_ascii_lowercase(); + for marker in [ + TOOL_CALLS_SECTION_BEGIN, + TOOL_CALLS_SECTION_END, + TOOL_CALL_BEGIN, + TOOL_CALL_ARGUMENT_BEGIN, + TOOL_CALL_ARGUMENT_END, + TOOL_CALL_END, + ] { + if text.contains(marker) { + return true; + } + } + lower.contains("using tool:") || (text.contains("\"name\"") && text.contains("\"arguments\"")) +} + +/// Catch-all sanitization applied to every message leaving the toolshim +/// pipeline, regardless of whether tool-call parsing succeeded. +pub fn sanitize_residual_markers(mut message: Message) -> Message { + let mut changed = false; + for content in &mut message.content { + if let MessageContent::Text(text) = content { + if has_tool_markers(&text.text) { + // Strip tokenized markers first (handles section blocks) + text.text = strip_tokenized_tool_markup(&text.text); + // Then clear any remaining JSON-style tool directives + let lower = text.text.to_ascii_lowercase(); + if lower.contains("using tool:") + || (text.text.contains("\"name\"") && text.text.contains("\"arguments\"")) + { + text.text.clear(); + } + changed = true; + } + } + } + if changed { + message.content.retain(|content| match content { + MessageContent::Text(text) => !text.text.trim().is_empty(), + _ => true, + }); + } + message +} /// Environment variables that affect behavior: /// - GOOSE_TOOLSHIM: When set to "true" or "1", enables using the tool shim in the standard OllamaProvider (default: false) @@ -69,6 +546,63 @@ pub struct OllamaInterpreter { base_url: String, } +/// Local llama.cpp implementation of the ToolInterpreter trait. +pub struct LocalInterpreter { + model: String, +} + +impl LocalInterpreter { + pub fn new() -> Result { + Ok(Self { + model: resolve_local_interpreter_model()?, + }) + } + + async fn infer_structured_response( + &self, + format_instruction: &str, + ) -> Result { + let model_config = ModelConfig::new(&self.model) + .map_err(|e| ProviderError::RequestFailed(format!("Model config error: {e}")))? + .with_canonical_limits("local") + .with_toolshim(false) + .with_toolshim_model(None); + + let provider = crate::providers::init::create("local", model_config, vec![]) + .await + .map_err(|e| { + ProviderError::RequestFailed(format!( + "Failed to create local interpreter provider: {e}" + )) + })?; + + let request_messages = vec![Message::user().with_text(format_instruction)]; + let mut stream = provider + .stream( + &provider.get_model_config(), + "toolshim-local", + "", + &request_messages, + &[], + ) + .await?; + + let mut content = String::new(); + while let Some(chunk) = stream.next().await { + let (message, _) = chunk?; + if let Some(message) = message { + for part in message.content { + if let MessageContent::Text(text) = part { + content.push_str(&text.text); + } + } + } + } + + Ok(content) + } +} + impl OllamaInterpreter { pub fn new() -> Result { let client = Client::builder() @@ -300,6 +834,49 @@ Otherwise, if no JSON tool requests are provided, use the no-op tool: } } +#[async_trait::async_trait] +impl ToolInterpreter for LocalInterpreter { + async fn interpret_to_tool_calls( + &self, + last_assistant_msg: &str, + tools: &[Tool], + ) -> Result, ProviderError> { + if tools.is_empty() { + return Ok(vec![]); + } + + let system_prompt = "If there is detectable JSON-formatted tool requests, write them into valid JSON tool calls in the following format: +{{ + \"tool_calls\": [ + {{ + \"name\": \"tool_name\", + \"arguments\": {{ + \"param1\": \"value1\", + \"param2\": \"value2\" + }} + }} + ] +}} + +Otherwise, if no JSON tool requests are provided, use the no-op tool: +{{ + \"tool_calls\": [ + {{ + \"name\": \"noop\", + \"arguments\": {{ + }} + }}] +}} +"; + + let format_instruction = format!("{}\nRequest: {}\n\n", system_prompt, last_assistant_msg); + let content = self.infer_structured_response(&format_instruction).await?; + let response = json!({ "message": { "content": content } }); + + OllamaInterpreter::process_interpreter_response(&response) + } +} + /// Creates a string containing formatted tool information pub fn format_tool_info(tools: &[Tool]) -> String { let mut tool_info = String::new(); @@ -385,8 +962,7 @@ pub fn modify_system_prompt_for_tool_json(system_prompt: &str, tools: &[Tool]) - format!( "{}\n\n{}\n\nBreak down your task into smaller steps and do one step and tool call at a time. Do not try to use multiple tools at once. If you want to use a tool, tell the user what tool to use by specifying the tool in this JSON format\n{{\n \"name\": \"tool_name\",\n \"arguments\": {{\n \"parameter1\": \"value1\",\n \"parameter2\": \"value2\"\n }}\n}}. After you get the tool result back, consider the result and then proceed to do the next step and tool call if required.", - system_prompt, - tool_info + system_prompt, tool_info ) } @@ -401,47 +977,423 @@ pub async fn augment_message_with_tool_calls( return Ok(message); } - // Extract content from the message - let content_opt = message.content.iter().find_map(|content| { - if let MessageContent::Text(text) = content { - Some(text.text.as_str()) - } else { - None - } - }); - - // If there's no text content or it's already a tool request, return the original message - let content = match content_opt { - Some(text) => text, - None => return Ok(message), - }; - - // Check if there's already a tool request - if message + // Extract and combine all text content blocks from the message. + let content = message .content .iter() - .any(|content| matches!(content, MessageContent::ToolRequest(_))) - { + .filter_map(|content| { + if let MessageContent::Text(text) = content { + Some(text.text.as_str()) + } else { + None + } + }) + .collect::>() + .join("\n"); + + if content.trim().is_empty() { return Ok(message); } + let has_existing_tool_request = message + .content + .iter() + .any(|content| matches!(content, MessageContent::ToolRequest(_))); + + let direct_tool_calls = parse_tokenized_tool_calls(&content, tools); + if !direct_tool_calls.is_empty() { + let cleaned = sanitize_message_after_tokenized_parse(message); + return Ok(append_tool_calls_to_message(cleaned, direct_tool_calls)); + } + + let inline_json_tool_calls = parse_inline_json_tool_calls(&content, tools); + if !inline_json_tool_calls.is_empty() { + let cleaned = sanitize_message_after_json_tool_parse(message); + return Ok(append_tool_calls_to_message( + cleaned, + inline_json_tool_calls, + )); + } + + if has_existing_tool_request { + return Ok(sanitize_residual_markers(message)); + } + // Use the interpreter to convert the content to tool calls - let tool_calls = interpreter.interpret_to_tool_calls(content, tools).await?; + let tool_calls = interpreter.interpret_to_tool_calls(&content, tools).await?; - // If no tool calls were detected, return the original message + // If no tool calls were detected, sanitize any residual markers if tool_calls.is_empty() { - return Ok(message); + return Ok(sanitize_residual_markers(message)); } - // Add each tool call to the message - let mut final_message = message; - for tool_call in tool_calls { - if tool_call.name != "noop" { - // do not actually execute noop tool - let id = Uuid::new_v4().to_string(); - final_message = final_message.with_tool_request(id, Ok(tool_call)); + Ok(sanitize_residual_markers(append_tool_calls_to_message( + message, tool_calls, + ))) +} + +pub async fn augment_message_with_selected_tool_interpreter( + message: Message, + tools: &[Tool], +) -> Result { + match get_toolshim_backend()? { + ToolshimBackend::Ollama => { + let interpreter = OllamaInterpreter::new()?; + augment_message_with_tool_calls(&interpreter, message, tools).await + } + ToolshimBackend::Local => { + let interpreter = LocalInterpreter::new()?; + augment_message_with_tool_calls(&interpreter, message, tools).await + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct FailingInterpreter; + + #[async_trait::async_trait] + impl ToolInterpreter for FailingInterpreter { + async fn interpret_to_tool_calls( + &self, + _content: &str, + _tools: &[Tool], + ) -> Result, ProviderError> { + Err(ProviderError::RequestFailed( + "interpreter should not be called".to_string(), + )) } } - Ok(final_message) + #[test] + fn parses_toolshim_backend_values() { + assert_eq!( + parse_toolshim_backend("ollama").unwrap(), + ToolshimBackend::Ollama + ); + assert_eq!( + parse_toolshim_backend("local").unwrap(), + ToolshimBackend::Local + ); + assert_eq!( + parse_toolshim_backend("llama.cpp").unwrap(), + ToolshimBackend::Local + ); + assert!(parse_toolshim_backend("something-else").is_err()); + } + + #[test] + fn resolves_local_interpreter_model_prefers_env() { + let model = resolve_local_interpreter_model_from_sources( + Some("env-model".to_string()), + Some("config-model".to_string()), + ) + .unwrap(); + assert_eq!(model, "env-model"); + } + + #[test] + fn resolves_local_interpreter_model_uses_config_fallback() { + let model = + resolve_local_interpreter_model_from_sources(None, Some("config-model".to_string())) + .unwrap(); + assert_eq!(model, "config-model"); + } + + #[test] + fn resolves_local_interpreter_model_requires_source() { + assert!(resolve_local_interpreter_model_from_sources(None, None).is_err()); + } + + #[test] + fn parses_tokenized_tool_calls() { + let tools = vec![Tool::new( + "shell".to_string(), + "Shell command execution".to_string(), + serde_json::Map::new(), + )]; + + let content = "<|tool_calls_section_begin|> <|tool_call_begin|> functions.shell:0 <|tool_call_argument_begin|> {\"command\":\"cat Cargo.toml\"} <|tool_call_end|> <|tool_calls_section_end|>"; + let calls = parse_tokenized_tool_calls(content, &tools); + + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + assert_eq!( + calls[0] + .arguments + .as_ref() + .and_then(|a| a.get("command")) + .and_then(|v| v.as_str()), + Some("cat Cargo.toml") + ); + } + + #[test] + fn parses_execute_marker_and_converts_to_shell_call() { + let tools = vec![Tool::new( + "shell".to_string(), + "Shell command execution".to_string(), + serde_json::Map::new(), + )]; + + let content = "<|tool_calls_section_begin|> <|tool_call_begin|> functions.execute:0 <|tool_call_argument_begin|> {\"code\":\"async function run() { const result = await Developer.shell({ command: \\\"cat Cargo.toml\\\" }); return result; }\"} <|tool_call_end|> <|tool_calls_section_end|>"; + + let calls = parse_tokenized_tool_calls(content, &tools); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + assert_eq!( + calls[0] + .arguments + .as_ref() + .and_then(|a| a.get("command")) + .and_then(|v| v.as_str()), + Some("cat Cargo.toml") + ); + } + + #[test] + fn parses_inline_json_tool_directive() { + let tools = vec![Tool::new( + "shell".to_string(), + "Shell command execution".to_string(), + serde_json::Map::new(), + )]; + + let content = "Using tool: shell\n{\n \"name\": \"shell\",\n \"arguments\": {\n \"command\": \"type Cargo.toml\"\n }\n}"; + let calls = parse_inline_json_tool_calls(content, &tools); + + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + assert_eq!( + calls[0] + .arguments + .as_ref() + .and_then(|a| a.get("command")) + .and_then(|v| v.as_str()), + Some("type Cargo.toml") + ); + } + + #[test] + fn parses_tokenized_tool_call_with_windows_path_arguments() { + let tools = vec![Tool::new( + "tree".to_string(), + "Directory tree".to_string(), + serde_json::Map::new(), + )]; + + let content = "<|tool_calls_section_begin|> <|tool_call_begin|> functions.tree:0 <|tool_call_argument_begin|> {\"path\": \"C:\\Users\\eugen\\programmazione\\goose-fork\", \"depth\": 1} <|tool_call_end|> <|tool_calls_section_end|>"; + let calls = parse_tokenized_tool_calls(content, &tools); + + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "tree"); + assert_eq!( + calls[0] + .arguments + .as_ref() + .and_then(|a| a.get("path")) + .and_then(|v| v.as_str()), + Some("C:\\Users\\eugen\\programmazione\\goose-fork") + ); + } + + #[tokio::test] + async fn augment_uses_direct_tokenized_parser_before_interpreter() { + let tools = vec![Tool::new( + "shell".to_string(), + "Shell command execution".to_string(), + serde_json::Map::new(), + )]; + + let message = Message::assistant().with_text( + "<|tool_calls_section_begin|> <|tool_call_begin|> functions.shell:0 <|tool_call_argument_begin|> {\"command\":\"cat Cargo.toml\"} <|tool_call_end|> <|tool_calls_section_end|>", + ); + + let augmented = augment_message_with_tool_calls(&FailingInterpreter, message, &tools) + .await + .unwrap(); + + assert!(augmented + .content + .iter() + .any(|c| matches!(c, MessageContent::ToolRequest(_)))); + assert!(!augmented.as_concat_text().contains("<|tool_call_begin|>")); + } + + #[tokio::test] + async fn augment_parses_inline_json_even_with_existing_tool_request() { + let tools = vec![ + Tool::new( + "analyze".to_string(), + "Analyze files".to_string(), + serde_json::Map::new(), + ), + Tool::new( + "shell".to_string(), + "Shell command execution".to_string(), + serde_json::Map::new(), + ), + ]; + + let message = Message::assistant() + .with_tool_request("existing", Ok(CallToolRequestParams::new("analyze"))) + .with_text( + "Using tool: shell\n{\n \"name\": \"shell\",\n \"arguments\": {\n \"command\": \"type Cargo.toml\"\n }\n}", + ); + + let augmented = augment_message_with_tool_calls(&FailingInterpreter, message, &tools) + .await + .unwrap(); + + let tool_request_count = augmented + .content + .iter() + .filter(|c| matches!(c, MessageContent::ToolRequest(_))) + .count(); + assert_eq!(tool_request_count, 2); + } + + #[tokio::test] + async fn augment_parses_tokenized_tool_call_from_later_text_chunk() { + let tools = vec![Tool::new( + "shell".to_string(), + "Shell command execution".to_string(), + serde_json::Map::new(), + )]; + + let message = Message::assistant() + .with_text("I will inspect the file now.") + .with_text( + "<|tool_calls_section_begin|> <|tool_call_begin|> functions.shell:0 <|tool_call_argument_begin|> {\"command\":\"type Cargo.toml\"} <|tool_call_end|> <|tool_calls_section_end|>", + ); + + let augmented = augment_message_with_tool_calls(&FailingInterpreter, message, &tools) + .await + .unwrap(); + + assert!(augmented + .content + .iter() + .any(|c| matches!(c, MessageContent::ToolRequest(_)))); + } + + // ── Regression tests: malformed marker leakage ────────────────────── + + /// Malformed tokenized markers (incomplete/garbled) must be stripped + /// from the final text even when parsing yields zero tool calls. + #[tokio::test] + async fn malformed_tokenized_markers_stripped_from_text_output() { + let tools = vec![Tool::new( + "shell".to_string(), + "Shell command execution".to_string(), + serde_json::Map::new(), + )]; + + // Marker sequence is incomplete — no TOOL_CALL_ARGUMENT_BEGIN, + // so parse_tokenized_tool_calls returns empty. + let message = Message::assistant().with_text( + "Here is the result.\n<|tool_calls_section_begin|> <|tool_call_begin|> functions.shell:0 GARBAGE <|tool_call_end|> <|tool_calls_section_end|>", + ); + + // Use an interpreter that returns empty (simulates no-match fallback) + struct EmptyInterpreter; + #[async_trait::async_trait] + impl ToolInterpreter for EmptyInterpreter { + async fn interpret_to_tool_calls( + &self, + _content: &str, + _tools: &[Tool], + ) -> Result, ProviderError> { + Ok(vec![]) + } + } + + let result = augment_message_with_tool_calls(&EmptyInterpreter, message, &tools) + .await + .unwrap(); + + let text = result.as_concat_text(); + assert!( + !has_tool_markers(&text), + "Residual tokenized markers leaked into output: {text}" + ); + } + + /// Malformed JSON-style tool directives ("Using tool: …" without valid + /// JSON) must be stripped from the final text. + #[tokio::test] + async fn malformed_json_directive_stripped_from_text_output() { + let tools = vec![Tool::new( + "shell".to_string(), + "Shell command execution".to_string(), + serde_json::Map::new(), + )]; + + // "Using tool:" present but no valid JSON follows + let message = Message::assistant().with_text( + "I will run the command.\nUsing tool: shell\n{invalid json that won't parse}", + ); + + struct EmptyInterpreter; + #[async_trait::async_trait] + impl ToolInterpreter for EmptyInterpreter { + async fn interpret_to_tool_calls( + &self, + _content: &str, + _tools: &[Tool], + ) -> Result, ProviderError> { + Ok(vec![]) + } + } + + let result = augment_message_with_tool_calls(&EmptyInterpreter, message, &tools) + .await + .unwrap(); + + let text = result.as_concat_text(); + assert!( + !has_tool_markers(&text), + "Residual JSON tool directive leaked into output: {text}" + ); + } + + #[test] + fn has_tool_markers_detects_tokenized_markers() { + assert!(has_tool_markers("hello <|tool_calls_section_begin|> world")); + assert!(has_tool_markers("text <|tool_call_begin|> more")); + assert!(!has_tool_markers("clean assistant text with no markers")); + } + + #[test] + fn has_tool_markers_detects_json_directive() { + assert!(has_tool_markers("Using tool: shell\n{...}")); + assert!(has_tool_markers("blah \"name\" blah \"arguments\" blah")); + assert!(!has_tool_markers("just normal text mentioning a name")); + } + + #[test] + fn parses_tokenized_tool_call_without_argument_marker() { + let tools = vec![Tool::new( + "Nadirclawusage__usageSummary".to_string(), + "Usage summary".to_string(), + serde_json::Map::new(), + )]; + + // Model emits tool call without <|tool_call_argument_begin|> + let content = "<|tool_calls_section_begin|> <|tool_call_begin|> functions.Nadirclawusage.usageSummary:1 {\"period\": \"24h\"} <|tool_call_end|> <|tool_calls_section_end|>"; + let calls = parse_tokenized_tool_calls(content, &tools); + + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "Nadirclawusage__usageSummary"); + assert_eq!( + calls[0] + .arguments + .as_ref() + .and_then(|a| a.get("period")) + .and_then(|v| v.as_str()), + Some("24h") + ); + } } diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index 3a99b88e60..81d15f5fc0 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -479,7 +479,7 @@ impl Drop for RequestLog { /// Safely parse a JSON string that may contain doubly-encoded or malformed JSON. /// This function first attempts to parse the input string as-is. If that fails, -/// it applies control character escaping and tries again. +/// it applies control character escaping and truncated JSON repair and tries again. /// /// This approach preserves valid JSON like `{"key1": "value1",\n"key2": "value"}` /// (which contains a literal \n but is perfectly valid JSON) while still fixing @@ -490,13 +490,71 @@ pub fn safely_parse_json(s: &str) -> Result Ok(value), Err(_) => { - // If that fails, try with control character escaping - let escaped = json_escape_control_chars_in_string(s); - serde_json::from_str(&escaped) + for candidate in [ + repair_truncated_json(s), + json_escape_control_chars_in_string(s), + ] { + if let Ok(value) = serde_json::from_str(&candidate) { + return Ok(value); + } + } + + let repaired = repair_truncated_json(&json_escape_control_chars_in_string(s)); + serde_json::from_str(&repaired) } } } +fn repair_truncated_json(s: &str) -> String { + let mut repaired = String::with_capacity(s.len() + 8); + let mut in_string = false; + let mut escape_next = false; + let mut closers = Vec::new(); + + for c in s.chars() { + repaired.push(c); + + if in_string { + if escape_next { + escape_next = false; + continue; + } + + match c { + '\\' => escape_next = true, + '"' => in_string = false, + _ => {} + } + continue; + } + + match c { + '"' => in_string = true, + '{' => closers.push('}'), + '[' => closers.push(']'), + '}' | ']' => { + if closers.last() == Some(&c) { + closers.pop(); + } + } + _ => {} + } + } + + if in_string { + if escape_next { + repaired.push('\\'); + } + repaired.push('"'); + } + + while let Some(closer) = closers.pop() { + repaired.push(closer); + } + + repaired +} + /// Helper to escape control characters in a string that is supposed to be a JSON document. /// This function iterates through the input string `s` and replaces any literal /// control characters (U+0000 to U+001F) with their JSON-escaped equivalents @@ -809,9 +867,16 @@ mod tests { let result = safely_parse_json(good_json).unwrap(); assert_eq!(result["test"], "value"); - // Test completely invalid JSON that can't be fixed - let broken_json = r#"{"key": "unclosed_string"#; - assert!(safely_parse_json(broken_json).is_err()); + // Test truncated JSON with unclosed string, object, and array + let truncated_json = r#"{"key": "unclosed_string","nested": {"items": [1, 2, 3"#; + let result = safely_parse_json(truncated_json).unwrap(); + assert_eq!(result["key"], "unclosed_string"); + assert_eq!(result["nested"]["items"], json!([1, 2, 3])); + + // Test dangling backslash at end of a truncated string + let dangling_escape_json = String::from(r#"{"path":"abc\"#); + let result = safely_parse_json(&dangling_escape_json).unwrap(); + assert_eq!(result["path"], "abc\\"); // Test empty object let empty_json = "{}"; diff --git a/ui/desktop/src/types/message.ts b/ui/desktop/src/types/message.ts index d63bdc7683..dd41bc0a9e 100644 --- a/ui/desktop/src/types/message.ts +++ b/ui/desktop/src/types/message.ts @@ -94,14 +94,24 @@ export function getTextAndImageContent(message: Message): { } } - // Strip tags from assistant text — the thinking is surfaced via getThinkingContent + // Strip assistant-only markup that shouldn't appear in rendered text if (message.role === 'assistant') { + textContent = stripToolCallMarkers(textContent); textContent = textContent.replace(/[\s\S]*?<\/think>/gi, ''); } return { textContent, imagePaths }; } +function stripToolCallMarkers(text: string): string { + // Remove all tool call XML markers and their content + return text + .replace(/<\|tool_calls_section_begin\|>[\s\S]*?<\|tool_calls_section_end\|>/g, '') + .replace(/<\|tool_call_begin\|>[\s\S]*?<\|tool_call_end\|>/g, '') + .replace(/<\|tool_call_argument_begin\|>[\s\S]*?<\|tool_call_argument_end\|>/g, '') + .trim(); +} + export function getThinkingContent(message: Message): string | null { const parts: string[] = [];