Optimized code in hot paths with z-ai/glm5

This commit is contained in:
Alishahryar1
2026-02-14 19:59:46 -08:00
parent 952a2351ec
commit 7259b1def8
12 changed files with 150 additions and 81 deletions
+5
View File
@@ -145,6 +145,11 @@ async def lifespan(app: FastAPI):
yield
# Cleanup
if message_handler and hasattr(message_handler, "session_store"):
try:
message_handler.session_store.flush_pending_save()
except Exception as e:
logger.warning(f"Session store flush on shutdown: {e}")
logger.info("Shutdown requested, cleaning up...")
if messaging_platform:
await _best_effort("messaging_platform.stop", messaging_platform.stop())
+2 -1
View File
@@ -126,9 +126,10 @@ def try_filepath_mock(
)
# Cheapest/most common optimizations first for faster short-circuit.
OPTIMIZATION_HANDLERS = [
try_prefix_detection,
try_quota_mock,
try_prefix_detection,
try_title_skip,
try_suggestion_skip,
try_filepath_mock,
+23 -38
View File
@@ -5,7 +5,7 @@ Contains token counting for API requests.
import json
import logging
from typing import List, Optional, Union
from typing import Any, List, Optional, Union
import tiktoken
@@ -15,6 +15,13 @@ ENCODER = tiktoken.get_encoding("cl100k_base")
__all__ = ["get_token_count"]
def _get_block_attr(block: object, key: str, default: Any = "") -> Any:
"""Get attribute from block (object or dict)."""
if isinstance(block, dict):
return block.get(key, default) # type: ignore[no-matching-overload]
return getattr(block, key, default)
def get_token_count(
messages: List,
system: Optional[Union[str, List]] = None,
@@ -32,13 +39,9 @@ def get_token_count(
total_tokens += len(ENCODER.encode(system))
elif isinstance(system, list):
for block in system:
text = (
getattr(block, "text", None)
if hasattr(block, "text")
else (block.get("text", "") if isinstance(block, dict) else "")
)
text = _get_block_attr(block, "text", "")
if text:
total_tokens += len(ENCODER.encode(text))
total_tokens += len(ENCODER.encode(str(text)))
total_tokens += 4 # System block formatting overhead
for msg in messages:
@@ -46,38 +49,24 @@ def get_token_count(
total_tokens += len(ENCODER.encode(msg.content))
elif isinstance(msg.content, list):
for block in msg.content:
b_type = getattr(block, "type", None) or (
block.get("type") if isinstance(block, dict) else None
)
b_type = _get_block_attr(block, "type") or None
if b_type == "text":
text = getattr(block, "text", "") or (
block.get("text", "") if isinstance(block, dict) else ""
)
total_tokens += len(ENCODER.encode(text))
text = _get_block_attr(block, "text", "")
total_tokens += len(ENCODER.encode(str(text)))
elif b_type == "thinking":
thinking = getattr(block, "thinking", "") or (
block.get("thinking", "") if isinstance(block, dict) else ""
)
total_tokens += len(ENCODER.encode(thinking))
thinking = _get_block_attr(block, "thinking", "")
total_tokens += len(ENCODER.encode(str(thinking)))
elif b_type == "tool_use":
name = getattr(block, "name", "") or (
block.get("name", "") if isinstance(block, dict) else ""
)
inp = getattr(block, "input", {}) or (
block.get("input", {}) if isinstance(block, dict) else {}
)
block_id = getattr(block, "id", "") or (
block.get("id", "") if isinstance(block, dict) else ""
)
total_tokens += len(ENCODER.encode(name))
name = _get_block_attr(block, "name", "")
inp = _get_block_attr(block, "input", {})
block_id = _get_block_attr(block, "id", "")
total_tokens += len(ENCODER.encode(str(name)))
total_tokens += len(ENCODER.encode(json.dumps(inp)))
total_tokens += len(ENCODER.encode(str(block_id)))
total_tokens += 15
elif b_type == "image":
source = getattr(block, "source", None) or (
block.get("source", {}) if isinstance(block, dict) else {}
)
source = _get_block_attr(block, "source")
if isinstance(source, dict):
data = source.get("data") or source.get("base64") or ""
if data:
@@ -87,12 +76,8 @@ def get_token_count(
else:
total_tokens += 765
elif b_type == "tool_result":
content = getattr(block, "content", "") or (
block.get("content", "") if isinstance(block, dict) else ""
)
tool_use_id = getattr(block, "tool_use_id", "") or (
block.get("tool_use_id", "") if isinstance(block, dict) else ""
)
content = _get_block_attr(block, "content", "")
tool_use_id = _get_block_attr(block, "tool_use_id", "")
if isinstance(content, str):
total_tokens += len(ENCODER.encode(content))
else:
@@ -102,7 +87,7 @@ def get_token_count(
else:
try:
total_tokens += len(ENCODER.encode(json.dumps(block)))
except (TypeError, ValueError):
except TypeError, ValueError:
total_tokens += len(ENCODER.encode(str(block)))
if tools:
+35 -28
View File
@@ -32,41 +32,45 @@ logger = logging.getLogger(__name__)
# Status message prefixes used to filter our own messages (ignore echo)
STATUS_MESSAGE_PREFIXES = ("", "💭", "🔧", "", "", "🚀", "🤖", "📋", "📊", "🔄")
# Event types that update the transcript
TRANSCRIPT_EVENT_TYPES = (
"thinking_start",
"thinking_delta",
"thinking_chunk",
"thinking_stop",
"text_start",
"text_delta",
"text_chunk",
"text_stop",
"tool_use_start",
"tool_use_delta",
"tool_use_stop",
"tool_use",
"tool_result",
"block_stop",
"error",
# Event types that update the transcript (frozenset for O(1) membership)
TRANSCRIPT_EVENT_TYPES = frozenset(
{
"thinking_start",
"thinking_delta",
"thinking_chunk",
"thinking_stop",
"text_start",
"text_delta",
"text_chunk",
"text_stop",
"tool_use_start",
"tool_use_delta",
"tool_use_stop",
"tool_use",
"tool_result",
"block_stop",
"error",
}
)
# Event types -> (emoji, label) for status updates
# Event type -> (emoji, label) for status updates (O(1) lookup)
_EVENT_STATUS_MAP = {
("thinking_start", "thinking_delta", "thinking_chunk"): (
"🧠",
"Claude is thinking...",
),
("text_start", "text_delta", "text_chunk"): ("🧠", "Claude is working..."),
("tool_result",): ("", "Executing tools..."),
"thinking_start": ("🧠", "Claude is thinking..."),
"thinking_delta": ("🧠", "Claude is thinking..."),
"thinking_chunk": ("🧠", "Claude is thinking..."),
"text_start": ("🧠", "Claude is working..."),
"text_delta": ("🧠", "Claude is working..."),
"text_chunk": ("🧠", "Claude is working..."),
"tool_result": ("", "Executing tools..."),
}
def _get_status_for_event(ptype: str, parsed: dict) -> Optional[str]:
"""Return status string for event type, or None if no status update needed."""
for types, (emoji, label) in _EVENT_STATUS_MAP.items():
if ptype in types:
return format_status(emoji, label)
entry = _EVENT_STATUS_MAP.get(ptype)
if entry is not None:
emoji, label = entry
return format_status(emoji, label)
if ptype in ("tool_use_start", "tool_use_delta", "tool_use"):
if parsed.get("name") == "Task":
return format_status("🤖", "Subagent working...")
@@ -636,6 +640,7 @@ class ClaudeMessageHandler:
def _update_cancelled_nodes_ui(self, nodes: List[MessageNode]) -> None:
"""Update status messages and persist tree state for cancelled nodes."""
trees_to_save: dict[str, MessageTree] = {}
for node in nodes:
self.platform.fire_and_forget(
self.platform.queue_edit_message(
@@ -647,7 +652,9 @@ class ClaudeMessageHandler:
)
tree = self.tree_queue.get_tree_for_node(node.node_id)
if tree:
self.session_store.save_tree(tree.root_id, tree.to_dict())
trees_to_save[tree.root_id] = tree
for root_id, tree in trees_to_save.items():
self.session_store.save_tree(root_id, tree.to_dict())
async def _handle_stop_command(self, incoming: IncomingMessage) -> None:
"""Handle /stop command from messaging platform."""
-2
View File
@@ -79,8 +79,6 @@ class MessagingRateLimiter:
_lock = asyncio.Lock()
def __new__(cls, *args, **kwargs):
if not cls._instance:
pass
return super(MessagingRateLimiter, cls).__new__(cls)
@classmethod
+45 -7
View File
@@ -50,6 +50,9 @@ class SessionStore:
# Key: "{platform}:{chat_id}" -> list of records
self._message_log: Dict[str, List[Dict[str, Any]]] = {}
self._message_log_ids: Dict[str, set[str]] = {}
self._dirty = False
self._save_timer: Optional[threading.Timer] = None
self._save_debounce_secs = 0.5
self._load()
def _make_key(self, platform: str, chat_id: str, msg_id: str) -> str:
@@ -130,7 +133,7 @@ class SessionStore:
logger.error(f"Failed to load sessions: {e}")
def _save(self) -> None:
"""Persist sessions and trees to disk."""
"""Persist sessions and trees to disk. Caller must hold self._lock."""
try:
data = {
"sessions": {
@@ -145,6 +148,41 @@ class SessionStore:
except Exception as e:
logger.error(f"Failed to save sessions: {e}")
def _schedule_save(self) -> None:
"""Schedule a debounced save. Caller must hold self._lock."""
self._dirty = True
if self._save_timer is not None:
self._save_timer.cancel()
self._save_timer = None
self._save_timer = threading.Timer(
self._save_debounce_secs, self._save_from_timer
)
self._save_timer.daemon = True
self._save_timer.start()
def _save_from_timer(self) -> None:
"""Timer callback: save if dirty. Runs in timer thread."""
with self._lock:
if not self._dirty:
self._save_timer = None
return
self._save()
self._dirty = False
self._save_timer = None
def _flush_save(self) -> None:
"""Immediate save, cancel any pending debounced save. Caller must hold self._lock."""
if self._save_timer is not None:
self._save_timer.cancel()
self._save_timer = None
self._dirty = False
self._save()
def flush_pending_save(self) -> None:
"""Flush any pending debounced save. Call on shutdown to avoid losing data."""
with self._lock:
self._flush_save()
def record_message_id(
self,
platform: str,
@@ -192,7 +230,7 @@ class SessionStore:
except Exception:
pass
self._save()
self._schedule_save()
def get_message_ids_for_chat(self, platform: str, chat_id: str) -> List[str]:
"""Get all recorded message IDs for a chat (in insertion order)."""
@@ -214,7 +252,7 @@ class SessionStore:
self._node_to_tree.clear()
self._message_log.clear()
self._message_log_ids.clear()
self._save()
self._flush_save()
# ==================== Tree Methods ====================
@@ -233,7 +271,7 @@ class SessionStore:
for node_id in tree_data.get("nodes", {}).keys():
self._node_to_tree[node_id] = root_id
self._save()
self._schedule_save()
logger.debug(f"Saved tree {root_id}")
def get_tree(self, root_id: str) -> Optional[dict]:
@@ -250,7 +288,7 @@ class SessionStore:
"""Register a node ID to a tree root."""
with self._lock:
self._node_to_tree[node_id] = root_id
self._save()
self._schedule_save()
def get_all_trees(self) -> Dict[str, dict]:
"""Get all stored trees (public accessor)."""
@@ -269,7 +307,7 @@ class SessionStore:
with self._lock:
self._trees = trees
self._node_to_tree = node_to_tree
self._save()
self._schedule_save()
def cleanup_old_trees(self, max_age_days: int = 30) -> int:
"""Remove trees older than max_age_days."""
@@ -299,7 +337,7 @@ class SessionStore:
removed += 1
if removed:
self._save()
self._schedule_save()
logger.info(f"Cleaned up {removed} old trees")
return removed
+18 -4
View File
@@ -11,8 +11,9 @@ from __future__ import annotations
import json
import logging
import os
from collections import deque
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, Iterable, List, Optional
logger = logging.getLogger(__name__)
@@ -299,6 +300,18 @@ class TranscriptBuffer:
if not self._subagent_stack:
return
if tool_id:
# O(1) common case: LIFO - top of stack matches.
if self._subagent_stack[-1] == tool_id:
self._subagent_stack.pop()
if self._subagent_segments:
self._subagent_segments.pop()
if self._debug_subagent_stack:
logger.debug(
"SUBAGENT_STACK: pop id=%r depth=%d (LIFO)",
tool_id,
len(self._subagent_stack),
)
return
# Pop to the matching id (defensive against non-LIFO emissions).
try:
idx = (
@@ -545,7 +558,7 @@ class TranscriptBuffer:
status_text = f"\n\n{status}" if status else ""
prefix_marker = ctx.escape_text("... (truncated)\n")
def _join(parts: List[str], add_marker: bool) -> str:
def _join(parts: Iterable[str], add_marker: bool) -> str:
body = "\n".join(parts)
if add_marker and body:
body = prefix_marker + body
@@ -557,13 +570,14 @@ class TranscriptBuffer:
return candidate
# Drop oldest segments until under limit (keep the tail).
parts = list(rendered)
# Use deque for O(1) popleft; list.pop(0) would be O(n) per iteration.
parts: deque[str] = deque(rendered)
dropped = False
while parts:
candidate = _join(parts, add_marker=True)
if len(candidate) <= limit_chars:
return candidate
parts.pop(0)
parts.popleft()
dropped = True
# Nothing fits; return status only with marker if possible.
+3
View File
@@ -284,6 +284,9 @@ class MessageTree:
Caller must hold the tree lock (e.g. via with_lock).
Returns True if node was removed, False if not in queue.
Note: asyncio.Queue has no built-in remove; we filter via the internal
deque. O(n) in queue size; acceptable for typical tree queue sizes.
"""
queue_deque: deque = self._queue._queue # type: ignore[attr-defined]
if node_id not in queue_deque:
+2 -1
View File
@@ -254,13 +254,14 @@ class TreeQueueManager:
# 2. Drain queue and mark nodes as cancelled
queue_nodes = tree.drain_queue_and_mark_cancelled()
cancelled_nodes.extend(queue_nodes)
cancelled_ids = {n.node_id for n in cancelled_nodes}
# 3. Cleanup: Mark ANY other PENDING or IN_PROGRESS nodes as ERROR
cleanup_count = 0
for node in tree.all_nodes():
if (
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
and node not in cancelled_nodes
and node.node_id not in cancelled_ids
):
node.state = MessageState.ERROR
node.error_message = "Stale task cleaned up"
+2
View File
@@ -29,6 +29,7 @@ async def test_reply_to_old_status_message_after_restore_routes_to_parent(
handler1.tree_queue.register_node("status_A", tree.root_id)
store.register_node("status_A", tree.root_id)
store.save_tree(tree.root_id, tree.to_dict())
store.flush_pending_save()
# "Restart": new store instance loads from disk, and we restore TreeQueueManager.
store2 = SessionStore(storage_path=str(store_path))
@@ -81,6 +82,7 @@ async def test_reply_to_old_status_message_without_mapping_creates_new_conversat
)
# Intentionally do NOT register "status_A" mapping.
store.save_tree(tree.root_id, tree.to_dict())
store.flush_pending_save()
store2 = SessionStore(storage_path=str(store_path))
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
+1
View File
@@ -152,6 +152,7 @@ class TestSessionStoreClearAll:
ids = store.get_message_ids_for_chat("telegram", "c1")
assert ids == ["1", "2"]
store.flush_pending_save()
store2 = SessionStore(storage_path=path)
assert store2.get_message_ids_for_chat("telegram", "c1") == ["1", "2"]
+14
View File
@@ -123,6 +123,20 @@ def test_transcript_truncates_by_dropping_oldest_segments():
assert escape_md_v2("segment_0") not in out
def test_transcript_render_many_segments_completes_quickly():
"""Render with 200+ segments exercises O(n) truncation (deque popleft)."""
t = TranscriptBuffer()
for i in range(200):
t.apply({"type": "text_start", "index": i})
t.apply({"type": "text_delta", "index": i, "text": f"seg_{i} " + ("y" * 80)})
t.apply({"type": "block_stop", "index": i})
out = t.render(_ctx(), limit_chars=500, status="ok")
assert escape_md_v2("... (truncated)") in out
assert "199" in out # last segment (MarkdownV2 escapes underscores)
assert "seg_0 " not in out # oldest segment dropped
def test_transcript_reused_index_closes_previous_open_block():
t = TranscriptBuffer()
# Open a text block at index 0, but never close it.