mirror of
https://github.com/block/goose.git
synced 2026-06-01 22:11:07 +02:00
feat(acp): paginate session list (#9199)
Signed-off-by: Kalvin Chau <kalvin@block.xyz>
This commit is contained in:
@@ -21,7 +21,7 @@ use crate::providers::inventory::{
|
||||
InventoryIdentity, ProviderInventoryEntry, ProviderInventoryService, RefreshJobPlan,
|
||||
RefreshPlan, RefreshSkipReason,
|
||||
};
|
||||
use crate::session::session_manager::SessionType;
|
||||
use crate::session::session_manager::{SessionListCursor, SessionType};
|
||||
use crate::session::{EnabledExtensionsState, Session, SessionManager};
|
||||
use crate::source_roots::SourceRoot;
|
||||
use crate::utils::sanitize_unicode_tags;
|
||||
@@ -51,6 +51,7 @@ use agent_client_protocol::{
|
||||
Responder,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
|
||||
use fs_err as fs;
|
||||
use futures::future::{BoxFuture, Either};
|
||||
use futures::stream::{self, StreamExt};
|
||||
@@ -58,7 +59,8 @@ use futures::FutureExt;
|
||||
use rmcp::model::{
|
||||
AnnotateAble, CallToolResult, RawContent, RawTextContent, ResourceContents, Role,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::panic::AssertUnwindSafe;
|
||||
use std::path::{Path, PathBuf};
|
||||
@@ -93,6 +95,10 @@ pub type AcpProviderFactory = Arc<
|
||||
+ Sync,
|
||||
>;
|
||||
|
||||
const SESSION_LIST_PAGE_SIZE: usize = 50;
|
||||
const ACP_SESSION_LIST_TYPES: [SessionType; 3] =
|
||||
[SessionType::User, SessionType::Scheduled, SessionType::Acp];
|
||||
|
||||
/// Convenience conversions from any `Display` error into an `agent_client_protocol::Error`.
|
||||
///
|
||||
/// Replaces the repetitive `.internal_err()`
|
||||
@@ -255,6 +261,94 @@ fn sid_short(id: &str) -> String {
|
||||
id.chars().take(8).collect()
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct SessionListCursorToken {
|
||||
updated_at: chrono::DateTime<chrono::Utc>,
|
||||
// Goose stores updated_at with second precision in common write paths, so the
|
||||
// cursor needs the full (updated_at, id) sort key to avoid skipping tied rows.
|
||||
session_id: String,
|
||||
filter_hash: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
struct SessionListCursorFilters {
|
||||
cwd: Option<String>,
|
||||
session_types: Vec<String>,
|
||||
non_empty: bool,
|
||||
}
|
||||
|
||||
fn invalid_session_list_cursor(message: &'static str) -> agent_client_protocol::Error {
|
||||
agent_client_protocol::Error::invalid_params().data(message)
|
||||
}
|
||||
|
||||
// bind cursors to the effective filters so they cannot be reused for a different list.
|
||||
fn session_list_filter_hash(
|
||||
cwd: Option<&std::path::Path>,
|
||||
session_types: &[SessionType],
|
||||
) -> Result<String, agent_client_protocol::Error> {
|
||||
let mut session_type_names = session_types
|
||||
.iter()
|
||||
.map(ToString::to_string)
|
||||
.collect::<Vec<_>>();
|
||||
session_type_names.sort();
|
||||
let filters = SessionListCursorFilters {
|
||||
cwd: cwd.map(|path| path.to_string_lossy().to_string()),
|
||||
session_types: session_type_names,
|
||||
non_empty: true,
|
||||
};
|
||||
let bytes =
|
||||
serde_json::to_vec(&filters).internal_err_ctx("Failed to encode session list filters")?;
|
||||
Ok(URL_SAFE_NO_PAD.encode(Sha256::digest(bytes)))
|
||||
}
|
||||
|
||||
fn decode_session_list_cursor(
|
||||
cursor: Option<&str>,
|
||||
cwd: Option<&std::path::Path>,
|
||||
session_types: &[SessionType],
|
||||
) -> Result<Option<SessionListCursor>, agent_client_protocol::Error> {
|
||||
let Some(cursor) = cursor else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let bytes = URL_SAFE_NO_PAD
|
||||
.decode(cursor)
|
||||
.map_err(|_| invalid_session_list_cursor("malformed session list cursor"))?;
|
||||
let token: SessionListCursorToken = serde_json::from_slice(&bytes)
|
||||
.map_err(|_| invalid_session_list_cursor("malformed session list cursor"))?;
|
||||
|
||||
if token.session_id.is_empty() || token.filter_hash.is_empty() {
|
||||
return Err(invalid_session_list_cursor("malformed session list cursor"));
|
||||
}
|
||||
|
||||
let expected_filter_hash = session_list_filter_hash(cwd, session_types)?;
|
||||
if token.filter_hash != expected_filter_hash {
|
||||
return Err(invalid_session_list_cursor(
|
||||
"session list cursor does not match filters",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Some(SessionListCursor {
|
||||
updated_at: token.updated_at,
|
||||
session_id: token.session_id,
|
||||
}))
|
||||
}
|
||||
|
||||
fn encode_session_list_cursor(
|
||||
cursor: &SessionListCursor,
|
||||
cwd: Option<&std::path::Path>,
|
||||
session_types: &[SessionType],
|
||||
) -> Result<String, agent_client_protocol::Error> {
|
||||
let token = SessionListCursorToken {
|
||||
updated_at: cursor.updated_at,
|
||||
session_id: cursor.session_id.clone(),
|
||||
filter_hash: session_list_filter_hash(cwd, session_types)?,
|
||||
};
|
||||
let bytes =
|
||||
serde_json::to_vec(&token).internal_err_ctx("Failed to encode session list cursor")?;
|
||||
Ok(URL_SAFE_NO_PAD.encode(bytes))
|
||||
}
|
||||
|
||||
fn session_meta(session: &Session) -> serde_json::Map<String, serde_json::Value> {
|
||||
let mut meta = serde_json::Map::new();
|
||||
meta.insert(
|
||||
@@ -3368,16 +3462,35 @@ impl GooseAcpAgent {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn on_list_sessions(&self) -> Result<ListSessionsResponse, agent_client_protocol::Error> {
|
||||
async fn on_list_sessions(
|
||||
&self,
|
||||
req: ListSessionsRequest,
|
||||
) -> Result<ListSessionsResponse, agent_client_protocol::Error> {
|
||||
if let Some(cwd) = req.cwd.as_deref() {
|
||||
if !cwd.is_absolute() {
|
||||
return Err(agent_client_protocol::Error::invalid_params()
|
||||
.data("cwd must be an absolute path"));
|
||||
}
|
||||
}
|
||||
|
||||
let cwd = req.cwd.as_deref();
|
||||
let cursor =
|
||||
decode_session_list_cursor(req.cursor.as_deref(), cwd, &ACP_SESSION_LIST_TYPES)?;
|
||||
|
||||
// ACP clients see their own (Acp) sessions plus legacy User/Scheduled ones.
|
||||
let sessions = self
|
||||
let page = self
|
||||
.session_manager
|
||||
.list_sessions_by_types(&[SessionType::User, SessionType::Scheduled, SessionType::Acp])
|
||||
.list_nonempty_sessions_by_types_paged(
|
||||
&ACP_SESSION_LIST_TYPES,
|
||||
cwd,
|
||||
cursor.as_ref(),
|
||||
SESSION_LIST_PAGE_SIZE,
|
||||
)
|
||||
.await
|
||||
.internal_err()?;
|
||||
let session_infos: Vec<SessionInfo> = sessions
|
||||
let session_infos: Vec<SessionInfo> = page
|
||||
.sessions
|
||||
.into_iter()
|
||||
.filter(|s| s.message_count > 0)
|
||||
.map(|s| {
|
||||
let meta = session_meta(&s);
|
||||
SessionInfo::new(SessionId::new(s.id), s.working_dir)
|
||||
@@ -3386,7 +3499,12 @@ impl GooseAcpAgent {
|
||||
.meta(meta)
|
||||
})
|
||||
.collect();
|
||||
Ok(ListSessionsResponse::new(session_infos))
|
||||
let next_cursor = page
|
||||
.next_cursor
|
||||
.as_ref()
|
||||
.map(|cursor| encode_session_list_cursor(cursor, cwd, &ACP_SESSION_LIST_TYPES))
|
||||
.transpose()?;
|
||||
Ok(ListSessionsResponse::new(session_infos).next_cursor(next_cursor))
|
||||
}
|
||||
|
||||
async fn on_fork_session(
|
||||
|
||||
@@ -330,9 +330,12 @@ impl HandleDispatchFrom<Client> for GooseAcpHandler {
|
||||
.if_request({
|
||||
let agent = agent.clone();
|
||||
let cx = cx.clone();
|
||||
|_req: ListSessionsRequest, responder: Responder<ListSessionsResponse>| async move {
|
||||
|req: ListSessionsRequest, responder: Responder<ListSessionsResponse>| async move {
|
||||
cx.spawn(async move {
|
||||
responder.respond(agent.on_list_sessions().await?)?;
|
||||
match agent.on_list_sessions(req).await {
|
||||
Ok(response) => responder.respond(response)?,
|
||||
Err(e) => responder.respond_with_error(e)?,
|
||||
}
|
||||
Ok(())
|
||||
})?;
|
||||
Ok(())
|
||||
|
||||
@@ -274,6 +274,27 @@ pub struct SessionManager {
|
||||
storage: Arc<SessionStorage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub(crate) struct SessionListCursor {
|
||||
pub(crate) updated_at: DateTime<Utc>,
|
||||
pub(crate) session_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct SessionListPage {
|
||||
pub(crate) sessions: Vec<Session>,
|
||||
pub(crate) next_cursor: Option<SessionListCursor>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct SessionListQuery<'a> {
|
||||
types: Option<&'a [SessionType]>,
|
||||
working_dir: Option<&'a Path>,
|
||||
cursor: Option<&'a SessionListCursor>,
|
||||
limit: Option<usize>,
|
||||
require_messages: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionNameUpdate {
|
||||
pub session_id: String,
|
||||
@@ -340,6 +361,18 @@ impl SessionManager {
|
||||
self.storage.list_sessions_by_types(Some(types)).await
|
||||
}
|
||||
|
||||
pub(crate) async fn list_nonempty_sessions_by_types_paged(
|
||||
&self,
|
||||
types: &[SessionType],
|
||||
working_dir: Option<&Path>,
|
||||
cursor: Option<&SessionListCursor>,
|
||||
page_size: usize,
|
||||
) -> Result<SessionListPage> {
|
||||
self.storage
|
||||
.list_nonempty_sessions_by_types_paged(types, working_dir, cursor, page_size)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn list_all_sessions(&self) -> Result<Vec<Session>> {
|
||||
self.storage.list_sessions_by_types(None).await
|
||||
}
|
||||
@@ -1478,17 +1511,46 @@ impl SessionStorage {
|
||||
Self::replace_conversation_inner(pool, session_id, conversation).await
|
||||
}
|
||||
|
||||
async fn list_sessions_by_types(&self, types: Option<&[SessionType]>) -> Result<Vec<Session>> {
|
||||
let (where_clause, binds): (String, Vec<String>) = match types {
|
||||
Some(t) if !t.is_empty() => {
|
||||
let placeholders: String = t.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
|
||||
(
|
||||
format!("WHERE s.session_type IN ({})", placeholders),
|
||||
t.iter().map(|t| t.to_string()).collect(),
|
||||
)
|
||||
}
|
||||
Some(_) => return Ok(Vec::new()),
|
||||
None => (String::new(), Vec::new()),
|
||||
async fn list_sessions_matching(&self, options: SessionListQuery<'_>) -> Result<Vec<Session>> {
|
||||
if matches!(options.types, Some(types) if types.is_empty()) {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut where_clauses = Vec::new();
|
||||
if let Some(types) = options.types {
|
||||
let placeholders = types.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
|
||||
where_clauses.push(format!("s.session_type IN ({})", placeholders));
|
||||
}
|
||||
if options.working_dir.is_some() {
|
||||
where_clauses.push("s.working_dir = ?".to_string());
|
||||
}
|
||||
if options.cursor.is_some() {
|
||||
where_clauses.push(
|
||||
"(datetime(s.updated_at) < datetime(?) \
|
||||
OR (datetime(s.updated_at) = datetime(?) AND s.id < ?))"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let where_clause = if where_clauses.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("WHERE {}", where_clauses.join(" AND "))
|
||||
};
|
||||
let message_join = if options.require_messages {
|
||||
"JOIN messages m ON s.id = m.session_id"
|
||||
} else {
|
||||
"LEFT JOIN messages m ON s.id = m.session_id"
|
||||
};
|
||||
let order_by = if options.cursor.is_some() || options.limit.is_some() {
|
||||
"ORDER BY datetime(s.updated_at) DESC, s.id DESC"
|
||||
} else {
|
||||
"ORDER BY s.updated_at DESC"
|
||||
};
|
||||
let limit_clause = if options.limit.is_some() {
|
||||
"LIMIT ?"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
let query = format!(
|
||||
@@ -1502,23 +1564,90 @@ impl SessionStorage {
|
||||
s.archived_at, s.project_id,
|
||||
COUNT(m.id) as message_count
|
||||
FROM sessions s
|
||||
LEFT JOIN messages m ON s.id = m.session_id
|
||||
{}
|
||||
{}
|
||||
GROUP BY s.id
|
||||
ORDER BY s.updated_at DESC
|
||||
{}
|
||||
{}
|
||||
"#,
|
||||
where_clause
|
||||
message_join, where_clause, order_by, limit_clause
|
||||
);
|
||||
|
||||
let mut q = sqlx::query_as::<_, Session>(&query);
|
||||
for b in &binds {
|
||||
q = q.bind(b);
|
||||
if let Some(types) = options.types {
|
||||
for session_type in types {
|
||||
q = q.bind(session_type.to_string());
|
||||
}
|
||||
}
|
||||
if let Some(working_dir) = options.working_dir {
|
||||
q = q.bind(working_dir.to_string_lossy().to_string());
|
||||
}
|
||||
if let Some(cursor) = options.cursor {
|
||||
let updated_at = cursor.updated_at.to_rfc3339();
|
||||
// Normalize mixed SQLite CURRENT_TIMESTAMP and RFC3339 stored values.
|
||||
q = q.bind(updated_at.clone());
|
||||
q = q.bind(updated_at);
|
||||
q = q.bind(&cursor.session_id);
|
||||
}
|
||||
if let Some(limit) = options.limit {
|
||||
q = q.bind(limit as i64);
|
||||
}
|
||||
|
||||
let pool = self.pool().await?;
|
||||
q.fetch_all(pool).await.map_err(Into::into)
|
||||
}
|
||||
|
||||
async fn list_sessions_by_types(&self, types: Option<&[SessionType]>) -> Result<Vec<Session>> {
|
||||
self.list_sessions_matching(SessionListQuery {
|
||||
types,
|
||||
..Default::default()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn list_nonempty_sessions_by_types_paged(
|
||||
&self,
|
||||
types: &[SessionType],
|
||||
working_dir: Option<&Path>,
|
||||
cursor: Option<&SessionListCursor>,
|
||||
page_size: usize,
|
||||
) -> Result<SessionListPage> {
|
||||
if types.is_empty() || page_size == 0 {
|
||||
return Ok(SessionListPage {
|
||||
sessions: Vec::new(),
|
||||
next_cursor: None,
|
||||
});
|
||||
}
|
||||
|
||||
let mut sessions = self
|
||||
.list_sessions_matching(SessionListQuery {
|
||||
types: Some(types),
|
||||
working_dir,
|
||||
cursor,
|
||||
limit: Some(page_size + 1),
|
||||
require_messages: true,
|
||||
})
|
||||
.await?;
|
||||
let has_next_page = sessions.len() > page_size;
|
||||
let next_cursor = if has_next_page {
|
||||
let anchor = &sessions[page_size - 1];
|
||||
Some(SessionListCursor {
|
||||
updated_at: anchor.updated_at,
|
||||
session_id: anchor.id.clone(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if has_next_page {
|
||||
sessions.truncate(page_size);
|
||||
}
|
||||
|
||||
Ok(SessionListPage {
|
||||
sessions,
|
||||
next_cursor,
|
||||
})
|
||||
}
|
||||
|
||||
async fn list_sessions(&self) -> Result<Vec<Session>> {
|
||||
self.list_sessions_by_types(Some(&[SessionType::User, SessionType::Scheduled]))
|
||||
.await
|
||||
@@ -1843,6 +1972,89 @@ mod tests {
|
||||
|
||||
const NUM_CONCURRENT_SESSIONS: i32 = 10;
|
||||
|
||||
async fn create_session_for_list(
|
||||
sm: &SessionManager,
|
||||
working_dir: &str,
|
||||
has_message: bool,
|
||||
) -> String {
|
||||
let session = sm
|
||||
.create_session(
|
||||
PathBuf::from(working_dir),
|
||||
format!("Session in {working_dir}"),
|
||||
SessionType::User,
|
||||
GooseMode::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
if has_message {
|
||||
sm.add_message(&session.id, &Message::user().with_text("message"))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
session.id
|
||||
}
|
||||
|
||||
async fn set_sessions_updated_at(
|
||||
sm: &SessionManager,
|
||||
session_ids: &[String],
|
||||
updated_at: &str,
|
||||
) {
|
||||
let pool = sm.storage().pool().await.unwrap();
|
||||
let updated_at = chrono::DateTime::parse_from_rfc3339(updated_at).unwrap();
|
||||
let timestamp = updated_at.format("%Y-%m-%d %H:%M:%S").to_string();
|
||||
|
||||
for session_id in session_ids {
|
||||
sqlx::query("UPDATE sessions SET updated_at = ? WHERE id = ?")
|
||||
.bind(×tamp)
|
||||
.bind(session_id)
|
||||
.execute(pool)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
async fn expected_session_list_ids(sm: &SessionManager, session_ids: &[String]) -> Vec<String> {
|
||||
let mut sessions = Vec::new();
|
||||
for session_id in session_ids {
|
||||
sessions.push(sm.get_session(session_id, false).await.unwrap());
|
||||
}
|
||||
sessions.sort_by(|a, b| {
|
||||
b.updated_at
|
||||
.cmp(&a.updated_at)
|
||||
.then_with(|| b.id.cmp(&a.id))
|
||||
});
|
||||
sessions.into_iter().map(|session| session.id).collect()
|
||||
}
|
||||
|
||||
async fn assert_session_list_page(
|
||||
sm: &SessionManager,
|
||||
cursor: Option<&SessionListCursor>,
|
||||
working_dir: Option<&str>,
|
||||
page_size: usize,
|
||||
expected_ids: &[String],
|
||||
expected_next_cursor: bool,
|
||||
) -> Option<SessionListCursor> {
|
||||
let page = sm
|
||||
.list_nonempty_sessions_by_types_paged(
|
||||
&[SessionType::User],
|
||||
working_dir.map(Path::new),
|
||||
cursor,
|
||||
page_size,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let ids = page
|
||||
.sessions
|
||||
.iter()
|
||||
.map(|session| session.id.clone())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(ids.as_slice(), expected_ids);
|
||||
assert_eq!(page.next_cursor.is_some(), expected_next_cursor);
|
||||
page.next_cursor
|
||||
}
|
||||
|
||||
async fn run_lock_upgrade_attempt(
|
||||
pool: Pool<Sqlite>,
|
||||
session_id: String,
|
||||
@@ -1935,6 +2147,70 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_list_paged_first_second_and_final_page() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let sm = SessionManager::new(temp_dir.path().to_path_buf());
|
||||
let mut expected_ids = Vec::new();
|
||||
for _ in 0..5 {
|
||||
expected_ids.push(create_session_for_list(&sm, "/tmp/session-list", true).await);
|
||||
}
|
||||
let expected_ids = expected_session_list_ids(&sm, &expected_ids).await;
|
||||
|
||||
let cursor = assert_session_list_page(&sm, None, None, 2, &expected_ids[0..2], true).await;
|
||||
let cursor =
|
||||
assert_session_list_page(&sm, cursor.as_ref(), None, 2, &expected_ids[2..4], true)
|
||||
.await;
|
||||
assert_session_list_page(&sm, cursor.as_ref(), None, 2, &expected_ids[4..5], false).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_list_paged_uses_id_tiebreaker_for_duplicate_updated_at() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let sm = SessionManager::new(temp_dir.path().to_path_buf());
|
||||
let mut expected_ids = Vec::new();
|
||||
for _ in 0..3 {
|
||||
expected_ids.push(create_session_for_list(&sm, "/tmp/session-list", true).await);
|
||||
}
|
||||
set_sessions_updated_at(&sm, &expected_ids, "2024-01-01T00:00:00Z").await;
|
||||
let expected_ids = expected_session_list_ids(&sm, &expected_ids).await;
|
||||
|
||||
let cursor = assert_session_list_page(&sm, None, None, 2, &expected_ids[0..2], true).await;
|
||||
assert_session_list_page(&sm, cursor.as_ref(), None, 2, &expected_ids[2..3], false).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_list_paged_filters_empty_and_cwd_before_pagination() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let sm = SessionManager::new(temp_dir.path().to_path_buf());
|
||||
let expected_ids = vec![
|
||||
create_session_for_list(&sm, "/tmp/session-list/a", true).await,
|
||||
create_session_for_list(&sm, "/tmp/session-list/a", true).await,
|
||||
];
|
||||
create_session_for_list(&sm, "/tmp/session-list/a", false).await;
|
||||
create_session_for_list(&sm, "/tmp/session-list/b", true).await;
|
||||
let expected_ids = expected_session_list_ids(&sm, &expected_ids).await;
|
||||
|
||||
let cursor = assert_session_list_page(
|
||||
&sm,
|
||||
None,
|
||||
Some("/tmp/session-list/a"),
|
||||
1,
|
||||
&expected_ids[0..1],
|
||||
true,
|
||||
)
|
||||
.await;
|
||||
assert_session_list_page(
|
||||
&sm,
|
||||
cursor.as_ref(),
|
||||
Some("/tmp/session-list/a"),
|
||||
1,
|
||||
&expected_ids[1..2],
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_session_creation() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
#[allow(dead_code)]
|
||||
#[path = "acp_common_tests/mod.rs"]
|
||||
mod common_tests;
|
||||
use common_tests::fixtures::run_test;
|
||||
use agent_client_protocol::schema::{ListSessionsRequest, ListSessionsResponse};
|
||||
use agent_client_protocol::ErrorCode;
|
||||
use common_tests::fixtures::server::AcpServerConnection;
|
||||
use common_tests::fixtures::{run_test, Connection, OpenAiFixture, TestConnectionConfig};
|
||||
use common_tests::{
|
||||
run_close_session, run_config_mcp, run_config_option_mode_set, run_config_option_model_set,
|
||||
run_delete_session, run_fs_read_text_file_true, run_fs_write_text_file_false,
|
||||
@@ -14,10 +16,65 @@ use common_tests::{
|
||||
run_prompt_mcp, run_prompt_model_mismatch, run_prompt_skill,
|
||||
run_session_name_update_notification, run_shell_terminal_false, run_shell_terminal_true,
|
||||
};
|
||||
use goose::config::GooseMode;
|
||||
use goose::conversation::message::Message;
|
||||
use goose::session::{SessionManager, SessionType};
|
||||
use std::path::Path;
|
||||
|
||||
tests_config_option_set_error!(AcpServerConnection);
|
||||
tests_mode_set_error!(AcpServerConnection);
|
||||
|
||||
async fn seed_list_sessions(data_root: &Path, working_dir: &Path, count: usize) {
|
||||
let session_manager = SessionManager::new(data_root.to_path_buf());
|
||||
for index in 0..count {
|
||||
let session = session_manager
|
||||
.create_session(
|
||||
working_dir.to_path_buf(),
|
||||
format!("Seed session {index}"),
|
||||
SessionType::Acp,
|
||||
GooseMode::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
session_manager
|
||||
.add_message(&session.id, &Message::user().with_text("hello"))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
async fn new_connection(data_root: &Path) -> AcpServerConnection {
|
||||
let openai = OpenAiFixture::new(
|
||||
vec![],
|
||||
<AcpServerConnection as Connection>::expected_session_id(),
|
||||
)
|
||||
.await;
|
||||
<AcpServerConnection as Connection>::new(
|
||||
TestConnectionConfig {
|
||||
data_root: data_root.to_path_buf(),
|
||||
..Default::default()
|
||||
},
|
||||
openai,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn list_sessions_request(
|
||||
conn: &AcpServerConnection,
|
||||
request: ListSessionsRequest,
|
||||
) -> anyhow::Result<ListSessionsResponse> {
|
||||
conn.cx()
|
||||
.send_request(request)
|
||||
.block_task()
|
||||
.await
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
fn assert_invalid_params(error: anyhow::Error) {
|
||||
let acp_error = error.downcast::<agent_client_protocol::Error>().unwrap();
|
||||
assert_eq!(acp_error.code, ErrorCode::InvalidParams);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_config_mcp() {
|
||||
run_test(async { run_config_mcp::<AcpServerConnection>().await });
|
||||
@@ -33,6 +90,74 @@ fn test_list_sessions() {
|
||||
run_test(async { run_list_sessions::<AcpServerConnection>().await });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_sessions_pagination() {
|
||||
run_test(async {
|
||||
let data_root = tempfile::tempdir().unwrap();
|
||||
seed_list_sessions(data_root.path(), Path::new("/tmp/acp-session-list"), 51).await;
|
||||
let conn = new_connection(data_root.path()).await;
|
||||
|
||||
let first = list_sessions_request(&conn, ListSessionsRequest::new())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(first.sessions.len(), 50);
|
||||
|
||||
let second = list_sessions_request(
|
||||
&conn,
|
||||
ListSessionsRequest::new().cursor(first.next_cursor.clone().unwrap()),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(second.sessions.len(), 1);
|
||||
assert!(second.next_cursor.is_none());
|
||||
|
||||
let second_id = &second.sessions[0].session_id;
|
||||
assert!(first
|
||||
.sessions
|
||||
.iter()
|
||||
.all(|session| session.session_id != *second_id));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_sessions_invalid_params() {
|
||||
run_test(async {
|
||||
let data_root = tempfile::tempdir().unwrap();
|
||||
let cwd = tempfile::tempdir().unwrap();
|
||||
let other_cwd = tempfile::tempdir().unwrap();
|
||||
seed_list_sessions(data_root.path(), cwd.path(), 51).await;
|
||||
let conn = new_connection(data_root.path()).await;
|
||||
|
||||
let error =
|
||||
list_sessions_request(&conn, ListSessionsRequest::new().cursor("*".to_string()))
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_invalid_params(error);
|
||||
|
||||
let error = list_sessions_request(
|
||||
&conn,
|
||||
ListSessionsRequest::new().cwd(std::path::PathBuf::from("relative/path")),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_invalid_params(error);
|
||||
|
||||
let first = list_sessions_request(&conn, ListSessionsRequest::new().cwd(cwd.path()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let error = list_sessions_request(
|
||||
&conn,
|
||||
ListSessionsRequest::new()
|
||||
.cwd(other_cwd.path())
|
||||
.cursor(first.next_cursor.unwrap()),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert_invalid_params(error);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_name_update_notification() {
|
||||
run_test(async { run_session_name_update_notification::<AcpServerConnection>().await });
|
||||
|
||||
Reference in New Issue
Block a user