Add clear command functionality to message handler

- Implemented handling of the `/clear` command to clear specific branches or entire trees based on message replies.
- Added tests for various scenarios of the clear command, including clearing branches, handling unknown replies, and clearing entire trees.
- Enhanced `TreeQueueManager` with methods to cancel branches and remove subtrees, ensuring proper state management in the session store.
- Updated `SessionStore` and `TreeRepository` to support removal of node mappings and trees, improving data integrity during clear operations.
This commit is contained in:
Alishahryar1
2026-02-16 16:23:26 -08:00
parent e4ae59511e
commit 6abcdb4017
8 changed files with 537 additions and 49 deletions
+2 -2
View File
@@ -190,7 +190,7 @@ Control Claude Code remotely from Discord. Send tasks, watch live progress, and
- Session persistence across server restarts
- Live streaming of thinking tokens, tool calls, and results
- Up to 10 concurrent Claude CLI sessions
- Commands: `/stop` (cancel tasks), `/clear` (reset all sessions), `/stats`
- Commands: `/stop` (cancel tasks; reply to a message to stop only that task), `/clear` (standalone: reset all sessions; reply to a message to clear that branch downwards), `/stats`
### Setup
@@ -219,7 +219,7 @@ ALLOWED_DIR=C:/Users/yourname/projects
uv run uvicorn server:app --host 0.0.0.0 --port 8082
```
5. **Invite the bot** to your server (OAuth2 → URL Generator, scopes: `bot`, permissions: Read Messages, Send Messages, Manage Messages, Read Message History). Send a message in an allowed channel with a task. Claude responds with thinking tokens, tool calls as they execute, and the final result. Reply `/stop` to a running task to cancel it.
5. **Invite the bot** to your server (OAuth2 → URL Generator, scopes: `bot`, permissions: Read Messages, Send Messages, Manage Messages, Read Message History). Send a message in an allowed channel with a task. Claude responds with thinking tokens, tool calls as they execute, and the final result. Reply to messages to cancel tasks or clear branches (see Commands above).
### Telegram (Alternative)
+117 -47
View File
@@ -802,15 +802,126 @@ class ClaudeMessageHandler:
incoming.platform, incoming.chat_id, msg_id, "command"
)
async def _handle_clear_branch(
self, incoming: IncomingMessage, branch_root_id: str
) -> None:
"""
Clear a branch (replied-to node + all descendants).
Order: cancel tasks, delete messages, remove branch, update session store.
"""
tree = self.tree_queue.get_tree_for_node(branch_root_id)
if not tree:
return
# 1) Cancel branch tasks (no stop_all)
cancelled = await self.tree_queue.cancel_branch(branch_root_id)
self._update_cancelled_nodes_ui(cancelled)
# 2) Collect message IDs from branch nodes only
msg_ids: set[str] = set()
branch_ids = tree.get_descendants(branch_root_id)
for nid in branch_ids:
node = tree.get_node(nid)
if node:
if node.incoming.message_id:
msg_ids.add(str(node.incoming.message_id))
if node.status_message_id:
msg_ids.add(str(node.status_message_id))
if incoming.message_id:
msg_ids.add(str(incoming.message_id))
# 3) Delete messages (best-effort)
await self._delete_message_ids(incoming.chat_id, msg_ids)
# 4) Remove branch from tree
removed, root_id, removed_entire_tree = await self.tree_queue.remove_branch(
branch_root_id
)
# 5) Update session store
try:
self.session_store.remove_node_mappings([n.node_id for n in removed])
if removed_entire_tree:
self.session_store.remove_tree(root_id)
else:
updated_tree = self.tree_queue.get_tree(root_id)
if updated_tree:
self.session_store.save_tree(root_id, updated_tree.to_dict())
except Exception as e:
logger.warning(f"Failed to update session store after branch clear: {e}")
async def _delete_message_ids(self, chat_id: str, msg_ids: set[str]) -> None:
"""Best-effort delete messages by ID. Sorts numeric IDs descending."""
if not msg_ids:
return
def _as_int(s: str) -> int | None:
try:
return int(str(s))
except Exception:
return None
numeric: list[tuple[int, str]] = []
non_numeric: list[str] = []
for mid in msg_ids:
n = _as_int(mid)
if n is None:
non_numeric.append(mid)
else:
numeric.append((n, mid))
numeric.sort(reverse=True)
ordered = [mid for _, mid in numeric] + non_numeric
batch_fn = getattr(self.platform, "queue_delete_messages", None)
if callable(batch_fn):
try:
CHUNK = 100
for i in range(0, len(ordered), CHUNK):
chunk = ordered[i : i + CHUNK]
await batch_fn(chat_id, chunk, fire_and_forget=False)
except Exception as e:
logger.debug(f"Batch delete failed: {type(e).__name__}: {e}")
else:
for mid in ordered:
try:
await self.platform.queue_delete_message(
chat_id, mid, fire_and_forget=False
)
except Exception as e:
logger.debug(
f"Delete failed for msg {mid}: {type(e).__name__}: {e}"
)
async def _handle_clear_command(self, incoming: IncomingMessage) -> None:
"""
Handle /clear global command.
Handle /clear command.
Order:
1. Stop all pending/in-progress tasks.
2. Best-effort delete tracked chat messages for this chat.
3. Clear sessions.json (entire store) and reset in-memory queue state.
Reply-scoped: reply to a message to clear that branch (node + descendants).
Standalone: global clear (stop all, delete all chat messages, reset store).
"""
if incoming.is_reply() and incoming.reply_to_message_id:
reply_id = incoming.reply_to_message_id
tree = self.tree_queue.get_tree_for_node(reply_id)
branch_root_id = (
self.tree_queue.resolve_parent_node_id(reply_id) if tree else None
)
if not branch_root_id:
msg_id = await self.platform.queue_send_message(
incoming.chat_id,
self._format_status(
"🗑", "Cleared.", "Nothing to clear for that message."
),
fire_and_forget=False,
)
self._record_outgoing_message(
incoming.platform, incoming.chat_id, msg_id, "command"
)
return
await self._handle_clear_branch(incoming, branch_root_id)
return
# Global clear
# 1) Stop tasks first (ensures no more work is running).
await self.stop_all_tasks()
@@ -853,48 +964,7 @@ class ClaudeMessageHandler:
if incoming.message_id is not None:
msg_ids.add(str(incoming.message_id))
def _as_int(s: str) -> int | None:
try:
return int(str(s))
except Exception:
return None
numeric: list[tuple[int, str]] = []
non_numeric: list[str] = []
for mid in msg_ids:
n = _as_int(mid)
if n is None:
non_numeric.append(mid)
else:
numeric.append((n, mid))
numeric.sort(reverse=True)
ordered = [mid for _, mid in numeric] + non_numeric
# If platform supports batch deletes, prefer it.
batch_fn = getattr(self.platform, "queue_delete_messages", None)
if callable(batch_fn):
try:
# Telegram supports up to 100 per request.
CHUNK = 100
for i in range(0, len(ordered), CHUNK):
chunk = ordered[i : i + CHUNK]
await batch_fn(incoming.chat_id, chunk, fire_and_forget=False)
except Exception as e:
logger.debug(f"/clear batch delete failed: {type(e).__name__}: {e}")
else:
for mid in ordered:
try:
await self.platform.queue_delete_message(
incoming.chat_id,
mid,
fire_and_forget=False,
)
except Exception as e:
# Deleting is best-effort; platform adapters also treat common cases as no-op.
logger.debug(
f"/clear delete failed for msg {mid}: {type(e).__name__}: {e}"
)
await self._delete_message_ids(incoming.chat_id, msg_ids)
# 3) Clear persistent state and reset in-memory queue/tree state.
try:
+16
View File
@@ -288,6 +288,22 @@ class SessionStore:
self._node_to_tree[node_id] = root_id
self._schedule_save()
def remove_node_mappings(self, node_ids: List[str]) -> None:
"""Remove node IDs from the node-to-tree mapping."""
with self._lock:
for nid in node_ids:
self._node_to_tree.pop(nid, None)
self._schedule_save()
def remove_tree(self, root_id: str) -> None:
"""Remove a tree and all its node mappings from the store."""
with self._lock:
tree_data = self._trees.pop(root_id, None)
if tree_data:
for node_id in tree_data.get("nodes", {}).keys():
self._node_to_tree.pop(node_id, None)
self._schedule_save()
def get_all_trees(self) -> Dict[str, dict]:
"""Get all stored trees (public accessor)."""
with self._lock:
+45
View File
@@ -386,3 +386,48 @@ class MessageTree:
"""Find the node that has this status message ID (O(1) lookup)."""
node_id = self._status_to_node.get(status_msg_id)
return self._nodes.get(node_id) if node_id else None
def get_descendants(self, node_id: str) -> List[str]:
"""
Get node_id and all descendant IDs (subtree).
Returns:
List of node IDs including the given node.
"""
if node_id not in self._nodes:
return []
result = [node_id]
node = self._nodes[node_id]
for child_id in node.children_ids:
result.extend(self.get_descendants(child_id))
return result
def remove_branch(self, branch_root_id: str) -> List[MessageNode]:
"""
Remove a subtree (branch_root and all descendants) from the tree.
Updates parent's children_ids. Caller must hold lock for consistency.
Does not acquire lock internally.
Returns:
List of removed nodes.
"""
if branch_root_id not in self._nodes:
return []
parent = self.get_parent(branch_root_id)
removed = []
for nid in self.get_descendants(branch_root_id):
node = self._nodes.get(nid)
if node:
removed.append(node)
del self._nodes[nid]
del self._status_to_node[node.status_message_id]
if parent and branch_root_id in parent.children_ids:
parent.children_ids = [
c for c in parent.children_ids if c != branch_root_id
]
logger.debug(f"Removed branch {branch_root_id} ({len(removed)} nodes)")
return removed
+69
View File
@@ -368,6 +368,75 @@ class TreeQueueManager:
"""Register a node ID to a tree (for external mapping)."""
self._repository.register_node(node_id, root_id)
async def cancel_branch(self, branch_root_id: str) -> List[MessageNode]:
"""
Cancel all PENDING/IN_PROGRESS nodes in the subtree (branch_root + descendants).
Does not call cli_manager.stop_all(). Returns list of cancelled nodes.
"""
tree = self._repository.get_tree_for_node(branch_root_id)
if not tree:
return []
branch_ids = set(tree.get_descendants(branch_root_id))
cancelled: List[MessageNode] = []
async with tree.with_lock():
for nid in branch_ids:
node = tree.get_node(nid)
if not node or node.state in (
MessageState.COMPLETED,
MessageState.ERROR,
):
continue
if tree.is_current_node(nid):
self._processor.cancel_current(tree)
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
node.completed_at = datetime.now(timezone.utc)
cancelled.append(node)
else:
tree.remove_from_queue(nid)
node.state = MessageState.ERROR
node.error_message = "Cancelled by user"
node.completed_at = datetime.now(timezone.utc)
cancelled.append(node)
if cancelled:
logger.info(f"Cancelled {len(cancelled)} nodes in branch {branch_root_id}")
return cancelled
async def remove_branch(
self, branch_root_id: str
) -> tuple[List[MessageNode], str, bool]:
"""
Remove a branch (subtree) from the tree.
If branch_root is the tree root, removes the entire tree.
Returns:
(removed_nodes, root_id, removed_entire_tree)
"""
tree = self._repository.get_tree_for_node(branch_root_id)
if not tree:
return ([], "", False)
root_id = tree.root_id
if branch_root_id == root_id:
cancelled = self.cancel_tree(root_id)
removed_tree = self._repository.remove_tree(root_id)
if removed_tree:
return (removed_tree.all_nodes(), root_id, True)
return (cancelled, root_id, True)
async with tree.with_lock():
removed = tree.remove_branch(branch_root_id)
self._repository.unregister_nodes([n.node_id for n in removed])
return (removed, root_id, False)
def to_dict(self) -> dict:
"""Serialize all trees."""
return self._repository.to_dict()
+20
View File
@@ -129,6 +129,26 @@ class TreeRepository:
"""Get all tree root IDs."""
return list(self._trees.keys())
def unregister_nodes(self, node_ids: List[str]) -> None:
"""Remove node IDs from the node-to-tree mapping."""
for nid in node_ids:
self._node_to_tree.pop(nid, None)
def remove_tree(self, root_id: str) -> Optional[MessageTree]:
"""
Remove a tree and all its node mappings from the repository.
Returns:
The removed tree, or None if not found.
"""
tree = self._trees.pop(root_id, None)
if not tree:
return None
for node in tree.all_nodes():
self._node_to_tree.pop(node.node_id, None)
logger.debug("TREE_REPO: remove_tree root_id=%s", root_id)
return tree
def to_dict(self) -> dict:
"""Serialize all trees."""
return {
+100
View File
@@ -517,3 +517,103 @@ async def test_handle_message_clear_command_deletes_message_log_ids(
deleted = {c.args[1] for c in mock_platform.queue_delete_message.call_args_list}
assert deleted == {"42", "43", "150"}
@pytest.mark.asyncio
async def test_handle_message_clear_command_reply_clears_branch(
handler, mock_platform, mock_session_store, incoming_message_factory
):
"""Reply /clear to a message clears only that branch."""
root_incoming = incoming_message_factory(
text="root", chat_id="chat_1", message_id="100", reply_to_message_id=None
)
tree = await handler.tree_queue.create_tree(
node_id="100", incoming=root_incoming, status_message_id="101"
)
handler.tree_queue.register_node("101", tree.root_id)
child_incoming = incoming_message_factory(
text="child",
chat_id="chat_1",
message_id="102",
reply_to_message_id="100",
)
await handler.tree_queue.add_to_tree(
parent_node_id="100",
node_id="102",
incoming=child_incoming,
status_message_id="103",
)
deleted_ids = []
async def _capture_delete(chat_id, message_id, fire_and_forget=True):
deleted_ids.append(message_id)
mock_platform.queue_delete_message = AsyncMock(side_effect=_capture_delete)
incoming = incoming_message_factory(
text="/clear",
chat_id="chat_1",
message_id="150",
reply_to_message_id="102",
)
await handler.handle_message(incoming)
assert set(deleted_ids) == {"102", "103", "150"}
assert "100" not in deleted_ids
assert "101" not in deleted_ids
mock_session_store.remove_node_mappings.assert_called()
assert handler.tree_queue.get_tree_for_node("102") is None
assert handler.tree_queue.get_tree_for_node("100") is not None
@pytest.mark.asyncio
async def test_handle_message_clear_command_reply_unknown_sends_nothing(
handler, mock_platform, mock_session_store, incoming_message_factory
):
"""Reply /clear to unknown message sends 'Nothing to clear'."""
incoming = incoming_message_factory(
text="/clear",
chat_id="chat_1",
message_id="150",
reply_to_message_id="999",
)
await handler.handle_message(incoming)
mock_platform.queue_send_message.assert_called_once()
call_args = mock_platform.queue_send_message.call_args[0]
assert "Nothing to clear" in call_args[1]
mock_session_store.clear_all.assert_not_called()
@pytest.mark.asyncio
async def test_handle_message_clear_command_reply_to_root_clears_tree(
handler, mock_platform, mock_session_store, incoming_message_factory
):
"""Reply /clear to root message clears entire tree."""
root_incoming = incoming_message_factory(
text="root", chat_id="chat_1", message_id="100", reply_to_message_id=None
)
await handler.tree_queue.create_tree(
node_id="100", incoming=root_incoming, status_message_id="101"
)
deleted_ids = []
async def _capture_delete(chat_id, message_id, fire_and_forget=True):
deleted_ids.append(message_id)
mock_platform.queue_delete_message = AsyncMock(side_effect=_capture_delete)
incoming = incoming_message_factory(
text="/clear",
chat_id="chat_1",
message_id="150",
reply_to_message_id="100",
)
await handler.handle_message(incoming)
assert set(deleted_ids) == {"100", "101", "150"}
mock_session_store.remove_tree.assert_called_once_with("100")
assert handler.tree_queue.get_tree_count() == 0
+168
View File
@@ -282,6 +282,83 @@ class TestMessageTree:
assert node is not None
assert node.session_id == "sess_1"
@pytest.mark.asyncio
async def test_get_descendants(self):
"""Test get_descendants returns node and all descendants."""
root_incoming = IncomingMessage(
text="Root", chat_id="1", user_id="1", message_id="root", platform="test"
)
root = MessageNode(
node_id="root", incoming=root_incoming, status_message_id="s1"
)
tree = MessageTree(root)
child_incoming = IncomingMessage(
text="Child",
chat_id="1",
user_id="1",
message_id="child",
platform="test",
reply_to_message_id="root",
)
await tree.add_node("child", child_incoming, "s2", "root")
grandchild_incoming = IncomingMessage(
text="Grand",
chat_id="1",
user_id="1",
message_id="grand",
platform="test",
reply_to_message_id="child",
)
await tree.add_node("grand", grandchild_incoming, "s3", "child")
assert tree.get_descendants("root") == ["root", "child", "grand"]
assert tree.get_descendants("child") == ["child", "grand"]
assert tree.get_descendants("grand") == ["grand"]
assert tree.get_descendants("nonexistent") == []
@pytest.mark.asyncio
async def test_remove_branch(self):
"""Test remove_branch removes subtree and updates parent."""
root_incoming = IncomingMessage(
text="Root", chat_id="1", user_id="1", message_id="root", platform="test"
)
root = MessageNode(
node_id="root", incoming=root_incoming, status_message_id="s1"
)
tree = MessageTree(root)
child_incoming = IncomingMessage(
text="Child",
chat_id="1",
user_id="1",
message_id="child",
platform="test",
reply_to_message_id="root",
)
await tree.add_node("child", child_incoming, "s2", "root")
grandchild_incoming = IncomingMessage(
text="Grand",
chat_id="1",
user_id="1",
message_id="grand",
platform="test",
reply_to_message_id="child",
)
await tree.add_node("grand", grandchild_incoming, "s3", "child")
async with tree.with_lock():
removed = tree.remove_branch("child")
assert len(removed) == 2
assert {n.node_id for n in removed} == {"child", "grand"}
assert tree.get_node("child") is None
assert tree.get_node("grand") is None
assert tree.get_node("root") is not None
assert "child" not in tree.get_root().children_ids
class TestTreeQueueManager:
"""Test TreeQueueManager class."""
@@ -441,6 +518,97 @@ class TestTreeQueueManager:
processing_complete.set()
@pytest.mark.asyncio
async def test_cancel_branch(self):
"""Test cancel_branch cancels only nodes in subtree."""
manager = TreeQueueManager()
root_incoming = IncomingMessage(
text="Root", chat_id="1", user_id="1", message_id="root", platform="test"
)
await manager.create_tree("root", root_incoming, "s1")
child_incoming = IncomingMessage(
text="Child",
chat_id="1",
user_id="1",
message_id="child",
platform="test",
reply_to_message_id="root",
)
tree, _ = await manager.add_to_tree("root", "child", child_incoming, "s2")
sibling_incoming = IncomingMessage(
text="Sibling",
chat_id="1",
user_id="1",
message_id="sibling",
platform="test",
reply_to_message_id="root",
)
await manager.add_to_tree("root", "sibling", sibling_incoming, "s3")
cancelled = await manager.cancel_branch("child")
assert len(cancelled) == 1
assert cancelled[0].node_id == "child"
child_node = tree.get_node("child")
assert child_node is not None
assert child_node.state == MessageState.ERROR
sibling_node = tree.get_node("sibling")
assert sibling_node is not None
assert sibling_node.state == MessageState.PENDING
@pytest.mark.asyncio
async def test_remove_branch_non_root(self):
"""Test remove_branch removes only the subtree when branch is not root."""
manager = TreeQueueManager()
root_incoming = IncomingMessage(
text="Root", chat_id="1", user_id="1", message_id="root", platform="test"
)
await manager.create_tree("root", root_incoming, "s1")
child_incoming = IncomingMessage(
text="Child",
chat_id="1",
user_id="1",
message_id="child",
platform="test",
reply_to_message_id="root",
)
tree, _ = await manager.add_to_tree("root", "child", child_incoming, "s2")
removed, root_id, removed_entire = await manager.remove_branch("child")
assert len(removed) == 1
assert removed[0].node_id == "child"
assert root_id == "root"
assert removed_entire is False
assert manager.get_tree_for_node("child") is None
assert manager.get_tree("root") is not None
assert tree.get_node("child") is None
assert "child" not in tree.get_root().children_ids
@pytest.mark.asyncio
async def test_remove_branch_root_removes_tree(self):
"""Test remove_branch when branch is root removes entire tree."""
manager = TreeQueueManager()
root_incoming = IncomingMessage(
text="Root", chat_id="1", user_id="1", message_id="root", platform="test"
)
await manager.create_tree("root", root_incoming, "s1")
removed, root_id, removed_entire = await manager.remove_branch("root")
assert len(removed) == 1
assert root_id == "root"
assert removed_entire is True
assert manager.get_tree("root") is None
assert manager.get_tree_for_node("root") is None
class TestSessionStoreTrees:
"""Test SessionStore tree methods."""