mirror of
https://github.com/aaif-goose/goose.git
synced 2026-06-02 06:14:27 +02:00
speed up agent switching with provider caching and spinning-up indicator
Cache external ACP providers (Claude Code, Codex) in memory so switching back to a recently-used agent reuses the warm subprocess, persist model lists to disk for instant picker fill on cold starts, pre-warm the last-used provider on server boot, and surface a dedicated "Spinning up …" state in the chat UI bracketed around the prepare call. Also de-dupes concurrent prepareSession calls so a background prepare from the agent picker and a foreground one from sendMessage join the same in-flight promise instead of racing duplicate newSession round-trips. Signed-off-by: morgmart <98432065+morgmart@users.noreply.github.com>
This commit is contained in:
Generated
+1
@@ -4454,6 +4454,7 @@ dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"chrono",
|
||||
"fs-err",
|
||||
"futures",
|
||||
"goose",
|
||||
|
||||
@@ -48,6 +48,7 @@ uuid = { workspace = true, features = ["v7"] }
|
||||
schemars = { workspace = true, features = ["derive"] }
|
||||
goose-acp-macros = { path = "../goose-acp-macros" }
|
||||
goose-sdk = { path = "../goose-sdk" }
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
|
||||
[dev-dependencies]
|
||||
async-trait = { workspace = true }
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
mod adapters;
|
||||
pub use goose_sdk::custom_requests;
|
||||
mod fs;
|
||||
pub mod model_cache;
|
||||
pub mod server;
|
||||
pub mod server_factory;
|
||||
pub(crate) mod tools;
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
//! Disk-persisted cache of per-provider model lists.
|
||||
//!
|
||||
//! When a user picks an ACP provider (e.g. Claude Code) the backend has to
|
||||
//! spawn the external agent and wait for its initial `NewSession` before it
|
||||
//! can answer "what models do you have?". That round-trip is ~25s for
|
||||
//! claude-acp on cold start. This cache lets the UI fill the model picker
|
||||
//! instantly with the last-known list while the real `update_provider` call
|
||||
//! continues in the background.
|
||||
//!
|
||||
//! The cache stores the raw `Vec<SessionConfigOption>` so it can be replayed
|
||||
//! verbatim through the existing `ConfigOptionUpdate` notification path.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use chrono::{DateTime, Utc};
|
||||
use fs_err as fs;
|
||||
use goose::config::paths::Paths;
|
||||
use sacp::schema::SessionConfigOption;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
const CACHE_FILE_NAME: &str = "acp_model_cache.json";
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ModelCache {
|
||||
#[serde(default)]
|
||||
pub providers: HashMap<String, ProviderEntry>,
|
||||
#[serde(skip)]
|
||||
path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProviderEntry {
|
||||
pub options: Vec<SessionConfigOption>,
|
||||
pub cached_at: DateTime<Utc>,
|
||||
pub last_used_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl ModelCache {
|
||||
pub fn default_path() -> PathBuf {
|
||||
Paths::in_state_dir(CACHE_FILE_NAME)
|
||||
}
|
||||
|
||||
pub fn load() -> Self {
|
||||
Self::load_from(Self::default_path())
|
||||
}
|
||||
|
||||
pub fn load_from(path: PathBuf) -> Self {
|
||||
match fs::read(&path) {
|
||||
Ok(bytes) => match serde_json::from_slice::<ModelCache>(&bytes) {
|
||||
Ok(mut cache) => {
|
||||
cache.path = path;
|
||||
cache
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
path = %path.display(),
|
||||
error = %e,
|
||||
"ACP model cache could not be parsed, starting empty",
|
||||
);
|
||||
Self {
|
||||
providers: HashMap::new(),
|
||||
path,
|
||||
}
|
||||
}
|
||||
},
|
||||
Err(_) => Self {
|
||||
providers: HashMap::new(),
|
||||
path,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get(&self, provider_name: &str) -> Option<&ProviderEntry> {
|
||||
self.providers.get(provider_name)
|
||||
}
|
||||
|
||||
/// Returns the provider name with the most recent `last_used_at`.
|
||||
pub fn last_used_provider(&self) -> Option<String> {
|
||||
self.providers
|
||||
.iter()
|
||||
.max_by_key(|(_, entry)| entry.last_used_at)
|
||||
.map(|(name, _)| name.clone())
|
||||
}
|
||||
|
||||
pub fn upsert(&mut self, provider_name: &str, options: Vec<SessionConfigOption>) {
|
||||
let now = Utc::now();
|
||||
self.providers.insert(
|
||||
provider_name.to_string(),
|
||||
ProviderEntry {
|
||||
options,
|
||||
cached_at: now,
|
||||
last_used_at: now,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub fn save(&self) -> Result<()> {
|
||||
Self::save_to(&self.path, self)
|
||||
}
|
||||
|
||||
fn save_to(path: &Path, cache: &ModelCache) -> Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent).with_context(|| {
|
||||
format!("creating model cache parent dir {}", parent.display())
|
||||
})?;
|
||||
}
|
||||
let tmp = path.with_extension("json.tmp");
|
||||
let bytes =
|
||||
serde_json::to_vec_pretty(cache).context("serializing model cache to json")?;
|
||||
fs::write(&tmp, &bytes)
|
||||
.with_context(|| format!("writing temp model cache {}", tmp.display()))?;
|
||||
fs::rename(&tmp, path)
|
||||
.with_context(|| format!("renaming model cache to {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use sacp::schema::SessionConfigOption;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn sample_options() -> Vec<SessionConfigOption> {
|
||||
vec![SessionConfigOption::select(
|
||||
"model",
|
||||
"Model",
|
||||
"opus".to_string(),
|
||||
vec![],
|
||||
)]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn round_trip_through_disk() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let path = tmp.path().join("acp_model_cache.json");
|
||||
|
||||
let mut cache = ModelCache::load_from(path.clone());
|
||||
assert!(cache.providers.is_empty());
|
||||
|
||||
cache.upsert("claude-acp", sample_options());
|
||||
cache.save().unwrap();
|
||||
|
||||
let reloaded = ModelCache::load_from(path);
|
||||
let entry = reloaded.get("claude-acp").expect("entry exists");
|
||||
assert_eq!(entry.options.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_file_yields_empty_cache() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let path = tmp.path().join("does_not_exist.json");
|
||||
let cache = ModelCache::load_from(path);
|
||||
assert!(cache.providers.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn last_used_provider_picks_most_recent() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let path = tmp.path().join("acp_model_cache.json");
|
||||
let mut cache = ModelCache::load_from(path);
|
||||
cache.upsert("claude-acp", sample_options());
|
||||
std::thread::sleep(std::time::Duration::from_millis(5));
|
||||
cache.upsert("codex", sample_options());
|
||||
assert_eq!(cache.last_used_provider().as_deref(), Some("codex"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn corrupt_file_yields_empty_cache() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let path = tmp.path().join("acp_model_cache.json");
|
||||
fs::write(&path, b"not json").unwrap();
|
||||
let cache = ModelCache::load_from(path);
|
||||
assert!(cache.providers.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::custom_requests::*;
|
||||
use crate::fs::AcpTools;
|
||||
use crate::model_cache::ModelCache;
|
||||
use crate::tools::AcpAwareToolMeta;
|
||||
use anyhow::Result;
|
||||
use fs_err as fs;
|
||||
@@ -126,6 +127,16 @@ pub struct GooseAcpAgent {
|
||||
permission_manager: Arc<PermissionManager>,
|
||||
goose_mode: GooseMode,
|
||||
disable_session_naming: bool,
|
||||
/// Keeps successfully-created providers alive across session/provider switches.
|
||||
/// External ACP providers (claude-acp, codex, cursor-agent) own a long-lived
|
||||
/// subprocess that takes ~25s to spin up on cold start; reusing the same
|
||||
/// `Arc<dyn Provider>` skips that entirely. Keyed by provider name; entries
|
||||
/// are never evicted for the lifetime of the agent.
|
||||
provider_cache: Arc<Mutex<HashMap<String, Arc<dyn Provider>>>>,
|
||||
/// Disk-persisted snapshot of the most recent model lists per provider.
|
||||
/// Lets the UI fill the model picker instantly while the real
|
||||
/// `update_provider` continues in the background.
|
||||
model_cache: Arc<Mutex<ModelCache>>,
|
||||
}
|
||||
|
||||
fn extract_timeout_from_meta(meta: &Option<Meta>) -> Option<u64> {
|
||||
@@ -615,6 +626,7 @@ impl GooseAcpAgent {
|
||||
));
|
||||
let permission_manager = Arc::new(PermissionManager::new(config_dir.clone()));
|
||||
|
||||
let model_cache = ModelCache::load();
|
||||
Ok(Self {
|
||||
sessions: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_factory,
|
||||
@@ -627,9 +639,27 @@ impl GooseAcpAgent {
|
||||
permission_manager,
|
||||
goose_mode,
|
||||
disable_session_naming,
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
model_cache: Arc::new(Mutex::new(model_cache)),
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a snapshot of the disk-cached model list for `provider_name`,
|
||||
/// or `None` if there's no cached entry yet.
|
||||
pub async fn cached_model_options(
|
||||
&self,
|
||||
provider_name: &str,
|
||||
) -> Option<Vec<SessionConfigOption>> {
|
||||
let cache = self.model_cache.lock().await;
|
||||
cache.get(provider_name).map(|e| e.options.clone())
|
||||
}
|
||||
|
||||
/// Returns the most-recently-used provider name from the disk cache.
|
||||
pub async fn last_used_provider(&self) -> Option<String> {
|
||||
let cache = self.model_cache.lock().await;
|
||||
cache.last_used_provider()
|
||||
}
|
||||
|
||||
fn load_config(&self) -> Result<Config> {
|
||||
Config::new(self.config_dir.join(CONFIG_YAML_NAME), "goose").map_err(Into::into)
|
||||
}
|
||||
@@ -643,6 +673,31 @@ impl GooseAcpAgent {
|
||||
(self.provider_factory)(provider_name.to_string(), model_config, extensions).await
|
||||
}
|
||||
|
||||
/// Returns a cached provider when available, otherwise constructs a new one
|
||||
/// via `create_provider` and stores it. The cache is keyed by provider name
|
||||
/// only — see the `provider_cache` field doc for the trade-off.
|
||||
pub(crate) async fn get_or_create_provider(
|
||||
&self,
|
||||
provider_name: &str,
|
||||
model_config: goose::model::ModelConfig,
|
||||
extensions: Vec<ExtensionConfig>,
|
||||
) -> Result<Arc<dyn Provider>> {
|
||||
{
|
||||
let cache = self.provider_cache.lock().await;
|
||||
if let Some(provider) = cache.get(provider_name) {
|
||||
return Ok(Arc::clone(provider));
|
||||
}
|
||||
}
|
||||
let provider = self
|
||||
.create_provider(provider_name, model_config, extensions)
|
||||
.await?;
|
||||
let mut cache = self.provider_cache.lock().await;
|
||||
let entry = cache
|
||||
.entry(provider_name.to_string())
|
||||
.or_insert_with(|| Arc::clone(&provider));
|
||||
Ok(Arc::clone(entry))
|
||||
}
|
||||
|
||||
fn spawn_agent_setup(
|
||||
&self,
|
||||
cx: &ConnectionTo<Client>,
|
||||
@@ -673,6 +728,7 @@ impl GooseAcpAgent {
|
||||
.unwrap_or_default();
|
||||
let client_terminal = self.client_terminal.get().copied().unwrap_or(false);
|
||||
let provider_factory = Arc::clone(&self.provider_factory);
|
||||
let provider_cache = Arc::clone(&self.provider_cache);
|
||||
let disable_session_naming = self.disable_session_naming;
|
||||
|
||||
tokio::spawn(async move {
|
||||
@@ -787,9 +843,25 @@ impl GooseAcpAgent {
|
||||
Some(&goose_session.extension_data),
|
||||
&config,
|
||||
);
|
||||
let provider = provider_factory(provider_name.to_string(), model_config, ext_state)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
let provider = {
|
||||
let cached = {
|
||||
let cache = provider_cache.lock().await;
|
||||
cache.get(&provider_name).map(Arc::clone)
|
||||
};
|
||||
match cached {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
let p = provider_factory(provider_name.to_string(), model_config, ext_state)
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
let mut cache = provider_cache.lock().await;
|
||||
cache
|
||||
.entry(provider_name.clone())
|
||||
.or_insert_with(|| Arc::clone(&p));
|
||||
p
|
||||
}
|
||||
}
|
||||
};
|
||||
agent
|
||||
.update_provider(provider.clone(), &goose_session.id)
|
||||
.await
|
||||
@@ -1810,7 +1882,7 @@ impl GooseAcpAgent {
|
||||
})?
|
||||
.with_canonical_limits(&provider_name);
|
||||
let provider = self
|
||||
.create_provider(&provider_name, model_config, extensions)
|
||||
.get_or_create_provider(&provider_name, model_config, extensions)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
sacp::Error::internal_error().data(format!("Failed to create provider: {}", e))
|
||||
@@ -1979,7 +2051,7 @@ impl GooseAcpAgent {
|
||||
let extensions =
|
||||
EnabledExtensionsState::for_session(&self.session_manager, &internal_id, &config).await;
|
||||
let new_provider = self
|
||||
.create_provider(&resolved_provider_name, model_config, extensions)
|
||||
.get_or_create_provider(&resolved_provider_name, model_config, extensions)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
sacp::Error::internal_error().data(format!("Failed to create provider: {}", e))
|
||||
@@ -2040,6 +2112,19 @@ impl GooseAcpAgent {
|
||||
let (_, config_options) = self
|
||||
.build_config_update(&SessionId::new(thread_id.to_string()))
|
||||
.await?;
|
||||
|
||||
{
|
||||
let mut cache = self.model_cache.lock().await;
|
||||
cache.upsert(&resolved_provider_name, config_options.clone());
|
||||
if let Err(e) = cache.save() {
|
||||
warn!(
|
||||
provider = %resolved_provider_name,
|
||||
error = %e,
|
||||
"failed to persist ACP model cache",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(config_options)
|
||||
}
|
||||
|
||||
@@ -2742,9 +2827,21 @@ impl HandleDispatchFrom<Client> for GooseAcpHandler {
|
||||
let session_id = req.session_id.clone();
|
||||
match req.config_id.0.as_ref() {
|
||||
"provider" => {
|
||||
if let Some(cached_options) = agent.cached_model_options(&value_id.0).await {
|
||||
let cached_notification = SessionNotification::new(
|
||||
session_id.clone(),
|
||||
SessionUpdate::ConfigOptionUpdate(
|
||||
ConfigOptionUpdate::new(cached_options),
|
||||
),
|
||||
);
|
||||
cx.send_notification(cached_notification)?;
|
||||
}
|
||||
match agent.update_provider(&session_id.0, &value_id.0, None, None, None).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => { responder.respond_with_error(e)?; return Ok(()); }
|
||||
Err(e) => {
|
||||
responder.respond_with_error(e)?;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
"mode" => {
|
||||
@@ -3404,4 +3501,58 @@ print(\"hello, world\")
|
||||
) -> Vec<SessionConfigOption> {
|
||||
build_config_options(&mode_state, &model_state, provider_name, provider_options)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn provider_cache_only_invokes_factory_once_per_name() {
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let factory_calls = Arc::clone(&calls);
|
||||
let provider_factory: AcpProviderFactory =
|
||||
Arc::new(move |_name, _model_config, _extensions| {
|
||||
let factory_calls = Arc::clone(&factory_calls);
|
||||
Box::pin(async move {
|
||||
factory_calls.fetch_add(1, Ordering::SeqCst);
|
||||
let provider: Arc<dyn Provider> = Arc::new(MockModelProvider {
|
||||
models: Ok(vec!["model-a".into()]),
|
||||
});
|
||||
Ok(provider)
|
||||
})
|
||||
});
|
||||
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let agent = GooseAcpAgent::new(
|
||||
provider_factory,
|
||||
Vec::new(),
|
||||
tmp.path().to_path_buf(),
|
||||
tmp.path().to_path_buf(),
|
||||
GooseMode::Auto,
|
||||
true,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let cfg = goose::model::ModelConfig::new_or_fail("model-a");
|
||||
let p1 = agent
|
||||
.get_or_create_provider("claude-acp", cfg.clone(), Vec::new())
|
||||
.await
|
||||
.unwrap();
|
||||
let p2 = agent
|
||||
.get_or_create_provider("claude-acp", cfg.clone(), Vec::new())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1, "factory should be invoked exactly once");
|
||||
assert!(Arc::ptr_eq(&p1, &p2), "second call should return the cached Arc");
|
||||
|
||||
let _p3 = agent
|
||||
.get_or_create_provider("codex", cfg, Vec::new())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
calls.load(Ordering::SeqCst),
|
||||
2,
|
||||
"different provider name should trigger a fresh factory call"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +49,35 @@ impl AcpServer {
|
||||
.await?;
|
||||
info!("Created new ACP agent");
|
||||
|
||||
Ok(Arc::new(agent))
|
||||
let agent = Arc::new(agent);
|
||||
spawn_provider_prewarm(Arc::clone(&agent));
|
||||
Ok(agent)
|
||||
}
|
||||
}
|
||||
|
||||
/// Best-effort background warm-up of the most-recently-used provider so the
|
||||
/// user doesn't pay the cold-start cost on the first agent click after launch.
|
||||
fn spawn_provider_prewarm(agent: Arc<GooseAcpAgent>) {
|
||||
tokio::spawn(async move {
|
||||
let Some(provider_name) = agent.last_used_provider().await else {
|
||||
return;
|
||||
};
|
||||
|
||||
let providers = goose::providers::providers().await;
|
||||
let Some((metadata, _)) = providers
|
||||
.into_iter()
|
||||
.find(|(m, _)| m.name == provider_name)
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
let Ok(model_config) = goose::model::ModelConfig::new(&metadata.default_model) else {
|
||||
return;
|
||||
};
|
||||
let model_config = model_config.with_canonical_limits(&provider_name);
|
||||
|
||||
let _ = agent
|
||||
.get_or_create_provider(&provider_name, model_config, Vec::new())
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1308,7 +1308,7 @@ fn resolve_model_info(
|
||||
))
|
||||
})?;
|
||||
let current = models.current_model_id.0.to_string();
|
||||
let available = models
|
||||
let available: Vec<String> = models
|
||||
.available_models
|
||||
.iter()
|
||||
.map(|am| am.model_id.0.to_string())
|
||||
|
||||
@@ -14,6 +14,8 @@ vi.mock("@/shared/api/acp", () => ({
|
||||
acpCancelSession: (...args: unknown[]) => mockAcpCancelSession(...args),
|
||||
acpPrepareSession: (...args: unknown[]) => mockAcpPrepareSession(...args),
|
||||
acpSetModel: (...args: unknown[]) => mockAcpSetModel(...args),
|
||||
acpIsPrepareInFlight: () => false,
|
||||
acpIsSessionPrepared: () => false,
|
||||
}));
|
||||
|
||||
import { useChat } from "../useChat";
|
||||
|
||||
@@ -9,12 +9,16 @@ const mockAcpSendMessage = vi.fn();
|
||||
const mockAcpCancelSession = vi.fn();
|
||||
const mockAcpPrepareSession = vi.fn();
|
||||
const mockAcpSetModel = vi.fn();
|
||||
const mockAcpIsPrepareInFlight = vi.fn(() => false);
|
||||
const mockAcpIsSessionPrepared = vi.fn(() => false);
|
||||
|
||||
vi.mock("@/shared/api/acp", () => ({
|
||||
acpSendMessage: (...args: unknown[]) => mockAcpSendMessage(...args),
|
||||
acpCancelSession: (...args: unknown[]) => mockAcpCancelSession(...args),
|
||||
acpPrepareSession: (...args: unknown[]) => mockAcpPrepareSession(...args),
|
||||
acpSetModel: (...args: unknown[]) => mockAcpSetModel(...args),
|
||||
acpIsPrepareInFlight: () => mockAcpIsPrepareInFlight(),
|
||||
acpIsSessionPrepared: () => mockAcpIsSessionPrepared(),
|
||||
}));
|
||||
|
||||
import { useChat } from "../useChat";
|
||||
@@ -401,6 +405,56 @@ describe("useChat", () => {
|
||||
expect(runtime.chatState).toBe("idle");
|
||||
});
|
||||
|
||||
it("transitions chatState from spinning_up to streaming while awaiting prepare", async () => {
|
||||
useChatSessionStore.setState({
|
||||
sessions: [
|
||||
{
|
||||
id: "session-1",
|
||||
title: "New Chat",
|
||||
providerId: "claude-acp",
|
||||
modelId: "opus",
|
||||
modelName: "Opus",
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
messageCount: 0,
|
||||
draft: true,
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const prepareDeferred = createDeferredPromise();
|
||||
mockAcpPrepareSession.mockReturnValueOnce(prepareDeferred.promise);
|
||||
const sendDeferred = createDeferredPromise();
|
||||
mockAcpSendMessage.mockReturnValueOnce(sendDeferred.promise);
|
||||
|
||||
const { result } = renderHook(() => useChat("session-1", "claude-acp"));
|
||||
|
||||
let sendPromise!: Promise<void>;
|
||||
await act(async () => {
|
||||
sendPromise = result.current.sendMessage("Hello");
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(useChatStore.getState().getSessionRuntime("session-1").chatState).toBe(
|
||||
"spinning_up",
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
prepareDeferred.resolve();
|
||||
await Promise.resolve();
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(useChatStore.getState().getSessionRuntime("session-1").chatState).toBe(
|
||||
"streaming",
|
||||
);
|
||||
|
||||
sendDeferred.resolve();
|
||||
await act(async () => {
|
||||
await sendPromise;
|
||||
});
|
||||
});
|
||||
|
||||
it("shows string-shaped invoke errors instead of falling back to unknown error", async () => {
|
||||
mockAcpSendMessage.mockRejectedValue("Working directory missing");
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ import {
|
||||
acpCancelSession,
|
||||
acpPrepareSession,
|
||||
acpSetModel,
|
||||
acpIsPrepareInFlight,
|
||||
acpIsSessionPrepared,
|
||||
} from "@/shared/api/acp";
|
||||
import { useAgentStore } from "@/features/agents/stores/agentStore";
|
||||
import {
|
||||
@@ -133,6 +135,7 @@ export function useChat(
|
||||
const hasAttachments = (attachments?.length ?? 0) > 0;
|
||||
if (
|
||||
(!text.trim() && !hasAttachments) ||
|
||||
chatState === "spinning_up" ||
|
||||
chatState === "streaming" ||
|
||||
chatState === "thinking"
|
||||
)
|
||||
@@ -177,7 +180,6 @@ export function useChat(
|
||||
}
|
||||
}
|
||||
store.addMessage(sessionId, userMessage);
|
||||
store.setChatState(sessionId, "thinking");
|
||||
store.setError(sessionId, null);
|
||||
|
||||
// Promote draft to real backend session before first send
|
||||
@@ -186,6 +188,14 @@ export function useChat(
|
||||
const wasDraft = !!session?.draft;
|
||||
const selectedModelId = session?.modelId;
|
||||
|
||||
const personaId = effectivePersonaInfo?.id;
|
||||
const needsPrepare =
|
||||
wasDraft || !acpIsSessionPrepared(sessionId, personaId);
|
||||
const prepareInFlight = acpIsPrepareInFlight(sessionId, personaId);
|
||||
const initialState =
|
||||
needsPrepare || prepareInFlight ? "spinning_up" : "thinking";
|
||||
store.setChatState(sessionId, initialState);
|
||||
|
||||
if (wasDraft) {
|
||||
sessionStore.promoteDraft(sessionId);
|
||||
}
|
||||
@@ -217,10 +227,10 @@ export function useChat(
|
||||
streamingPersonaIdRef.current = effectivePersonaInfo?.id ?? null;
|
||||
|
||||
try {
|
||||
if (wasDraft || selectedModelId) {
|
||||
if (needsPrepare || prepareInFlight || selectedModelId) {
|
||||
await acpPrepareSession(sessionId, providerId, {
|
||||
workingDir: workingDirOverride,
|
||||
personaId: effectivePersonaInfo?.id,
|
||||
personaId,
|
||||
});
|
||||
if (selectedModelId) {
|
||||
await acpSetModel(sessionId, selectedModelId);
|
||||
|
||||
@@ -2,6 +2,7 @@ import type { ChatState } from "@/shared/types/chat";
|
||||
|
||||
export function isSessionRunning(chatState: ChatState): boolean {
|
||||
return (
|
||||
chatState === "spinning_up" ||
|
||||
chatState === "thinking" ||
|
||||
chatState === "streaming" ||
|
||||
chatState === "waiting" ||
|
||||
|
||||
@@ -217,8 +217,29 @@ export function ChatView({
|
||||
const cached = sessionStore.getCachedModels(providerId);
|
||||
sessionStore.switchSessionProvider(activeSessionId, providerId, cached);
|
||||
setGlobalSelectedProvider(providerId);
|
||||
|
||||
// Eagerly tell the backend about the new provider so it pushes the
|
||||
// model list via config_option_update — otherwise the model picker
|
||||
// stays on "Loading…" until the user sends a message.
|
||||
if (effectiveWorkingDir) {
|
||||
void acpPrepareSession(activeSessionId, providerId, {
|
||||
workingDir: effectiveWorkingDir,
|
||||
personaId: selectedPersonaId ?? undefined,
|
||||
}).catch((error) => {
|
||||
console.error(
|
||||
"Failed to prepare ACP session on provider change:",
|
||||
error,
|
||||
);
|
||||
});
|
||||
}
|
||||
},
|
||||
[activeSessionId, selectedProvider, setGlobalSelectedProvider],
|
||||
[
|
||||
activeSessionId,
|
||||
selectedProvider,
|
||||
setGlobalSelectedProvider,
|
||||
effectiveWorkingDir,
|
||||
selectedPersonaId,
|
||||
],
|
||||
);
|
||||
|
||||
const handleProjectChange = useCallback(
|
||||
@@ -427,10 +448,13 @@ export function ChatView({
|
||||
]);
|
||||
const isStreaming = chatState === "streaming";
|
||||
const showIndicator =
|
||||
chatState === "spinning_up" ||
|
||||
chatState === "thinking" ||
|
||||
chatState === "streaming" ||
|
||||
chatState === "waiting" ||
|
||||
chatState === "compacting";
|
||||
const selectedProviderLabel =
|
||||
providers.find((p) => p.id === selectedProvider)?.label ?? selectedProvider;
|
||||
const handleCreatePersona = useCallback(() => {
|
||||
useAgentStore.getState().openPersonaEditor();
|
||||
}, []);
|
||||
@@ -474,11 +498,13 @@ export function ChatView({
|
||||
key="loading-indicator"
|
||||
chatState={
|
||||
chatState as
|
||||
| "spinning_up"
|
||||
| "thinking"
|
||||
| "streaming"
|
||||
| "waiting"
|
||||
| "compacting"
|
||||
}
|
||||
providerName={selectedProviderLabel}
|
||||
/>
|
||||
) : null}
|
||||
</AnimatePresence>
|
||||
|
||||
@@ -4,6 +4,7 @@ import { Shimmer } from "@/shared/ui/ai-elements/shimmer";
|
||||
|
||||
export type LoadingChatState =
|
||||
| "idle"
|
||||
| "spinning_up"
|
||||
| "thinking"
|
||||
| "streaming"
|
||||
| "waiting"
|
||||
@@ -11,6 +12,11 @@ export type LoadingChatState =
|
||||
|
||||
interface LoadingGooseProps {
|
||||
chatState?: LoadingChatState;
|
||||
/**
|
||||
* Display name of the provider being warmed up. Only used when
|
||||
* `chatState === "spinning_up"`. Falls back to a generic message when absent.
|
||||
*/
|
||||
providerName?: string;
|
||||
}
|
||||
|
||||
const LOADING_FADE_S = 0.45;
|
||||
@@ -20,7 +26,7 @@ const LOADING_SHIMMER_DELAY_S = 0.35;
|
||||
const LOADING_SHIMMER_REPEAT_DELAY_S = 0.9;
|
||||
|
||||
const MESSAGE_KEY_BY_STATE: Record<
|
||||
Exclude<LoadingChatState, "idle">,
|
||||
Exclude<LoadingChatState, "idle" | "spinning_up">,
|
||||
"thinking" | "responding"
|
||||
> = {
|
||||
thinking: "thinking",
|
||||
@@ -29,14 +35,22 @@ const MESSAGE_KEY_BY_STATE: Record<
|
||||
compacting: "responding",
|
||||
};
|
||||
|
||||
export function LoadingGoose({ chatState = "idle" }: LoadingGooseProps) {
|
||||
export function LoadingGoose({
|
||||
chatState = "idle",
|
||||
providerName,
|
||||
}: LoadingGooseProps) {
|
||||
const { t } = useTranslation("chat");
|
||||
const shouldReduceMotion = useReducedMotion();
|
||||
if (chatState === "idle") {
|
||||
return null;
|
||||
}
|
||||
|
||||
const message = t(`loading.${MESSAGE_KEY_BY_STATE[chatState]}`);
|
||||
const message =
|
||||
chatState === "spinning_up"
|
||||
? providerName
|
||||
? t("loading.spinningUp", { providerName })
|
||||
: t("loading.spinningUpFallback")
|
||||
: t(`loading.${MESSAGE_KEY_BY_STATE[chatState]}`);
|
||||
|
||||
return (
|
||||
<motion.div
|
||||
|
||||
@@ -3,7 +3,7 @@ import { describe, expect, it } from "vitest";
|
||||
import { LoadingGoose } from "../LoadingGoose";
|
||||
import chat from "@/shared/i18n/locales/en/chat.json";
|
||||
|
||||
const { thinking, responding } = chat.loading;
|
||||
const { thinking, responding, spinningUpFallback } = chat.loading;
|
||||
|
||||
describe("LoadingGoose", () => {
|
||||
it("renders thinking copy for the thinking state", () => {
|
||||
@@ -35,4 +35,22 @@ describe("LoadingGoose", () => {
|
||||
|
||||
expect(container).toBeEmptyDOMElement();
|
||||
});
|
||||
|
||||
it("renders the provider name in the spinning_up state", () => {
|
||||
render(
|
||||
<LoadingGoose chatState="spinning_up" providerName="Claude Code" />,
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.getByRole("status", { name: /Spinning up Claude Code/i }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("falls back to a generic message when no provider name is given", () => {
|
||||
render(<LoadingGoose chatState="spinning_up" />);
|
||||
|
||||
expect(
|
||||
screen.getByRole("status", { name: spinningUpFallback }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -78,6 +78,22 @@ export async function acpPrepareSession(
|
||||
);
|
||||
}
|
||||
|
||||
/** True if a `prepareSession` call for this session is currently awaiting the backend. */
|
||||
export function acpIsPrepareInFlight(
|
||||
sessionId: string,
|
||||
personaId?: string,
|
||||
): boolean {
|
||||
return sessionTracker.isPrepareInFlight(sessionId, personaId);
|
||||
}
|
||||
|
||||
/** True if this session has completed its initial backend prepare at least once. */
|
||||
export function acpIsSessionPrepared(
|
||||
sessionId: string,
|
||||
personaId?: string,
|
||||
): boolean {
|
||||
return sessionTracker.isSessionPrepared(sessionId, personaId);
|
||||
}
|
||||
|
||||
export async function acpSetModel(
|
||||
sessionId: string,
|
||||
modelId: string,
|
||||
|
||||
@@ -121,7 +121,10 @@ export async function newSession(
|
||||
workingDir: string,
|
||||
): Promise<NewSessionResponse> {
|
||||
const client = await getClient();
|
||||
return client.newSession({ cwd: workingDir, mcpServers: [] });
|
||||
return client.newSession({
|
||||
cwd: workingDir,
|
||||
mcpServers: [],
|
||||
});
|
||||
}
|
||||
|
||||
export async function loadSession(
|
||||
@@ -129,7 +132,11 @@ export async function loadSession(
|
||||
workingDir: string,
|
||||
): Promise<LoadSessionResponse> {
|
||||
const client = await getClient();
|
||||
return client.loadSession({ sessionId, cwd: workingDir, mcpServers: [] });
|
||||
return client.loadSession({
|
||||
sessionId,
|
||||
cwd: workingDir,
|
||||
mcpServers: [],
|
||||
});
|
||||
}
|
||||
|
||||
export async function prompt(
|
||||
|
||||
@@ -308,23 +308,31 @@ function handleShared(sessionId: string, update: SessionUpdate): void {
|
||||
const configUpdate = update as SessionUpdate & {
|
||||
sessionUpdate: "config_option_update";
|
||||
};
|
||||
if ("options" in configUpdate && Array.isArray(configUpdate.options)) {
|
||||
const modelOption = configUpdate.options.find(
|
||||
(opt: { category?: string; kind?: Record<string, unknown> }) =>
|
||||
opt.category === "model",
|
||||
);
|
||||
if (modelOption?.kind?.type === "select") {
|
||||
const select = modelOption.kind;
|
||||
const currentModelId = select.currentValue;
|
||||
const availableModels: Array<{ id: string; name: string }> = [];
|
||||
|
||||
if (select.options?.type === "ungrouped") {
|
||||
for (const v of select.options.values) {
|
||||
availableModels.push({ id: v.value, name: v.name });
|
||||
}
|
||||
} else if (select.options?.type === "grouped") {
|
||||
for (const group of select.options.groups) {
|
||||
for (const v of group.options) {
|
||||
if (Array.isArray(configUpdate.configOptions)) {
|
||||
type SelectOption = { value: string; name: string };
|
||||
type SelectGroup = { group: string; name: string; options: SelectOption[] };
|
||||
type ModelConfigOption = {
|
||||
category?: string;
|
||||
type?: string;
|
||||
currentValue?: string;
|
||||
options?: Array<SelectOption | SelectGroup>;
|
||||
};
|
||||
|
||||
const modelOption = (
|
||||
configUpdate.configOptions as ModelConfigOption[]
|
||||
).find((opt) => opt.category === "model");
|
||||
|
||||
if (modelOption?.type === "select") {
|
||||
const currentModelId = modelOption.currentValue ?? "";
|
||||
const availableModels: Array<{ id: string; name: string }> = [];
|
||||
const rawOptions = modelOption.options ?? [];
|
||||
|
||||
for (const entry of rawOptions) {
|
||||
if ("value" in entry) {
|
||||
availableModels.push({ id: entry.value, name: entry.name });
|
||||
} else if ("options" in entry) {
|
||||
for (const v of entry.options) {
|
||||
availableModels.push({ id: v.value, name: v.name });
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ interface PreparedSession {
|
||||
|
||||
const prepared = new Map<string, PreparedSession>();
|
||||
const gooseToLocal = new Map<string, string>();
|
||||
const inFlight = new Map<string, Promise<string>>();
|
||||
|
||||
function makeKey(sessionId: string, personaId?: string): string {
|
||||
if (personaId && personaId.length > 0) {
|
||||
@@ -24,6 +25,28 @@ export async function prepareSession(
|
||||
): Promise<string> {
|
||||
const key = makeKey(sessionId, personaId);
|
||||
|
||||
const inFlightExisting = inFlight.get(key);
|
||||
if (inFlightExisting) {
|
||||
return inFlightExisting;
|
||||
}
|
||||
|
||||
const promise = doPrepareSession(sessionId, providerId, workingDir, personaId);
|
||||
inFlight.set(key, promise);
|
||||
try {
|
||||
return await promise;
|
||||
} finally {
|
||||
inFlight.delete(key);
|
||||
}
|
||||
}
|
||||
|
||||
async function doPrepareSession(
|
||||
sessionId: string,
|
||||
providerId: string,
|
||||
workingDir: string,
|
||||
personaId?: string,
|
||||
): Promise<string> {
|
||||
const key = makeKey(sessionId, personaId);
|
||||
|
||||
const existing = prepared.get(key) ?? prepared.get(sessionId);
|
||||
if (existing) {
|
||||
if (existing.workingDir !== workingDir) {
|
||||
@@ -49,15 +72,32 @@ export async function prepareSession(
|
||||
gooseSessionId = response.sessionId;
|
||||
}
|
||||
|
||||
await acpApi.setProvider(gooseSessionId, providerId);
|
||||
|
||||
prepared.set(key, { gooseSessionId, providerId, workingDir });
|
||||
prepared.set(sessionId, { gooseSessionId, providerId, workingDir });
|
||||
const entry = { gooseSessionId, providerId, workingDir };
|
||||
prepared.set(key, entry);
|
||||
prepared.set(sessionId, entry);
|
||||
gooseToLocal.set(gooseSessionId, sessionId);
|
||||
|
||||
await acpApi.setProvider(gooseSessionId, providerId);
|
||||
|
||||
return gooseSessionId;
|
||||
}
|
||||
|
||||
export function isPrepareInFlight(
|
||||
sessionId: string,
|
||||
personaId?: string,
|
||||
): boolean {
|
||||
const key = makeKey(sessionId, personaId);
|
||||
return inFlight.has(key) || inFlight.has(sessionId);
|
||||
}
|
||||
|
||||
export function isSessionPrepared(
|
||||
sessionId: string,
|
||||
personaId?: string,
|
||||
): boolean {
|
||||
const key = makeKey(sessionId, personaId);
|
||||
return prepared.has(key) || prepared.has(sessionId);
|
||||
}
|
||||
|
||||
export function getGooseSessionId(
|
||||
sessionId: string,
|
||||
personaId?: string,
|
||||
|
||||
@@ -112,7 +112,9 @@
|
||||
},
|
||||
"loading": {
|
||||
"thinking": "Thinking...",
|
||||
"responding": "Responding..."
|
||||
"responding": "Responding...",
|
||||
"spinningUp": "Spinning up {{providerName}}...",
|
||||
"spinningUpFallback": "Spinning up..."
|
||||
},
|
||||
"mention": {
|
||||
"ariaLabel": "Mention suggestions",
|
||||
|
||||
@@ -4,6 +4,7 @@ import type { Agent } from "./agents";
|
||||
// Chat state machine
|
||||
export type ChatState =
|
||||
| "idle"
|
||||
| "spinning_up"
|
||||
| "thinking"
|
||||
| "streaming"
|
||||
| "waiting"
|
||||
|
||||
Reference in New Issue
Block a user