mirror of
https://github.com/aaif-goose/goose.git
synced 2026-06-01 22:09:18 +02:00
fix(agents): serialize per-session agent creation to stop duplicate MCP init (#9357)
Signed-off-by: fresh3nough <anonwurcod@proton.me>
This commit is contained in:
@@ -10,7 +10,7 @@ use lru::LruCache;
|
||||
use std::collections::HashMap;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{OnceCell, RwLock};
|
||||
use tokio::sync::{Mutex, OnceCell, RwLock};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{debug, info};
|
||||
|
||||
@@ -25,6 +25,15 @@ pub struct AgentManager {
|
||||
default_provider: Arc<RwLock<Option<Arc<dyn crate::providers::base::Provider>>>>,
|
||||
default_mode: GooseMode,
|
||||
cancel_tokens: Arc<RwLock<HashMap<String, CancellationToken>>>,
|
||||
/// Per-session creation locks. When `get_or_create_agent` misses the
|
||||
/// `sessions` cache it acquires the per-session lock before doing the
|
||||
/// expensive work (provider restore, MCP extension initialization) so
|
||||
/// concurrent callers for the same session never race into doing the
|
||||
/// work twice. Entries are inserted on demand and pruned when the
|
||||
/// session is removed *or* evicted by the LRU; the underlying
|
||||
/// `Arc<Mutex<()>>` stays alive as long as any caller still holds it,
|
||||
/// even after the HashMap entry is removed.
|
||||
creation_locks: Arc<Mutex<HashMap<String, Arc<Mutex<()>>>>>,
|
||||
}
|
||||
|
||||
impl AgentManager {
|
||||
@@ -46,6 +55,7 @@ impl AgentManager {
|
||||
default_provider: Arc::new(RwLock::new(None)),
|
||||
default_mode,
|
||||
cancel_tokens: Arc::new(RwLock::new(HashMap::new())),
|
||||
creation_locks: Arc::new(Mutex::new(HashMap::new())),
|
||||
};
|
||||
|
||||
Ok(manager)
|
||||
@@ -89,6 +99,7 @@ impl AgentManager {
|
||||
}
|
||||
|
||||
pub async fn get_or_create_agent(&self, session_id: String) -> Result<Arc<Agent>> {
|
||||
// Fast path: agent already cached.
|
||||
{
|
||||
let mut sessions = self.sessions.write().await;
|
||||
if let Some(existing) = sessions.get(&session_id) {
|
||||
@@ -96,10 +107,62 @@ impl AgentManager {
|
||||
}
|
||||
}
|
||||
|
||||
// Slow path: serialize creation per session so concurrent callers
|
||||
// (e.g. start_agent's background extension-loading task and a
|
||||
// resume_agent request racing through the frontend) cannot each
|
||||
// construct their own Agent and independently send `initialize` to
|
||||
// every MCP server. See issue #9031.
|
||||
let creation_lock = {
|
||||
let mut locks = self.creation_locks.lock().await;
|
||||
Arc::clone(
|
||||
locks
|
||||
.entry(session_id.clone())
|
||||
.or_insert_with(|| Arc::new(Mutex::new(()))),
|
||||
)
|
||||
};
|
||||
let creation_guard = creation_lock.lock().await;
|
||||
|
||||
// Funnel the fallible work through a helper so we can prune the
|
||||
// per-session creation lock on every error exit. Without this
|
||||
// the provider-setup path (update_provider / update_mode) could
|
||||
// bail out via `?`, leaving a permanent `creation_locks` entry
|
||||
// for a session that never made it into the LRU cache and that
|
||||
// no one will ever call `remove_session` on.
|
||||
let result = self.create_agent_locked(&session_id).await;
|
||||
|
||||
if result.is_err() {
|
||||
// Release BOTH the guard and our local Arc clone of the
|
||||
// creation lock before pruning. `prune_creation_lock`
|
||||
// gates removal on `Arc::strong_count == 1`; if we kept
|
||||
// `creation_lock` alive the count would still be at least
|
||||
// two (HashMap + this local) and the failed session would
|
||||
// leak its lock entry forever. In-flight waiters keep the
|
||||
// Arc alive on their own and prune correctly skips while
|
||||
// they hold it.
|
||||
drop(creation_guard);
|
||||
drop(creation_lock);
|
||||
self.prune_creation_lock(&session_id).await;
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Slow-path body for `get_or_create_agent`. Must be called with the
|
||||
/// per-session creation lock held by the caller.
|
||||
async fn create_agent_locked(&self, session_id: &str) -> Result<Arc<Agent>> {
|
||||
// Re-check under the creation lock: another caller may have
|
||||
// finished creating the agent while we were waiting.
|
||||
{
|
||||
let mut sessions = self.sessions.write().await;
|
||||
if let Some(existing) = sessions.get(session_id) {
|
||||
return Ok(Arc::clone(existing));
|
||||
}
|
||||
}
|
||||
|
||||
let mut mode = self.default_mode;
|
||||
let permission_manager = PermissionManager::instance();
|
||||
|
||||
if let Ok(session) = self.session_manager.get_session(&session_id, false).await {
|
||||
if let Ok(session) = self.session_manager.get_session(session_id, false).await {
|
||||
mode = session.goose_mode;
|
||||
info!(goose_mode = %mode, session_id = %session_id, "Session loaded");
|
||||
}
|
||||
@@ -116,7 +179,7 @@ impl AgentManager {
|
||||
);
|
||||
let agent = Arc::new(Agent::with_config(config));
|
||||
|
||||
if let Ok(session) = self.session_manager.get_session(&session_id, false).await {
|
||||
if let Ok(session) = self.session_manager.get_session(session_id, false).await {
|
||||
if session.provider_name.is_some() {
|
||||
info!(
|
||||
"Restoring evicted session {} (provider: {:?})",
|
||||
@@ -136,21 +199,53 @@ impl AgentManager {
|
||||
if agent.provider().await.is_err() {
|
||||
if let Some(provider) = &*self.default_provider.read().await {
|
||||
agent
|
||||
.update_provider(Arc::clone(provider), &session_id)
|
||||
.update_provider(Arc::clone(provider), session_id)
|
||||
.await?;
|
||||
provider
|
||||
.update_mode(&session_id, mode)
|
||||
.update_mode(session_id, mode)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to propagate mode to provider: {}", e))?;
|
||||
}
|
||||
}
|
||||
|
||||
let mut sessions = self.sessions.write().await;
|
||||
if let Some(existing) = sessions.get(&session_id) {
|
||||
Ok(Arc::clone(existing))
|
||||
} else {
|
||||
sessions.put(session_id, agent.clone());
|
||||
Ok(agent)
|
||||
if let Some(existing) = sessions.get(session_id) {
|
||||
return Ok(Arc::clone(existing));
|
||||
}
|
||||
// `push` returns the LRU-evicted entry when the cache is at
|
||||
// capacity, which `put` does not surface. We need the evicted
|
||||
// key so we can also drop its creation lock below, otherwise the
|
||||
// `creation_locks` HashMap would grow without bound in long-lived
|
||||
// processes that churn through many sessions.
|
||||
let evicted = sessions
|
||||
.push(session_id.to_string(), agent.clone())
|
||||
.map(|(k, _)| k);
|
||||
drop(sessions);
|
||||
|
||||
if let Some(evicted_id) = evicted {
|
||||
self.prune_creation_lock(&evicted_id).await;
|
||||
}
|
||||
|
||||
Ok(agent)
|
||||
}
|
||||
|
||||
/// Drop the per-session creation lock for `session_id` if no other
|
||||
/// caller is currently holding a clone of its `Arc`. Holding the
|
||||
/// `creation_locks` mutex while we both check `Arc::strong_count` and
|
||||
/// remove guarantees no new waiter can race in between the check and
|
||||
/// the removal: any new caller would need to acquire the outer mutex
|
||||
/// first to clone the inner `Arc`.
|
||||
///
|
||||
/// If a waiter is still in flight (strong_count > 1) we leave the
|
||||
/// entry in place so the in-flight callers continue to serialize
|
||||
/// through the same lock; a later removal or eviction will sweep it.
|
||||
async fn prune_creation_lock(&self, session_id: &str) {
|
||||
let mut locks = self.creation_locks.lock().await;
|
||||
let in_use = locks
|
||||
.get(session_id)
|
||||
.is_some_and(|lock| Arc::strong_count(lock) > 1);
|
||||
if !in_use {
|
||||
locks.remove(session_id);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,6 +257,12 @@ impl AgentManager {
|
||||
sessions
|
||||
.pop(session_id)
|
||||
.ok_or_else(|| anyhow::anyhow!("Session {} not found", session_id))?;
|
||||
drop(sessions);
|
||||
// Best-effort prune of the per-session creation lock so the
|
||||
// HashMap doesn't grow unbounded. Any caller still holding a
|
||||
// clone of the Arc keeps the underlying Mutex alive until it
|
||||
// releases its guard.
|
||||
self.prune_creation_lock(session_id).await;
|
||||
info!("Removed session {}", session_id);
|
||||
Ok(())
|
||||
}
|
||||
@@ -428,6 +529,140 @@ mod tests {
|
||||
assert!(result.unwrap_err().to_string().contains("not found"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remove_session_prunes_creation_lock() {
|
||||
// remove_session must drop the per-session creation lock so the
|
||||
// HashMap doesn't grow unboundedly.
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = create_test_manager(&temp_dir).await;
|
||||
let session = String::from("to-be-removed");
|
||||
|
||||
manager.get_or_create_agent(session.clone()).await.unwrap();
|
||||
assert_eq!(manager.creation_locks.lock().await.len(), 1);
|
||||
|
||||
manager.remove_session(&session).await.unwrap();
|
||||
assert!(
|
||||
manager.creation_locks.lock().await.is_empty(),
|
||||
"remove_session must prune the creation lock for the removed session"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_failed_creation_prunes_creation_lock() {
|
||||
// Regression test for the Codex review note on PR #9357: when the
|
||||
// provider-setup path in `create_agent_locked` returns Err, the
|
||||
// outer `get_or_create_agent` must also drop its local Arc clone
|
||||
// of the creation lock before pruning. Otherwise
|
||||
// `Arc::strong_count` stays > 1 and the failed session leaks a
|
||||
// permanent entry in `creation_locks`.
|
||||
use async_trait::async_trait;
|
||||
use rmcp::model::Tool;
|
||||
|
||||
use crate::conversation::message::Message;
|
||||
use crate::model::ModelConfig;
|
||||
use crate::providers::base::{MessageStream, Provider, ProviderUsage, Usage};
|
||||
use crate::providers::errors::ProviderError;
|
||||
|
||||
struct FailingProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for FailingProvider {
|
||||
fn get_name(&self) -> &str {
|
||||
"failing-test-provider"
|
||||
}
|
||||
|
||||
fn get_model_config(&self) -> ModelConfig {
|
||||
ModelConfig::new_or_fail("test-model")
|
||||
}
|
||||
|
||||
async fn stream(
|
||||
&self,
|
||||
_model_config: &ModelConfig,
|
||||
_session_id: &str,
|
||||
_system: &str,
|
||||
_messages: &[Message],
|
||||
_tools: &[Tool],
|
||||
) -> std::result::Result<MessageStream, ProviderError> {
|
||||
Ok(crate::providers::base::stream_from_single_message(
|
||||
Message::assistant().with_text("unused"),
|
||||
ProviderUsage::new("failing-test-provider".into(), Usage::default()),
|
||||
))
|
||||
}
|
||||
|
||||
async fn update_mode(
|
||||
&self,
|
||||
_session_id: &str,
|
||||
_mode: GooseMode,
|
||||
) -> std::result::Result<(), ProviderError> {
|
||||
Err(ProviderError::ExecutionError(
|
||||
"intentional failure for test".into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let manager = create_test_manager(&temp_dir).await;
|
||||
manager
|
||||
.set_default_provider(Arc::new(FailingProvider))
|
||||
.await;
|
||||
|
||||
let session_id = String::from("failed-creation-test");
|
||||
let result = manager.get_or_create_agent(session_id.clone()).await;
|
||||
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"expected provider mode-update failure to propagate"
|
||||
);
|
||||
assert!(
|
||||
manager.creation_locks.lock().await.is_empty(),
|
||||
"creation_locks must be empty after a failed agent creation"
|
||||
);
|
||||
assert!(
|
||||
!manager.has_session(&session_id).await,
|
||||
"failed creation must not insert into the LRU cache"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_lru_eviction_prunes_creation_lock() {
|
||||
// Sessions can disappear from the LRU cache without going through
|
||||
// remove_session. When that happens the matching creation lock
|
||||
// must also be pruned, otherwise long-lived processes that churn
|
||||
// through many session IDs would accumulate stale lock entries
|
||||
// even though only `max_sessions` agents remain cached.
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let session_manager = Arc::new(SessionManager::new(temp_dir.path().to_path_buf()));
|
||||
let schedule_path = temp_dir.path().join("schedule.json");
|
||||
let manager = AgentManager::new(
|
||||
session_manager,
|
||||
schedule_path,
|
||||
Some(2),
|
||||
GooseMode::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
manager.get_or_create_agent("a".into()).await.unwrap();
|
||||
manager.get_or_create_agent("b".into()).await.unwrap();
|
||||
assert_eq!(manager.creation_locks.lock().await.len(), 2);
|
||||
|
||||
// Inserting a third session evicts the LRU entry ("a").
|
||||
manager.get_or_create_agent("c".into()).await.unwrap();
|
||||
|
||||
let locks = manager.creation_locks.lock().await;
|
||||
assert_eq!(
|
||||
locks.len(),
|
||||
2,
|
||||
"creation_locks must stay bounded by max_sessions after LRU eviction"
|
||||
);
|
||||
assert!(
|
||||
!locks.contains_key("a"),
|
||||
"LRU-evicted session's creation lock should be pruned"
|
||||
);
|
||||
assert!(locks.contains_key("b"));
|
||||
assert!(locks.contains_key("c"));
|
||||
}
|
||||
|
||||
#[test_case(GooseMode::Approve ; "approve")]
|
||||
#[test_case(GooseMode::Chat ; "chat")]
|
||||
#[test_case(GooseMode::SmartApprove ; "smart_approve")]
|
||||
|
||||
Reference in New Issue
Block a user