feat(acp): paginate session list (#9199)

Signed-off-by: Kalvin Chau <kalvin@block.xyz>
This commit is contained in:
Kalvin C
2026-05-18 18:12:57 -07:00
committed by GitHub
parent bc9c1fc11a
commit 49dd575789
4 changed files with 549 additions and 27 deletions
+126 -8
View File
@@ -21,7 +21,7 @@ use crate::providers::inventory::{
InventoryIdentity, ProviderInventoryEntry, ProviderInventoryService, RefreshJobPlan,
RefreshPlan, RefreshSkipReason,
};
use crate::session::session_manager::SessionType;
use crate::session::session_manager::{SessionListCursor, SessionType};
use crate::session::{EnabledExtensionsState, Session, SessionManager};
use crate::source_roots::SourceRoot;
use crate::utils::sanitize_unicode_tags;
@@ -51,6 +51,7 @@ use agent_client_protocol::{
Responder,
};
use anyhow::Result;
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use fs_err as fs;
use futures::future::{BoxFuture, Either};
use futures::stream::{self, StreamExt};
@@ -58,7 +59,8 @@ use futures::FutureExt;
use rmcp::model::{
AnnotateAble, CallToolResult, RawContent, RawTextContent, ResourceContents, Role,
};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::{HashMap, HashSet};
use std::panic::AssertUnwindSafe;
use std::path::{Path, PathBuf};
@@ -93,6 +95,10 @@ pub type AcpProviderFactory = Arc<
+ Sync,
>;
const SESSION_LIST_PAGE_SIZE: usize = 50;
const ACP_SESSION_LIST_TYPES: [SessionType; 3] =
[SessionType::User, SessionType::Scheduled, SessionType::Acp];
/// Convenience conversions from any `Display` error into an `agent_client_protocol::Error`.
///
/// Replaces the repetitive `.internal_err()`
@@ -255,6 +261,94 @@ fn sid_short(id: &str) -> String {
id.chars().take(8).collect()
}
#[derive(Debug, Serialize, Deserialize)]
struct SessionListCursorToken {
updated_at: chrono::DateTime<chrono::Utc>,
// Goose stores updated_at with second precision in common write paths, so the
// cursor needs the full (updated_at, id) sort key to avoid skipping tied rows.
session_id: String,
filter_hash: String,
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct SessionListCursorFilters {
cwd: Option<String>,
session_types: Vec<String>,
non_empty: bool,
}
fn invalid_session_list_cursor(message: &'static str) -> agent_client_protocol::Error {
agent_client_protocol::Error::invalid_params().data(message)
}
// bind cursors to the effective filters so they cannot be reused for a different list.
fn session_list_filter_hash(
cwd: Option<&std::path::Path>,
session_types: &[SessionType],
) -> Result<String, agent_client_protocol::Error> {
let mut session_type_names = session_types
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>();
session_type_names.sort();
let filters = SessionListCursorFilters {
cwd: cwd.map(|path| path.to_string_lossy().to_string()),
session_types: session_type_names,
non_empty: true,
};
let bytes =
serde_json::to_vec(&filters).internal_err_ctx("Failed to encode session list filters")?;
Ok(URL_SAFE_NO_PAD.encode(Sha256::digest(bytes)))
}
fn decode_session_list_cursor(
cursor: Option<&str>,
cwd: Option<&std::path::Path>,
session_types: &[SessionType],
) -> Result<Option<SessionListCursor>, agent_client_protocol::Error> {
let Some(cursor) = cursor else {
return Ok(None);
};
let bytes = URL_SAFE_NO_PAD
.decode(cursor)
.map_err(|_| invalid_session_list_cursor("malformed session list cursor"))?;
let token: SessionListCursorToken = serde_json::from_slice(&bytes)
.map_err(|_| invalid_session_list_cursor("malformed session list cursor"))?;
if token.session_id.is_empty() || token.filter_hash.is_empty() {
return Err(invalid_session_list_cursor("malformed session list cursor"));
}
let expected_filter_hash = session_list_filter_hash(cwd, session_types)?;
if token.filter_hash != expected_filter_hash {
return Err(invalid_session_list_cursor(
"session list cursor does not match filters",
));
}
Ok(Some(SessionListCursor {
updated_at: token.updated_at,
session_id: token.session_id,
}))
}
fn encode_session_list_cursor(
cursor: &SessionListCursor,
cwd: Option<&std::path::Path>,
session_types: &[SessionType],
) -> Result<String, agent_client_protocol::Error> {
let token = SessionListCursorToken {
updated_at: cursor.updated_at,
session_id: cursor.session_id.clone(),
filter_hash: session_list_filter_hash(cwd, session_types)?,
};
let bytes =
serde_json::to_vec(&token).internal_err_ctx("Failed to encode session list cursor")?;
Ok(URL_SAFE_NO_PAD.encode(bytes))
}
fn session_meta(session: &Session) -> serde_json::Map<String, serde_json::Value> {
let mut meta = serde_json::Map::new();
meta.insert(
@@ -3368,16 +3462,35 @@ impl GooseAcpAgent {
Ok(())
}
async fn on_list_sessions(&self) -> Result<ListSessionsResponse, agent_client_protocol::Error> {
async fn on_list_sessions(
&self,
req: ListSessionsRequest,
) -> Result<ListSessionsResponse, agent_client_protocol::Error> {
if let Some(cwd) = req.cwd.as_deref() {
if !cwd.is_absolute() {
return Err(agent_client_protocol::Error::invalid_params()
.data("cwd must be an absolute path"));
}
}
let cwd = req.cwd.as_deref();
let cursor =
decode_session_list_cursor(req.cursor.as_deref(), cwd, &ACP_SESSION_LIST_TYPES)?;
// ACP clients see their own (Acp) sessions plus legacy User/Scheduled ones.
let sessions = self
let page = self
.session_manager
.list_sessions_by_types(&[SessionType::User, SessionType::Scheduled, SessionType::Acp])
.list_nonempty_sessions_by_types_paged(
&ACP_SESSION_LIST_TYPES,
cwd,
cursor.as_ref(),
SESSION_LIST_PAGE_SIZE,
)
.await
.internal_err()?;
let session_infos: Vec<SessionInfo> = sessions
let session_infos: Vec<SessionInfo> = page
.sessions
.into_iter()
.filter(|s| s.message_count > 0)
.map(|s| {
let meta = session_meta(&s);
SessionInfo::new(SessionId::new(s.id), s.working_dir)
@@ -3386,7 +3499,12 @@ impl GooseAcpAgent {
.meta(meta)
})
.collect();
Ok(ListSessionsResponse::new(session_infos))
let next_cursor = page
.next_cursor
.as_ref()
.map(|cursor| encode_session_list_cursor(cursor, cwd, &ACP_SESSION_LIST_TYPES))
.transpose()?;
Ok(ListSessionsResponse::new(session_infos).next_cursor(next_cursor))
}
async fn on_fork_session(
+5 -2
View File
@@ -330,9 +330,12 @@ impl HandleDispatchFrom<Client> for GooseAcpHandler {
.if_request({
let agent = agent.clone();
let cx = cx.clone();
|_req: ListSessionsRequest, responder: Responder<ListSessionsResponse>| async move {
|req: ListSessionsRequest, responder: Responder<ListSessionsResponse>| async move {
cx.spawn(async move {
responder.respond(agent.on_list_sessions().await?)?;
match agent.on_list_sessions(req).await {
Ok(response) => responder.respond(response)?,
Err(e) => responder.respond_with_error(e)?,
}
Ok(())
})?;
Ok(())
+292 -16
View File
@@ -274,6 +274,27 @@ pub struct SessionManager {
storage: Arc<SessionStorage>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct SessionListCursor {
pub(crate) updated_at: DateTime<Utc>,
pub(crate) session_id: String,
}
#[derive(Debug, Clone)]
pub(crate) struct SessionListPage {
pub(crate) sessions: Vec<Session>,
pub(crate) next_cursor: Option<SessionListCursor>,
}
#[derive(Debug, Default)]
struct SessionListQuery<'a> {
types: Option<&'a [SessionType]>,
working_dir: Option<&'a Path>,
cursor: Option<&'a SessionListCursor>,
limit: Option<usize>,
require_messages: bool,
}
#[derive(Debug, Clone)]
pub struct SessionNameUpdate {
pub session_id: String,
@@ -340,6 +361,18 @@ impl SessionManager {
self.storage.list_sessions_by_types(Some(types)).await
}
pub(crate) async fn list_nonempty_sessions_by_types_paged(
&self,
types: &[SessionType],
working_dir: Option<&Path>,
cursor: Option<&SessionListCursor>,
page_size: usize,
) -> Result<SessionListPage> {
self.storage
.list_nonempty_sessions_by_types_paged(types, working_dir, cursor, page_size)
.await
}
pub async fn list_all_sessions(&self) -> Result<Vec<Session>> {
self.storage.list_sessions_by_types(None).await
}
@@ -1478,17 +1511,46 @@ impl SessionStorage {
Self::replace_conversation_inner(pool, session_id, conversation).await
}
async fn list_sessions_by_types(&self, types: Option<&[SessionType]>) -> Result<Vec<Session>> {
let (where_clause, binds): (String, Vec<String>) = match types {
Some(t) if !t.is_empty() => {
let placeholders: String = t.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
(
format!("WHERE s.session_type IN ({})", placeholders),
t.iter().map(|t| t.to_string()).collect(),
)
}
Some(_) => return Ok(Vec::new()),
None => (String::new(), Vec::new()),
async fn list_sessions_matching(&self, options: SessionListQuery<'_>) -> Result<Vec<Session>> {
if matches!(options.types, Some(types) if types.is_empty()) {
return Ok(Vec::new());
}
let mut where_clauses = Vec::new();
if let Some(types) = options.types {
let placeholders = types.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
where_clauses.push(format!("s.session_type IN ({})", placeholders));
}
if options.working_dir.is_some() {
where_clauses.push("s.working_dir = ?".to_string());
}
if options.cursor.is_some() {
where_clauses.push(
"(datetime(s.updated_at) < datetime(?) \
OR (datetime(s.updated_at) = datetime(?) AND s.id < ?))"
.to_string(),
);
}
let where_clause = if where_clauses.is_empty() {
String::new()
} else {
format!("WHERE {}", where_clauses.join(" AND "))
};
let message_join = if options.require_messages {
"JOIN messages m ON s.id = m.session_id"
} else {
"LEFT JOIN messages m ON s.id = m.session_id"
};
let order_by = if options.cursor.is_some() || options.limit.is_some() {
"ORDER BY datetime(s.updated_at) DESC, s.id DESC"
} else {
"ORDER BY s.updated_at DESC"
};
let limit_clause = if options.limit.is_some() {
"LIMIT ?"
} else {
""
};
let query = format!(
@@ -1502,23 +1564,90 @@ impl SessionStorage {
s.archived_at, s.project_id,
COUNT(m.id) as message_count
FROM sessions s
LEFT JOIN messages m ON s.id = m.session_id
{}
{}
GROUP BY s.id
ORDER BY s.updated_at DESC
{}
{}
"#,
where_clause
message_join, where_clause, order_by, limit_clause
);
let mut q = sqlx::query_as::<_, Session>(&query);
for b in &binds {
q = q.bind(b);
if let Some(types) = options.types {
for session_type in types {
q = q.bind(session_type.to_string());
}
}
if let Some(working_dir) = options.working_dir {
q = q.bind(working_dir.to_string_lossy().to_string());
}
if let Some(cursor) = options.cursor {
let updated_at = cursor.updated_at.to_rfc3339();
// Normalize mixed SQLite CURRENT_TIMESTAMP and RFC3339 stored values.
q = q.bind(updated_at.clone());
q = q.bind(updated_at);
q = q.bind(&cursor.session_id);
}
if let Some(limit) = options.limit {
q = q.bind(limit as i64);
}
let pool = self.pool().await?;
q.fetch_all(pool).await.map_err(Into::into)
}
async fn list_sessions_by_types(&self, types: Option<&[SessionType]>) -> Result<Vec<Session>> {
self.list_sessions_matching(SessionListQuery {
types,
..Default::default()
})
.await
}
async fn list_nonempty_sessions_by_types_paged(
&self,
types: &[SessionType],
working_dir: Option<&Path>,
cursor: Option<&SessionListCursor>,
page_size: usize,
) -> Result<SessionListPage> {
if types.is_empty() || page_size == 0 {
return Ok(SessionListPage {
sessions: Vec::new(),
next_cursor: None,
});
}
let mut sessions = self
.list_sessions_matching(SessionListQuery {
types: Some(types),
working_dir,
cursor,
limit: Some(page_size + 1),
require_messages: true,
})
.await?;
let has_next_page = sessions.len() > page_size;
let next_cursor = if has_next_page {
let anchor = &sessions[page_size - 1];
Some(SessionListCursor {
updated_at: anchor.updated_at,
session_id: anchor.id.clone(),
})
} else {
None
};
if has_next_page {
sessions.truncate(page_size);
}
Ok(SessionListPage {
sessions,
next_cursor,
})
}
async fn list_sessions(&self) -> Result<Vec<Session>> {
self.list_sessions_by_types(Some(&[SessionType::User, SessionType::Scheduled]))
.await
@@ -1843,6 +1972,89 @@ mod tests {
const NUM_CONCURRENT_SESSIONS: i32 = 10;
async fn create_session_for_list(
sm: &SessionManager,
working_dir: &str,
has_message: bool,
) -> String {
let session = sm
.create_session(
PathBuf::from(working_dir),
format!("Session in {working_dir}"),
SessionType::User,
GooseMode::default(),
)
.await
.unwrap();
if has_message {
sm.add_message(&session.id, &Message::user().with_text("message"))
.await
.unwrap();
}
session.id
}
async fn set_sessions_updated_at(
sm: &SessionManager,
session_ids: &[String],
updated_at: &str,
) {
let pool = sm.storage().pool().await.unwrap();
let updated_at = chrono::DateTime::parse_from_rfc3339(updated_at).unwrap();
let timestamp = updated_at.format("%Y-%m-%d %H:%M:%S").to_string();
for session_id in session_ids {
sqlx::query("UPDATE sessions SET updated_at = ? WHERE id = ?")
.bind(&timestamp)
.bind(session_id)
.execute(pool)
.await
.unwrap();
}
}
async fn expected_session_list_ids(sm: &SessionManager, session_ids: &[String]) -> Vec<String> {
let mut sessions = Vec::new();
for session_id in session_ids {
sessions.push(sm.get_session(session_id, false).await.unwrap());
}
sessions.sort_by(|a, b| {
b.updated_at
.cmp(&a.updated_at)
.then_with(|| b.id.cmp(&a.id))
});
sessions.into_iter().map(|session| session.id).collect()
}
async fn assert_session_list_page(
sm: &SessionManager,
cursor: Option<&SessionListCursor>,
working_dir: Option<&str>,
page_size: usize,
expected_ids: &[String],
expected_next_cursor: bool,
) -> Option<SessionListCursor> {
let page = sm
.list_nonempty_sessions_by_types_paged(
&[SessionType::User],
working_dir.map(Path::new),
cursor,
page_size,
)
.await
.unwrap();
let ids = page
.sessions
.iter()
.map(|session| session.id.clone())
.collect::<Vec<_>>();
assert_eq!(ids.as_slice(), expected_ids);
assert_eq!(page.next_cursor.is_some(), expected_next_cursor);
page.next_cursor
}
async fn run_lock_upgrade_attempt(
pool: Pool<Sqlite>,
session_id: String,
@@ -1935,6 +2147,70 @@ mod tests {
);
}
#[tokio::test]
async fn test_session_list_paged_first_second_and_final_page() {
let temp_dir = TempDir::new().unwrap();
let sm = SessionManager::new(temp_dir.path().to_path_buf());
let mut expected_ids = Vec::new();
for _ in 0..5 {
expected_ids.push(create_session_for_list(&sm, "/tmp/session-list", true).await);
}
let expected_ids = expected_session_list_ids(&sm, &expected_ids).await;
let cursor = assert_session_list_page(&sm, None, None, 2, &expected_ids[0..2], true).await;
let cursor =
assert_session_list_page(&sm, cursor.as_ref(), None, 2, &expected_ids[2..4], true)
.await;
assert_session_list_page(&sm, cursor.as_ref(), None, 2, &expected_ids[4..5], false).await;
}
#[tokio::test]
async fn test_session_list_paged_uses_id_tiebreaker_for_duplicate_updated_at() {
let temp_dir = TempDir::new().unwrap();
let sm = SessionManager::new(temp_dir.path().to_path_buf());
let mut expected_ids = Vec::new();
for _ in 0..3 {
expected_ids.push(create_session_for_list(&sm, "/tmp/session-list", true).await);
}
set_sessions_updated_at(&sm, &expected_ids, "2024-01-01T00:00:00Z").await;
let expected_ids = expected_session_list_ids(&sm, &expected_ids).await;
let cursor = assert_session_list_page(&sm, None, None, 2, &expected_ids[0..2], true).await;
assert_session_list_page(&sm, cursor.as_ref(), None, 2, &expected_ids[2..3], false).await;
}
#[tokio::test]
async fn test_session_list_paged_filters_empty_and_cwd_before_pagination() {
let temp_dir = TempDir::new().unwrap();
let sm = SessionManager::new(temp_dir.path().to_path_buf());
let expected_ids = vec![
create_session_for_list(&sm, "/tmp/session-list/a", true).await,
create_session_for_list(&sm, "/tmp/session-list/a", true).await,
];
create_session_for_list(&sm, "/tmp/session-list/a", false).await;
create_session_for_list(&sm, "/tmp/session-list/b", true).await;
let expected_ids = expected_session_list_ids(&sm, &expected_ids).await;
let cursor = assert_session_list_page(
&sm,
None,
Some("/tmp/session-list/a"),
1,
&expected_ids[0..1],
true,
)
.await;
assert_session_list_page(
&sm,
cursor.as_ref(),
Some("/tmp/session-list/a"),
1,
&expected_ids[1..2],
false,
)
.await;
}
#[tokio::test]
async fn test_concurrent_session_creation() {
let temp_dir = TempDir::new().unwrap();
+126 -1
View File
@@ -1,8 +1,10 @@
#[allow(dead_code)]
#[path = "acp_common_tests/mod.rs"]
mod common_tests;
use common_tests::fixtures::run_test;
use agent_client_protocol::schema::{ListSessionsRequest, ListSessionsResponse};
use agent_client_protocol::ErrorCode;
use common_tests::fixtures::server::AcpServerConnection;
use common_tests::fixtures::{run_test, Connection, OpenAiFixture, TestConnectionConfig};
use common_tests::{
run_close_session, run_config_mcp, run_config_option_mode_set, run_config_option_model_set,
run_delete_session, run_fs_read_text_file_true, run_fs_write_text_file_false,
@@ -14,10 +16,65 @@ use common_tests::{
run_prompt_mcp, run_prompt_model_mismatch, run_prompt_skill,
run_session_name_update_notification, run_shell_terminal_false, run_shell_terminal_true,
};
use goose::config::GooseMode;
use goose::conversation::message::Message;
use goose::session::{SessionManager, SessionType};
use std::path::Path;
tests_config_option_set_error!(AcpServerConnection);
tests_mode_set_error!(AcpServerConnection);
async fn seed_list_sessions(data_root: &Path, working_dir: &Path, count: usize) {
let session_manager = SessionManager::new(data_root.to_path_buf());
for index in 0..count {
let session = session_manager
.create_session(
working_dir.to_path_buf(),
format!("Seed session {index}"),
SessionType::Acp,
GooseMode::default(),
)
.await
.unwrap();
session_manager
.add_message(&session.id, &Message::user().with_text("hello"))
.await
.unwrap();
}
}
async fn new_connection(data_root: &Path) -> AcpServerConnection {
let openai = OpenAiFixture::new(
vec![],
<AcpServerConnection as Connection>::expected_session_id(),
)
.await;
<AcpServerConnection as Connection>::new(
TestConnectionConfig {
data_root: data_root.to_path_buf(),
..Default::default()
},
openai,
)
.await
}
async fn list_sessions_request(
conn: &AcpServerConnection,
request: ListSessionsRequest,
) -> anyhow::Result<ListSessionsResponse> {
conn.cx()
.send_request(request)
.block_task()
.await
.map_err(Into::into)
}
fn assert_invalid_params(error: anyhow::Error) {
let acp_error = error.downcast::<agent_client_protocol::Error>().unwrap();
assert_eq!(acp_error.code, ErrorCode::InvalidParams);
}
#[test]
fn test_config_mcp() {
run_test(async { run_config_mcp::<AcpServerConnection>().await });
@@ -33,6 +90,74 @@ fn test_list_sessions() {
run_test(async { run_list_sessions::<AcpServerConnection>().await });
}
#[test]
fn test_list_sessions_pagination() {
run_test(async {
let data_root = tempfile::tempdir().unwrap();
seed_list_sessions(data_root.path(), Path::new("/tmp/acp-session-list"), 51).await;
let conn = new_connection(data_root.path()).await;
let first = list_sessions_request(&conn, ListSessionsRequest::new())
.await
.unwrap();
assert_eq!(first.sessions.len(), 50);
let second = list_sessions_request(
&conn,
ListSessionsRequest::new().cursor(first.next_cursor.clone().unwrap()),
)
.await
.unwrap();
assert_eq!(second.sessions.len(), 1);
assert!(second.next_cursor.is_none());
let second_id = &second.sessions[0].session_id;
assert!(first
.sessions
.iter()
.all(|session| session.session_id != *second_id));
});
}
#[test]
fn test_list_sessions_invalid_params() {
run_test(async {
let data_root = tempfile::tempdir().unwrap();
let cwd = tempfile::tempdir().unwrap();
let other_cwd = tempfile::tempdir().unwrap();
seed_list_sessions(data_root.path(), cwd.path(), 51).await;
let conn = new_connection(data_root.path()).await;
let error =
list_sessions_request(&conn, ListSessionsRequest::new().cursor("*".to_string()))
.await
.unwrap_err();
assert_invalid_params(error);
let error = list_sessions_request(
&conn,
ListSessionsRequest::new().cwd(std::path::PathBuf::from("relative/path")),
)
.await
.unwrap_err();
assert_invalid_params(error);
let first = list_sessions_request(&conn, ListSessionsRequest::new().cwd(cwd.path()))
.await
.unwrap();
let error = list_sessions_request(
&conn,
ListSessionsRequest::new()
.cwd(other_cwd.path())
.cursor(first.next_cursor.unwrap()),
)
.await
.unwrap_err();
assert_invalid_params(error);
});
}
#[test]
fn test_session_name_update_notification() {
run_test(async { run_session_name_update_notification::<AcpServerConnection>().await });