mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-06-02 06:13:46 +02:00
Add clear command functionality to message handler
- Implemented handling of the `/clear` command to clear specific branches or entire trees based on message replies. - Added tests for various scenarios of the clear command, including clearing branches, handling unknown replies, and clearing entire trees. - Enhanced `TreeQueueManager` with methods to cancel branches and remove subtrees, ensuring proper state management in the session store. - Updated `SessionStore` and `TreeRepository` to support removal of node mappings and trees, improving data integrity during clear operations.
This commit is contained in:
@@ -190,7 +190,7 @@ Control Claude Code remotely from Discord. Send tasks, watch live progress, and
|
||||
- Session persistence across server restarts
|
||||
- Live streaming of thinking tokens, tool calls, and results
|
||||
- Up to 10 concurrent Claude CLI sessions
|
||||
- Commands: `/stop` (cancel tasks), `/clear` (reset all sessions), `/stats`
|
||||
- Commands: `/stop` (cancel tasks; reply to a message to stop only that task), `/clear` (standalone: reset all sessions; reply to a message to clear that branch downwards), `/stats`
|
||||
|
||||
### Setup
|
||||
|
||||
@@ -219,7 +219,7 @@ ALLOWED_DIR=C:/Users/yourname/projects
|
||||
uv run uvicorn server:app --host 0.0.0.0 --port 8082
|
||||
```
|
||||
|
||||
5. **Invite the bot** to your server (OAuth2 → URL Generator, scopes: `bot`, permissions: Read Messages, Send Messages, Manage Messages, Read Message History). Send a message in an allowed channel with a task. Claude responds with thinking tokens, tool calls as they execute, and the final result. Reply `/stop` to a running task to cancel it.
|
||||
5. **Invite the bot** to your server (OAuth2 → URL Generator, scopes: `bot`, permissions: Read Messages, Send Messages, Manage Messages, Read Message History). Send a message in an allowed channel with a task. Claude responds with thinking tokens, tool calls as they execute, and the final result. Reply to messages to cancel tasks or clear branches (see Commands above).
|
||||
|
||||
### Telegram (Alternative)
|
||||
|
||||
|
||||
+117
-47
@@ -802,15 +802,126 @@ class ClaudeMessageHandler:
|
||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||
)
|
||||
|
||||
async def _handle_clear_branch(
|
||||
self, incoming: IncomingMessage, branch_root_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Clear a branch (replied-to node + all descendants).
|
||||
|
||||
Order: cancel tasks, delete messages, remove branch, update session store.
|
||||
"""
|
||||
tree = self.tree_queue.get_tree_for_node(branch_root_id)
|
||||
if not tree:
|
||||
return
|
||||
|
||||
# 1) Cancel branch tasks (no stop_all)
|
||||
cancelled = await self.tree_queue.cancel_branch(branch_root_id)
|
||||
self._update_cancelled_nodes_ui(cancelled)
|
||||
|
||||
# 2) Collect message IDs from branch nodes only
|
||||
msg_ids: set[str] = set()
|
||||
branch_ids = tree.get_descendants(branch_root_id)
|
||||
for nid in branch_ids:
|
||||
node = tree.get_node(nid)
|
||||
if node:
|
||||
if node.incoming.message_id:
|
||||
msg_ids.add(str(node.incoming.message_id))
|
||||
if node.status_message_id:
|
||||
msg_ids.add(str(node.status_message_id))
|
||||
if incoming.message_id:
|
||||
msg_ids.add(str(incoming.message_id))
|
||||
|
||||
# 3) Delete messages (best-effort)
|
||||
await self._delete_message_ids(incoming.chat_id, msg_ids)
|
||||
|
||||
# 4) Remove branch from tree
|
||||
removed, root_id, removed_entire_tree = await self.tree_queue.remove_branch(
|
||||
branch_root_id
|
||||
)
|
||||
|
||||
# 5) Update session store
|
||||
try:
|
||||
self.session_store.remove_node_mappings([n.node_id for n in removed])
|
||||
if removed_entire_tree:
|
||||
self.session_store.remove_tree(root_id)
|
||||
else:
|
||||
updated_tree = self.tree_queue.get_tree(root_id)
|
||||
if updated_tree:
|
||||
self.session_store.save_tree(root_id, updated_tree.to_dict())
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update session store after branch clear: {e}")
|
||||
|
||||
async def _delete_message_ids(self, chat_id: str, msg_ids: set[str]) -> None:
|
||||
"""Best-effort delete messages by ID. Sorts numeric IDs descending."""
|
||||
if not msg_ids:
|
||||
return
|
||||
|
||||
def _as_int(s: str) -> int | None:
|
||||
try:
|
||||
return int(str(s))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
numeric: list[tuple[int, str]] = []
|
||||
non_numeric: list[str] = []
|
||||
for mid in msg_ids:
|
||||
n = _as_int(mid)
|
||||
if n is None:
|
||||
non_numeric.append(mid)
|
||||
else:
|
||||
numeric.append((n, mid))
|
||||
numeric.sort(reverse=True)
|
||||
ordered = [mid for _, mid in numeric] + non_numeric
|
||||
|
||||
batch_fn = getattr(self.platform, "queue_delete_messages", None)
|
||||
if callable(batch_fn):
|
||||
try:
|
||||
CHUNK = 100
|
||||
for i in range(0, len(ordered), CHUNK):
|
||||
chunk = ordered[i : i + CHUNK]
|
||||
await batch_fn(chat_id, chunk, fire_and_forget=False)
|
||||
except Exception as e:
|
||||
logger.debug(f"Batch delete failed: {type(e).__name__}: {e}")
|
||||
else:
|
||||
for mid in ordered:
|
||||
try:
|
||||
await self.platform.queue_delete_message(
|
||||
chat_id, mid, fire_and_forget=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Delete failed for msg {mid}: {type(e).__name__}: {e}"
|
||||
)
|
||||
|
||||
async def _handle_clear_command(self, incoming: IncomingMessage) -> None:
|
||||
"""
|
||||
Handle /clear global command.
|
||||
Handle /clear command.
|
||||
|
||||
Order:
|
||||
1. Stop all pending/in-progress tasks.
|
||||
2. Best-effort delete tracked chat messages for this chat.
|
||||
3. Clear sessions.json (entire store) and reset in-memory queue state.
|
||||
Reply-scoped: reply to a message to clear that branch (node + descendants).
|
||||
Standalone: global clear (stop all, delete all chat messages, reset store).
|
||||
"""
|
||||
if incoming.is_reply() and incoming.reply_to_message_id:
|
||||
reply_id = incoming.reply_to_message_id
|
||||
tree = self.tree_queue.get_tree_for_node(reply_id)
|
||||
branch_root_id = (
|
||||
self.tree_queue.resolve_parent_node_id(reply_id) if tree else None
|
||||
)
|
||||
if not branch_root_id:
|
||||
msg_id = await self.platform.queue_send_message(
|
||||
incoming.chat_id,
|
||||
self._format_status(
|
||||
"🗑", "Cleared.", "Nothing to clear for that message."
|
||||
),
|
||||
fire_and_forget=False,
|
||||
)
|
||||
self._record_outgoing_message(
|
||||
incoming.platform, incoming.chat_id, msg_id, "command"
|
||||
)
|
||||
return
|
||||
await self._handle_clear_branch(incoming, branch_root_id)
|
||||
return
|
||||
|
||||
# Global clear
|
||||
# 1) Stop tasks first (ensures no more work is running).
|
||||
await self.stop_all_tasks()
|
||||
|
||||
@@ -853,48 +964,7 @@ class ClaudeMessageHandler:
|
||||
if incoming.message_id is not None:
|
||||
msg_ids.add(str(incoming.message_id))
|
||||
|
||||
def _as_int(s: str) -> int | None:
|
||||
try:
|
||||
return int(str(s))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
numeric: list[tuple[int, str]] = []
|
||||
non_numeric: list[str] = []
|
||||
for mid in msg_ids:
|
||||
n = _as_int(mid)
|
||||
if n is None:
|
||||
non_numeric.append(mid)
|
||||
else:
|
||||
numeric.append((n, mid))
|
||||
|
||||
numeric.sort(reverse=True)
|
||||
ordered = [mid for _, mid in numeric] + non_numeric
|
||||
|
||||
# If platform supports batch deletes, prefer it.
|
||||
batch_fn = getattr(self.platform, "queue_delete_messages", None)
|
||||
if callable(batch_fn):
|
||||
try:
|
||||
# Telegram supports up to 100 per request.
|
||||
CHUNK = 100
|
||||
for i in range(0, len(ordered), CHUNK):
|
||||
chunk = ordered[i : i + CHUNK]
|
||||
await batch_fn(incoming.chat_id, chunk, fire_and_forget=False)
|
||||
except Exception as e:
|
||||
logger.debug(f"/clear batch delete failed: {type(e).__name__}: {e}")
|
||||
else:
|
||||
for mid in ordered:
|
||||
try:
|
||||
await self.platform.queue_delete_message(
|
||||
incoming.chat_id,
|
||||
mid,
|
||||
fire_and_forget=False,
|
||||
)
|
||||
except Exception as e:
|
||||
# Deleting is best-effort; platform adapters also treat common cases as no-op.
|
||||
logger.debug(
|
||||
f"/clear delete failed for msg {mid}: {type(e).__name__}: {e}"
|
||||
)
|
||||
await self._delete_message_ids(incoming.chat_id, msg_ids)
|
||||
|
||||
# 3) Clear persistent state and reset in-memory queue/tree state.
|
||||
try:
|
||||
|
||||
@@ -288,6 +288,22 @@ class SessionStore:
|
||||
self._node_to_tree[node_id] = root_id
|
||||
self._schedule_save()
|
||||
|
||||
def remove_node_mappings(self, node_ids: List[str]) -> None:
|
||||
"""Remove node IDs from the node-to-tree mapping."""
|
||||
with self._lock:
|
||||
for nid in node_ids:
|
||||
self._node_to_tree.pop(nid, None)
|
||||
self._schedule_save()
|
||||
|
||||
def remove_tree(self, root_id: str) -> None:
|
||||
"""Remove a tree and all its node mappings from the store."""
|
||||
with self._lock:
|
||||
tree_data = self._trees.pop(root_id, None)
|
||||
if tree_data:
|
||||
for node_id in tree_data.get("nodes", {}).keys():
|
||||
self._node_to_tree.pop(node_id, None)
|
||||
self._schedule_save()
|
||||
|
||||
def get_all_trees(self) -> Dict[str, dict]:
|
||||
"""Get all stored trees (public accessor)."""
|
||||
with self._lock:
|
||||
|
||||
@@ -386,3 +386,48 @@ class MessageTree:
|
||||
"""Find the node that has this status message ID (O(1) lookup)."""
|
||||
node_id = self._status_to_node.get(status_msg_id)
|
||||
return self._nodes.get(node_id) if node_id else None
|
||||
|
||||
def get_descendants(self, node_id: str) -> List[str]:
|
||||
"""
|
||||
Get node_id and all descendant IDs (subtree).
|
||||
|
||||
Returns:
|
||||
List of node IDs including the given node.
|
||||
"""
|
||||
if node_id not in self._nodes:
|
||||
return []
|
||||
result = [node_id]
|
||||
node = self._nodes[node_id]
|
||||
for child_id in node.children_ids:
|
||||
result.extend(self.get_descendants(child_id))
|
||||
return result
|
||||
|
||||
def remove_branch(self, branch_root_id: str) -> List[MessageNode]:
|
||||
"""
|
||||
Remove a subtree (branch_root and all descendants) from the tree.
|
||||
|
||||
Updates parent's children_ids. Caller must hold lock for consistency.
|
||||
Does not acquire lock internally.
|
||||
|
||||
Returns:
|
||||
List of removed nodes.
|
||||
"""
|
||||
if branch_root_id not in self._nodes:
|
||||
return []
|
||||
|
||||
parent = self.get_parent(branch_root_id)
|
||||
removed = []
|
||||
for nid in self.get_descendants(branch_root_id):
|
||||
node = self._nodes.get(nid)
|
||||
if node:
|
||||
removed.append(node)
|
||||
del self._nodes[nid]
|
||||
del self._status_to_node[node.status_message_id]
|
||||
|
||||
if parent and branch_root_id in parent.children_ids:
|
||||
parent.children_ids = [
|
||||
c for c in parent.children_ids if c != branch_root_id
|
||||
]
|
||||
|
||||
logger.debug(f"Removed branch {branch_root_id} ({len(removed)} nodes)")
|
||||
return removed
|
||||
|
||||
@@ -368,6 +368,75 @@ class TreeQueueManager:
|
||||
"""Register a node ID to a tree (for external mapping)."""
|
||||
self._repository.register_node(node_id, root_id)
|
||||
|
||||
async def cancel_branch(self, branch_root_id: str) -> List[MessageNode]:
|
||||
"""
|
||||
Cancel all PENDING/IN_PROGRESS nodes in the subtree (branch_root + descendants).
|
||||
|
||||
Does not call cli_manager.stop_all(). Returns list of cancelled nodes.
|
||||
"""
|
||||
tree = self._repository.get_tree_for_node(branch_root_id)
|
||||
if not tree:
|
||||
return []
|
||||
|
||||
branch_ids = set(tree.get_descendants(branch_root_id))
|
||||
cancelled: List[MessageNode] = []
|
||||
|
||||
async with tree.with_lock():
|
||||
for nid in branch_ids:
|
||||
node = tree.get_node(nid)
|
||||
if not node or node.state in (
|
||||
MessageState.COMPLETED,
|
||||
MessageState.ERROR,
|
||||
):
|
||||
continue
|
||||
|
||||
if tree.is_current_node(nid):
|
||||
self._processor.cancel_current(tree)
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Cancelled by user"
|
||||
node.completed_at = datetime.now(timezone.utc)
|
||||
cancelled.append(node)
|
||||
else:
|
||||
tree.remove_from_queue(nid)
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Cancelled by user"
|
||||
node.completed_at = datetime.now(timezone.utc)
|
||||
cancelled.append(node)
|
||||
|
||||
if cancelled:
|
||||
logger.info(f"Cancelled {len(cancelled)} nodes in branch {branch_root_id}")
|
||||
return cancelled
|
||||
|
||||
async def remove_branch(
|
||||
self, branch_root_id: str
|
||||
) -> tuple[List[MessageNode], str, bool]:
|
||||
"""
|
||||
Remove a branch (subtree) from the tree.
|
||||
|
||||
If branch_root is the tree root, removes the entire tree.
|
||||
|
||||
Returns:
|
||||
(removed_nodes, root_id, removed_entire_tree)
|
||||
"""
|
||||
tree = self._repository.get_tree_for_node(branch_root_id)
|
||||
if not tree:
|
||||
return ([], "", False)
|
||||
|
||||
root_id = tree.root_id
|
||||
|
||||
if branch_root_id == root_id:
|
||||
cancelled = self.cancel_tree(root_id)
|
||||
removed_tree = self._repository.remove_tree(root_id)
|
||||
if removed_tree:
|
||||
return (removed_tree.all_nodes(), root_id, True)
|
||||
return (cancelled, root_id, True)
|
||||
|
||||
async with tree.with_lock():
|
||||
removed = tree.remove_branch(branch_root_id)
|
||||
|
||||
self._repository.unregister_nodes([n.node_id for n in removed])
|
||||
return (removed, root_id, False)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize all trees."""
|
||||
return self._repository.to_dict()
|
||||
|
||||
@@ -129,6 +129,26 @@ class TreeRepository:
|
||||
"""Get all tree root IDs."""
|
||||
return list(self._trees.keys())
|
||||
|
||||
def unregister_nodes(self, node_ids: List[str]) -> None:
|
||||
"""Remove node IDs from the node-to-tree mapping."""
|
||||
for nid in node_ids:
|
||||
self._node_to_tree.pop(nid, None)
|
||||
|
||||
def remove_tree(self, root_id: str) -> Optional[MessageTree]:
|
||||
"""
|
||||
Remove a tree and all its node mappings from the repository.
|
||||
|
||||
Returns:
|
||||
The removed tree, or None if not found.
|
||||
"""
|
||||
tree = self._trees.pop(root_id, None)
|
||||
if not tree:
|
||||
return None
|
||||
for node in tree.all_nodes():
|
||||
self._node_to_tree.pop(node.node_id, None)
|
||||
logger.debug("TREE_REPO: remove_tree root_id=%s", root_id)
|
||||
return tree
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Serialize all trees."""
|
||||
return {
|
||||
|
||||
@@ -517,3 +517,103 @@ async def test_handle_message_clear_command_deletes_message_log_ids(
|
||||
|
||||
deleted = {c.args[1] for c in mock_platform.queue_delete_message.call_args_list}
|
||||
assert deleted == {"42", "43", "150"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_clear_command_reply_clears_branch(
|
||||
handler, mock_platform, mock_session_store, incoming_message_factory
|
||||
):
|
||||
"""Reply /clear to a message clears only that branch."""
|
||||
root_incoming = incoming_message_factory(
|
||||
text="root", chat_id="chat_1", message_id="100", reply_to_message_id=None
|
||||
)
|
||||
tree = await handler.tree_queue.create_tree(
|
||||
node_id="100", incoming=root_incoming, status_message_id="101"
|
||||
)
|
||||
handler.tree_queue.register_node("101", tree.root_id)
|
||||
|
||||
child_incoming = incoming_message_factory(
|
||||
text="child",
|
||||
chat_id="chat_1",
|
||||
message_id="102",
|
||||
reply_to_message_id="100",
|
||||
)
|
||||
await handler.tree_queue.add_to_tree(
|
||||
parent_node_id="100",
|
||||
node_id="102",
|
||||
incoming=child_incoming,
|
||||
status_message_id="103",
|
||||
)
|
||||
|
||||
deleted_ids = []
|
||||
|
||||
async def _capture_delete(chat_id, message_id, fire_and_forget=True):
|
||||
deleted_ids.append(message_id)
|
||||
|
||||
mock_platform.queue_delete_message = AsyncMock(side_effect=_capture_delete)
|
||||
|
||||
incoming = incoming_message_factory(
|
||||
text="/clear",
|
||||
chat_id="chat_1",
|
||||
message_id="150",
|
||||
reply_to_message_id="102",
|
||||
)
|
||||
await handler.handle_message(incoming)
|
||||
|
||||
assert set(deleted_ids) == {"102", "103", "150"}
|
||||
assert "100" not in deleted_ids
|
||||
assert "101" not in deleted_ids
|
||||
mock_session_store.remove_node_mappings.assert_called()
|
||||
assert handler.tree_queue.get_tree_for_node("102") is None
|
||||
assert handler.tree_queue.get_tree_for_node("100") is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_clear_command_reply_unknown_sends_nothing(
|
||||
handler, mock_platform, mock_session_store, incoming_message_factory
|
||||
):
|
||||
"""Reply /clear to unknown message sends 'Nothing to clear'."""
|
||||
incoming = incoming_message_factory(
|
||||
text="/clear",
|
||||
chat_id="chat_1",
|
||||
message_id="150",
|
||||
reply_to_message_id="999",
|
||||
)
|
||||
await handler.handle_message(incoming)
|
||||
|
||||
mock_platform.queue_send_message.assert_called_once()
|
||||
call_args = mock_platform.queue_send_message.call_args[0]
|
||||
assert "Nothing to clear" in call_args[1]
|
||||
mock_session_store.clear_all.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_message_clear_command_reply_to_root_clears_tree(
|
||||
handler, mock_platform, mock_session_store, incoming_message_factory
|
||||
):
|
||||
"""Reply /clear to root message clears entire tree."""
|
||||
root_incoming = incoming_message_factory(
|
||||
text="root", chat_id="chat_1", message_id="100", reply_to_message_id=None
|
||||
)
|
||||
await handler.tree_queue.create_tree(
|
||||
node_id="100", incoming=root_incoming, status_message_id="101"
|
||||
)
|
||||
|
||||
deleted_ids = []
|
||||
|
||||
async def _capture_delete(chat_id, message_id, fire_and_forget=True):
|
||||
deleted_ids.append(message_id)
|
||||
|
||||
mock_platform.queue_delete_message = AsyncMock(side_effect=_capture_delete)
|
||||
|
||||
incoming = incoming_message_factory(
|
||||
text="/clear",
|
||||
chat_id="chat_1",
|
||||
message_id="150",
|
||||
reply_to_message_id="100",
|
||||
)
|
||||
await handler.handle_message(incoming)
|
||||
|
||||
assert set(deleted_ids) == {"100", "101", "150"}
|
||||
mock_session_store.remove_tree.assert_called_once_with("100")
|
||||
assert handler.tree_queue.get_tree_count() == 0
|
||||
|
||||
@@ -282,6 +282,83 @@ class TestMessageTree:
|
||||
assert node is not None
|
||||
assert node.session_id == "sess_1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_descendants(self):
|
||||
"""Test get_descendants returns node and all descendants."""
|
||||
root_incoming = IncomingMessage(
|
||||
text="Root", chat_id="1", user_id="1", message_id="root", platform="test"
|
||||
)
|
||||
root = MessageNode(
|
||||
node_id="root", incoming=root_incoming, status_message_id="s1"
|
||||
)
|
||||
tree = MessageTree(root)
|
||||
|
||||
child_incoming = IncomingMessage(
|
||||
text="Child",
|
||||
chat_id="1",
|
||||
user_id="1",
|
||||
message_id="child",
|
||||
platform="test",
|
||||
reply_to_message_id="root",
|
||||
)
|
||||
await tree.add_node("child", child_incoming, "s2", "root")
|
||||
|
||||
grandchild_incoming = IncomingMessage(
|
||||
text="Grand",
|
||||
chat_id="1",
|
||||
user_id="1",
|
||||
message_id="grand",
|
||||
platform="test",
|
||||
reply_to_message_id="child",
|
||||
)
|
||||
await tree.add_node("grand", grandchild_incoming, "s3", "child")
|
||||
|
||||
assert tree.get_descendants("root") == ["root", "child", "grand"]
|
||||
assert tree.get_descendants("child") == ["child", "grand"]
|
||||
assert tree.get_descendants("grand") == ["grand"]
|
||||
assert tree.get_descendants("nonexistent") == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_branch(self):
|
||||
"""Test remove_branch removes subtree and updates parent."""
|
||||
root_incoming = IncomingMessage(
|
||||
text="Root", chat_id="1", user_id="1", message_id="root", platform="test"
|
||||
)
|
||||
root = MessageNode(
|
||||
node_id="root", incoming=root_incoming, status_message_id="s1"
|
||||
)
|
||||
tree = MessageTree(root)
|
||||
|
||||
child_incoming = IncomingMessage(
|
||||
text="Child",
|
||||
chat_id="1",
|
||||
user_id="1",
|
||||
message_id="child",
|
||||
platform="test",
|
||||
reply_to_message_id="root",
|
||||
)
|
||||
await tree.add_node("child", child_incoming, "s2", "root")
|
||||
|
||||
grandchild_incoming = IncomingMessage(
|
||||
text="Grand",
|
||||
chat_id="1",
|
||||
user_id="1",
|
||||
message_id="grand",
|
||||
platform="test",
|
||||
reply_to_message_id="child",
|
||||
)
|
||||
await tree.add_node("grand", grandchild_incoming, "s3", "child")
|
||||
|
||||
async with tree.with_lock():
|
||||
removed = tree.remove_branch("child")
|
||||
|
||||
assert len(removed) == 2
|
||||
assert {n.node_id for n in removed} == {"child", "grand"}
|
||||
assert tree.get_node("child") is None
|
||||
assert tree.get_node("grand") is None
|
||||
assert tree.get_node("root") is not None
|
||||
assert "child" not in tree.get_root().children_ids
|
||||
|
||||
|
||||
class TestTreeQueueManager:
|
||||
"""Test TreeQueueManager class."""
|
||||
@@ -441,6 +518,97 @@ class TestTreeQueueManager:
|
||||
|
||||
processing_complete.set()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_branch(self):
|
||||
"""Test cancel_branch cancels only nodes in subtree."""
|
||||
manager = TreeQueueManager()
|
||||
|
||||
root_incoming = IncomingMessage(
|
||||
text="Root", chat_id="1", user_id="1", message_id="root", platform="test"
|
||||
)
|
||||
await manager.create_tree("root", root_incoming, "s1")
|
||||
|
||||
child_incoming = IncomingMessage(
|
||||
text="Child",
|
||||
chat_id="1",
|
||||
user_id="1",
|
||||
message_id="child",
|
||||
platform="test",
|
||||
reply_to_message_id="root",
|
||||
)
|
||||
tree, _ = await manager.add_to_tree("root", "child", child_incoming, "s2")
|
||||
|
||||
sibling_incoming = IncomingMessage(
|
||||
text="Sibling",
|
||||
chat_id="1",
|
||||
user_id="1",
|
||||
message_id="sibling",
|
||||
platform="test",
|
||||
reply_to_message_id="root",
|
||||
)
|
||||
await manager.add_to_tree("root", "sibling", sibling_incoming, "s3")
|
||||
|
||||
cancelled = await manager.cancel_branch("child")
|
||||
assert len(cancelled) == 1
|
||||
assert cancelled[0].node_id == "child"
|
||||
|
||||
child_node = tree.get_node("child")
|
||||
assert child_node is not None
|
||||
assert child_node.state == MessageState.ERROR
|
||||
|
||||
sibling_node = tree.get_node("sibling")
|
||||
assert sibling_node is not None
|
||||
assert sibling_node.state == MessageState.PENDING
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_branch_non_root(self):
|
||||
"""Test remove_branch removes only the subtree when branch is not root."""
|
||||
manager = TreeQueueManager()
|
||||
|
||||
root_incoming = IncomingMessage(
|
||||
text="Root", chat_id="1", user_id="1", message_id="root", platform="test"
|
||||
)
|
||||
await manager.create_tree("root", root_incoming, "s1")
|
||||
|
||||
child_incoming = IncomingMessage(
|
||||
text="Child",
|
||||
chat_id="1",
|
||||
user_id="1",
|
||||
message_id="child",
|
||||
platform="test",
|
||||
reply_to_message_id="root",
|
||||
)
|
||||
tree, _ = await manager.add_to_tree("root", "child", child_incoming, "s2")
|
||||
|
||||
removed, root_id, removed_entire = await manager.remove_branch("child")
|
||||
|
||||
assert len(removed) == 1
|
||||
assert removed[0].node_id == "child"
|
||||
assert root_id == "root"
|
||||
assert removed_entire is False
|
||||
assert manager.get_tree_for_node("child") is None
|
||||
assert manager.get_tree("root") is not None
|
||||
assert tree.get_node("child") is None
|
||||
assert "child" not in tree.get_root().children_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_branch_root_removes_tree(self):
|
||||
"""Test remove_branch when branch is root removes entire tree."""
|
||||
manager = TreeQueueManager()
|
||||
|
||||
root_incoming = IncomingMessage(
|
||||
text="Root", chat_id="1", user_id="1", message_id="root", platform="test"
|
||||
)
|
||||
await manager.create_tree("root", root_incoming, "s1")
|
||||
|
||||
removed, root_id, removed_entire = await manager.remove_branch("root")
|
||||
|
||||
assert len(removed) == 1
|
||||
assert root_id == "root"
|
||||
assert removed_entire is True
|
||||
assert manager.get_tree("root") is None
|
||||
assert manager.get_tree_for_node("root") is None
|
||||
|
||||
|
||||
class TestSessionStoreTrees:
|
||||
"""Test SessionStore tree methods."""
|
||||
|
||||
Reference in New Issue
Block a user