Fix conversations before they hit the LLM (#3660)

Co-authored-by: Douwe Osinga <douwe@squareup.com>
This commit is contained in:
Douwe Osinga
2025-07-28 18:23:10 +02:00
committed by GitHub
parent 1d3a91912a
commit 4442ccfe4a
4 changed files with 444 additions and 6 deletions
+11 -3
View File
@@ -56,6 +56,7 @@ use super::platform_tools;
use super::router_tools;
use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE};
use crate::agents::subagent_task_config::TaskConfig;
use crate::conversation_fixer::{debug_conversation_fix, ConversationFixer};
const DEFAULT_MAX_TURNS: u32 = 1000;
@@ -706,14 +707,21 @@ impl Agent {
}
}
#[instrument(skip(self, messages, session), fields(user_message))]
#[instrument(skip(self, unfixed_messages, session), fields(user_message))]
pub async fn reply(
&self,
messages: &[Message],
unfixed_messages: &[Message],
session: Option<SessionConfig>,
cancel_token: Option<CancellationToken>,
) -> Result<BoxStream<'_, Result<AgentEvent>>> {
let mut messages = messages.to_vec();
let (mut messages, issues) =
ConversationFixer::fix_conversation(Vec::from(unfixed_messages));
if !issues.is_empty() {
tracing::warn!(
"Conversation issue fixed: {}",
debug_conversation_fix(&messages, unfixed_messages, &issues)
);
}
let initial_messages = messages.clone();
let reply_span = tracing::Span::current();
self.reset_retry_attempts().await;
+378
View File
@@ -0,0 +1,378 @@
use crate::message::{Message, MessageContent};
use rmcp::model::Role;
use std::collections::HashSet;
pub struct ConversationFixer;
const PLACEHOLDER_USER_MESSAGE: &str = "Hello";
impl ConversationFixer {
/// Fix a conversation that we're about to send to an LLM. So the last and first
/// messages should always be from the user.
pub fn fix_conversation(messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
let (messages, empty_removed) = Self::remove_empty_messages(messages);
let (messages, tool_calling_fixed) = Self::fix_tool_calling(messages);
let (messages, messages_merged) = Self::merge_consecutive_messages(messages);
let (messages, lead_trail_fixed) = Self::fix_lead_trail(messages);
let (messages, populated_if_empty) = Self::populate_if_empty(messages);
let mut issues = Vec::new();
issues.extend(empty_removed);
issues.extend(tool_calling_fixed);
issues.extend(messages_merged);
issues.extend(lead_trail_fixed);
issues.extend(populated_if_empty);
(messages, issues)
}
fn remove_empty_messages(messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
let mut issues = Vec::new();
let filtered_messages = messages
.into_iter()
.filter(|msg| {
if msg.content.is_empty() {
issues.push("Removed empty message".to_string());
false
} else {
true
}
})
.collect();
(filtered_messages, issues)
}
fn fix_tool_calling(mut messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
let mut issues = Vec::new();
let mut pending_tool_requests: HashSet<String> = HashSet::new();
for message in &mut messages {
let mut content_to_remove = Vec::new();
match message.role {
Role::User => {
for (idx, content) in message.content.iter().enumerate() {
match content {
MessageContent::ToolRequest(req) => {
content_to_remove.push(idx);
issues.push(format!(
"Removed tool request '{}' from user message",
req.id
));
}
MessageContent::ToolConfirmationRequest(req) => {
content_to_remove.push(idx);
issues.push(format!(
"Removed tool confirmation request '{}' from user message",
req.id
));
}
MessageContent::Thinking(_) | MessageContent::RedactedThinking(_) => {
content_to_remove.push(idx);
issues
.push("Removed thinking content from user message".to_string());
}
MessageContent::ToolResponse(resp) => {
if pending_tool_requests.contains(&resp.id) {
pending_tool_requests.remove(&resp.id);
} else {
content_to_remove.push(idx);
issues.push(format!(
"Removed orphaned tool response '{}'",
resp.id
));
}
}
_ => {}
}
}
}
Role::Assistant => {
for (idx, content) in message.content.iter().enumerate() {
match content {
MessageContent::ToolResponse(resp) => {
content_to_remove.push(idx);
issues.push(format!(
"Removed tool response '{}' from assistant message",
resp.id
));
}
MessageContent::FrontendToolRequest(req) => {
content_to_remove.push(idx);
issues.push(format!(
"Removed frontend tool request '{}' from assistant message",
req.id
));
}
MessageContent::ToolRequest(req) => {
pending_tool_requests.insert(req.id.clone());
}
_ => {}
}
}
}
}
for &idx in content_to_remove.iter().rev() {
message.content.remove(idx);
}
}
for message in &mut messages {
if message.role == Role::Assistant {
let mut content_to_remove = Vec::new();
for (idx, content) in message.content.iter().enumerate() {
if let MessageContent::ToolRequest(req) = content {
if pending_tool_requests.contains(&req.id) {
content_to_remove.push(idx);
issues.push(format!("Removed orphaned tool request '{}'", req.id));
}
}
}
for &idx in content_to_remove.iter().rev() {
message.content.remove(idx);
}
}
}
let (messages, empty_removed) = Self::remove_empty_messages(messages);
issues.extend(empty_removed);
(messages, issues)
}
fn merge_consecutive_messages(messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
let mut issues = Vec::new();
let mut merged_messages: Vec<Message> = Vec::new();
for message in messages {
if let Some(last) = merged_messages.last_mut() {
if last.role == message.role {
last.content.extend(message.content);
let role_name = match message.role {
Role::User => "user",
Role::Assistant => "assistant",
};
issues.push(format!("Merged consecutive {} messages", role_name));
continue;
}
}
merged_messages.push(message);
}
(merged_messages, issues)
}
fn fix_lead_trail(mut messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
let mut issues = Vec::new();
if let Some(first) = messages.first() {
if first.role == Role::Assistant {
messages.remove(0);
issues.push("Removed leading assistant message".to_string());
}
}
if let Some(last) = messages.last() {
if last.role == Role::Assistant {
messages.pop();
issues.push("Removed trailing assistant message".to_string());
}
}
(messages, issues)
}
fn populate_if_empty(mut messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
let mut issues = Vec::new();
if messages.is_empty() {
issues.push("Added placeholder user message to empty conversation".to_string());
messages.push(Message::user().with_text(PLACEHOLDER_USER_MESSAGE));
}
(messages, issues)
}
}
pub fn debug_conversation_fix(
messages: &[Message],
fixed: &[Message],
issues: &[String],
) -> String {
let mut output = String::new();
output.push_str("=== CONVERSATION FIX DEBUG ===\n\n");
output.push_str("BEFORE:\n");
for (i, msg) in messages.iter().enumerate() {
output.push_str(&format!(" [{}] {}\n", i, msg.debug()));
}
output.push_str("\nISSUES FOUND:\n");
if issues.is_empty() {
output.push_str(" (none)\n");
} else {
for issue in issues {
output.push_str(&format!(" - {}\n", issue));
}
}
output.push_str("\nAFTER:\n");
for (i, msg) in fixed.iter().enumerate() {
output.push_str(&format!(" [{}] {}\n", i, msg.debug()));
}
output.push_str("\n==============================\n");
output
}
#[cfg(test)]
mod tests {
use super::*;
use mcp_core::tool::ToolCall;
use serde_json::json;
fn run_verify(messages: Vec<Message>) -> (Vec<Message>, Vec<String>) {
let (fixed, issues) = ConversationFixer::fix_conversation(messages.clone());
// Uncomment the following line to print the debug report
// let report = debug_conversation_fix(&messages, &fixed, &issues);
// print!("\n{}", report);
let (_fixed, issues_with_fixed) = ConversationFixer::fix_conversation(fixed.clone());
assert_eq!(
issues_with_fixed.len(),
0,
"Fixed conversation should have no issues, but found: {:?}\n\n{}",
issues_with_fixed,
debug_conversation_fix(&messages, &fixed, &issues)
);
(fixed, issues)
}
#[test]
fn test_valid_conversation() {
let all_messages = vec![
Message::user().with_text("Can you help me search for something?"),
Message::assistant()
.with_text("I'll help you search.")
.with_tool_request(
"search_1",
Ok(ToolCall::new(
"web_search",
json!({"query": "rust programming"}),
)),
),
Message::user().with_tool_response("search_1", Ok(vec![])),
Message::assistant().with_text("Based on the search results, here's what I found..."),
];
for i in 1..=all_messages.len() {
let messages = all_messages[..i].to_vec();
if messages.last().unwrap().role == Role::User {
let (fixed, issues) = ConversationFixer::fix_conversation(messages.clone());
assert_eq!(
fixed.len(),
messages.len(),
"Step {}: Length should match",
i
);
assert!(
issues.is_empty(),
"Step {}: Should have no issues, but found: {:?}",
i,
issues
);
assert_eq!(fixed, messages, "Step {}: Messages should be unchanged", i);
}
}
}
#[test]
fn test_role_alternation_and_content_placement_issues() {
let messages = vec![
Message::user().with_text("Hello"),
Message::user().with_text("Another user message"),
Message::assistant()
.with_text("Response")
.with_tool_response("orphan_1", Ok(vec![])), // Wrong role
Message::assistant().with_thinking("Let me think", "sig"),
Message::user()
.with_tool_request("bad_req", Ok(ToolCall::new("search", json!({}))))
.with_text("User with bad tool request"),
];
let (fixed, issues) = run_verify(messages);
assert_eq!(fixed.len(), 3);
assert_eq!(issues.len(), 4);
assert!(issues
.iter()
.any(|i| i.contains("Merged consecutive user messages")));
assert!(issues
.iter()
.any(|i| i.contains("Removed tool response 'orphan_1' from assistant message")));
assert!(issues
.iter()
.any(|i| i.contains("Removed tool request 'bad_req' from user message")));
assert_eq!(fixed[0].role, Role::User);
assert_eq!(fixed[1].role, Role::Assistant);
assert_eq!(fixed[2].role, Role::User);
assert_eq!(fixed[0].content.len(), 2);
}
#[test]
fn test_orphaned_tools_and_empty_messages() {
// This conversation completely collapses. the first user message is invalid
// then we remove the empty user message and the wrong tool response
// then we collapse the assistant messages
// which we then remove because you can't end a conversation with an assistant message
let messages = vec![
Message::assistant()
.with_text("I'll search for you")
.with_tool_request("search_1", Ok(ToolCall::new("search", json!({})))),
Message::user(),
Message::user().with_tool_response("wrong_id", Ok(vec![])),
Message::assistant()
.with_tool_request("search_2", Ok(ToolCall::new("search", json!({})))),
];
let (fixed, issues) = run_verify(messages);
assert_eq!(fixed.len(), 1);
assert!(issues.iter().any(|i| i.contains("Removed empty message")));
assert!(issues
.iter()
.any(|i| i.contains("Removed orphaned tool response 'wrong_id'")));
assert_eq!(fixed[0].role, Role::User);
assert_eq!(fixed[0].as_concat_text(), "Hello");
}
#[test]
fn test_real_world_consecutive_assistant_messages() {
let messages = vec![
Message::user().with_text("run ls in the current directory and then run a word count on the smallest file"),
Message::assistant()
.with_text("I'll help you run `ls` in the current directory and then perform a word count on the smallest file. Let me start by listing the directory contents.")
.with_tool_request("toolu_bdrk_018adWbP4X26CfoJU5hkhu3i", Ok(ToolCall::new("developer__shell", json!({"command": "ls -la"})))),
Message::assistant()
.with_text("Now I'll identify the smallest file by size. Looking at the output, I can see that both `slack.yaml` and `subrecipes.yaml` have a size of 0 bytes, making them the smallest files. I'll run a word count on one of them:")
.with_tool_request("toolu_bdrk_01KgDYHs4fAodi22NqxRzmwx", Ok(ToolCall::new("developer__shell", json!({"command": "wc slack.yaml"})))),
Message::user()
.with_tool_response("toolu_bdrk_01KgDYHs4fAodi22NqxRzmwx", Ok(vec![])),
Message::assistant()
.with_text("I ran `ls -la` in the current directory and found several files. Looking at the file sizes, I can see that both `slack.yaml` and `subrecipes.yaml` are 0 bytes (the smallest files). I ran a word count on `slack.yaml` which shows: **0 lines**, **0 words**, **0 characters**"),
Message::user().with_text("thanks!"),
];
let (fixed, issues) = ConversationFixer::fix_conversation(messages);
assert_eq!(fixed.len(), 5);
assert_eq!(issues.len(), 2);
assert!(issues[0].contains("Removed orphaned tool request"));
assert!(issues[1].contains("Merged consecutive assistant messages"));
}
}
+1
View File
@@ -1,6 +1,7 @@
pub mod agents;
pub mod config;
pub mod context_mgmt;
mod conversation_fixer;
pub mod message;
pub mod model;
pub mod permission;
+54 -3
View File
@@ -1,5 +1,3 @@
use std::collections::HashSet;
/// Messages which represent the content sent back and forth to LLM provider
///
/// We use these messages in the agent code, and interfaces which interact with
@@ -18,6 +16,8 @@ use rmcp::model::{
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashSet;
use std::fmt;
use utoipa::ToSchema;
mod tool_result_serde;
@@ -114,6 +114,41 @@ pub enum MessageContent {
SummarizationRequested(SummarizationRequested),
}
impl fmt::Display for MessageContent {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MessageContent::Text(t) => write!(f, "{}", t.text),
MessageContent::Image(i) => write!(f, "[Image: {}]", i.mime_type),
MessageContent::ToolRequest(r) => {
write!(f, "[ToolRequest: {}]", r.to_readable_string())
}
MessageContent::ToolResponse(r) => write!(
f,
"[ToolResponse: {}]",
match &r.tool_result {
Ok(contents) => format!("{} content item(s)", contents.len()),
Err(e) => format!("Error: {e}"),
}
),
MessageContent::ToolConfirmationRequest(r) => {
write!(f, "[ToolConfirmationRequest: {}]", r.tool_name)
}
MessageContent::FrontendToolRequest(r) => match &r.tool_call {
Ok(tool_call) => write!(f, "[FrontendToolRequest: {}]", tool_call.name),
Err(e) => write!(f, "[FrontendToolRequest: Error: {}]", e),
},
MessageContent::Thinking(t) => write!(f, "[Thinking: {}]", t.thinking),
MessageContent::RedactedThinking(_r) => write!(f, "[RedactedThinking]"),
MessageContent::ContextLengthExceeded(r) => {
write!(f, "[ContextLengthExceeded: {}]", r.msg)
}
MessageContent::SummarizationRequested(r) => {
write!(f, "[SummarizationRequested: {}]", r.msg)
}
}
}
}
impl MessageContent {
pub fn text<S: Into<String>>(text: S) -> Self {
MessageContent::Text(RawTextContent { text: text.into() }.no_annotation())
@@ -312,7 +347,7 @@ impl From<PromptMessage> for Message {
}
}
#[derive(ToSchema, Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(ToSchema, Clone, PartialEq, Serialize, Deserialize)]
/// A message to or from an LLM
#[serde(rename_all = "camelCase")]
pub struct Message {
@@ -322,6 +357,19 @@ pub struct Message {
pub content: Vec<MessageContent>,
}
impl fmt::Debug for Message {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let joined_content: String = self
.content
.iter()
.map(|c| format!("{c}"))
.collect::<Vec<_>>()
.join(" ");
write!(f, "{:?}: {}", self.role, joined_content)
}
}
pub fn push_message(messages: &mut Vec<Message>, message: Message) {
if let Some(last) = messages
.last_mut()
@@ -351,6 +399,9 @@ impl Message {
content,
}
}
pub fn debug(&self) -> String {
format!("{:?}", self)
}
/// Create a new user message with the current timestamp
pub fn user() -> Self {