mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-06-02 06:13:46 +02:00
Optimized code in hot paths with z-ai/glm5
This commit is contained in:
@@ -145,6 +145,11 @@ async def lifespan(app: FastAPI):
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
if message_handler and hasattr(message_handler, "session_store"):
|
||||
try:
|
||||
message_handler.session_store.flush_pending_save()
|
||||
except Exception as e:
|
||||
logger.warning(f"Session store flush on shutdown: {e}")
|
||||
logger.info("Shutdown requested, cleaning up...")
|
||||
if messaging_platform:
|
||||
await _best_effort("messaging_platform.stop", messaging_platform.stop())
|
||||
|
||||
@@ -126,9 +126,10 @@ def try_filepath_mock(
|
||||
)
|
||||
|
||||
|
||||
# Cheapest/most common optimizations first for faster short-circuit.
|
||||
OPTIMIZATION_HANDLERS = [
|
||||
try_prefix_detection,
|
||||
try_quota_mock,
|
||||
try_prefix_detection,
|
||||
try_title_skip,
|
||||
try_suggestion_skip,
|
||||
try_filepath_mock,
|
||||
|
||||
+23
-38
@@ -5,7 +5,7 @@ Contains token counting for API requests.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import tiktoken
|
||||
|
||||
@@ -15,6 +15,13 @@ ENCODER = tiktoken.get_encoding("cl100k_base")
|
||||
__all__ = ["get_token_count"]
|
||||
|
||||
|
||||
def _get_block_attr(block: object, key: str, default: Any = "") -> Any:
|
||||
"""Get attribute from block (object or dict)."""
|
||||
if isinstance(block, dict):
|
||||
return block.get(key, default) # type: ignore[no-matching-overload]
|
||||
return getattr(block, key, default)
|
||||
|
||||
|
||||
def get_token_count(
|
||||
messages: List,
|
||||
system: Optional[Union[str, List]] = None,
|
||||
@@ -32,13 +39,9 @@ def get_token_count(
|
||||
total_tokens += len(ENCODER.encode(system))
|
||||
elif isinstance(system, list):
|
||||
for block in system:
|
||||
text = (
|
||||
getattr(block, "text", None)
|
||||
if hasattr(block, "text")
|
||||
else (block.get("text", "") if isinstance(block, dict) else "")
|
||||
)
|
||||
text = _get_block_attr(block, "text", "")
|
||||
if text:
|
||||
total_tokens += len(ENCODER.encode(text))
|
||||
total_tokens += len(ENCODER.encode(str(text)))
|
||||
total_tokens += 4 # System block formatting overhead
|
||||
|
||||
for msg in messages:
|
||||
@@ -46,38 +49,24 @@ def get_token_count(
|
||||
total_tokens += len(ENCODER.encode(msg.content))
|
||||
elif isinstance(msg.content, list):
|
||||
for block in msg.content:
|
||||
b_type = getattr(block, "type", None) or (
|
||||
block.get("type") if isinstance(block, dict) else None
|
||||
)
|
||||
b_type = _get_block_attr(block, "type") or None
|
||||
|
||||
if b_type == "text":
|
||||
text = getattr(block, "text", "") or (
|
||||
block.get("text", "") if isinstance(block, dict) else ""
|
||||
)
|
||||
total_tokens += len(ENCODER.encode(text))
|
||||
text = _get_block_attr(block, "text", "")
|
||||
total_tokens += len(ENCODER.encode(str(text)))
|
||||
elif b_type == "thinking":
|
||||
thinking = getattr(block, "thinking", "") or (
|
||||
block.get("thinking", "") if isinstance(block, dict) else ""
|
||||
)
|
||||
total_tokens += len(ENCODER.encode(thinking))
|
||||
thinking = _get_block_attr(block, "thinking", "")
|
||||
total_tokens += len(ENCODER.encode(str(thinking)))
|
||||
elif b_type == "tool_use":
|
||||
name = getattr(block, "name", "") or (
|
||||
block.get("name", "") if isinstance(block, dict) else ""
|
||||
)
|
||||
inp = getattr(block, "input", {}) or (
|
||||
block.get("input", {}) if isinstance(block, dict) else {}
|
||||
)
|
||||
block_id = getattr(block, "id", "") or (
|
||||
block.get("id", "") if isinstance(block, dict) else ""
|
||||
)
|
||||
total_tokens += len(ENCODER.encode(name))
|
||||
name = _get_block_attr(block, "name", "")
|
||||
inp = _get_block_attr(block, "input", {})
|
||||
block_id = _get_block_attr(block, "id", "")
|
||||
total_tokens += len(ENCODER.encode(str(name)))
|
||||
total_tokens += len(ENCODER.encode(json.dumps(inp)))
|
||||
total_tokens += len(ENCODER.encode(str(block_id)))
|
||||
total_tokens += 15
|
||||
elif b_type == "image":
|
||||
source = getattr(block, "source", None) or (
|
||||
block.get("source", {}) if isinstance(block, dict) else {}
|
||||
)
|
||||
source = _get_block_attr(block, "source")
|
||||
if isinstance(source, dict):
|
||||
data = source.get("data") or source.get("base64") or ""
|
||||
if data:
|
||||
@@ -87,12 +76,8 @@ def get_token_count(
|
||||
else:
|
||||
total_tokens += 765
|
||||
elif b_type == "tool_result":
|
||||
content = getattr(block, "content", "") or (
|
||||
block.get("content", "") if isinstance(block, dict) else ""
|
||||
)
|
||||
tool_use_id = getattr(block, "tool_use_id", "") or (
|
||||
block.get("tool_use_id", "") if isinstance(block, dict) else ""
|
||||
)
|
||||
content = _get_block_attr(block, "content", "")
|
||||
tool_use_id = _get_block_attr(block, "tool_use_id", "")
|
||||
if isinstance(content, str):
|
||||
total_tokens += len(ENCODER.encode(content))
|
||||
else:
|
||||
@@ -102,7 +87,7 @@ def get_token_count(
|
||||
else:
|
||||
try:
|
||||
total_tokens += len(ENCODER.encode(json.dumps(block)))
|
||||
except (TypeError, ValueError):
|
||||
except TypeError, ValueError:
|
||||
total_tokens += len(ENCODER.encode(str(block)))
|
||||
|
||||
if tools:
|
||||
|
||||
+35
-28
@@ -32,41 +32,45 @@ logger = logging.getLogger(__name__)
|
||||
# Status message prefixes used to filter our own messages (ignore echo)
|
||||
STATUS_MESSAGE_PREFIXES = ("⏳", "💭", "🔧", "✅", "❌", "🚀", "🤖", "📋", "📊", "🔄")
|
||||
|
||||
# Event types that update the transcript
|
||||
TRANSCRIPT_EVENT_TYPES = (
|
||||
"thinking_start",
|
||||
"thinking_delta",
|
||||
"thinking_chunk",
|
||||
"thinking_stop",
|
||||
"text_start",
|
||||
"text_delta",
|
||||
"text_chunk",
|
||||
"text_stop",
|
||||
"tool_use_start",
|
||||
"tool_use_delta",
|
||||
"tool_use_stop",
|
||||
"tool_use",
|
||||
"tool_result",
|
||||
"block_stop",
|
||||
"error",
|
||||
# Event types that update the transcript (frozenset for O(1) membership)
|
||||
TRANSCRIPT_EVENT_TYPES = frozenset(
|
||||
{
|
||||
"thinking_start",
|
||||
"thinking_delta",
|
||||
"thinking_chunk",
|
||||
"thinking_stop",
|
||||
"text_start",
|
||||
"text_delta",
|
||||
"text_chunk",
|
||||
"text_stop",
|
||||
"tool_use_start",
|
||||
"tool_use_delta",
|
||||
"tool_use_stop",
|
||||
"tool_use",
|
||||
"tool_result",
|
||||
"block_stop",
|
||||
"error",
|
||||
}
|
||||
)
|
||||
|
||||
# Event types -> (emoji, label) for status updates
|
||||
# Event type -> (emoji, label) for status updates (O(1) lookup)
|
||||
_EVENT_STATUS_MAP = {
|
||||
("thinking_start", "thinking_delta", "thinking_chunk"): (
|
||||
"🧠",
|
||||
"Claude is thinking...",
|
||||
),
|
||||
("text_start", "text_delta", "text_chunk"): ("🧠", "Claude is working..."),
|
||||
("tool_result",): ("⏳", "Executing tools..."),
|
||||
"thinking_start": ("🧠", "Claude is thinking..."),
|
||||
"thinking_delta": ("🧠", "Claude is thinking..."),
|
||||
"thinking_chunk": ("🧠", "Claude is thinking..."),
|
||||
"text_start": ("🧠", "Claude is working..."),
|
||||
"text_delta": ("🧠", "Claude is working..."),
|
||||
"text_chunk": ("🧠", "Claude is working..."),
|
||||
"tool_result": ("⏳", "Executing tools..."),
|
||||
}
|
||||
|
||||
|
||||
def _get_status_for_event(ptype: str, parsed: dict) -> Optional[str]:
|
||||
"""Return status string for event type, or None if no status update needed."""
|
||||
for types, (emoji, label) in _EVENT_STATUS_MAP.items():
|
||||
if ptype in types:
|
||||
return format_status(emoji, label)
|
||||
entry = _EVENT_STATUS_MAP.get(ptype)
|
||||
if entry is not None:
|
||||
emoji, label = entry
|
||||
return format_status(emoji, label)
|
||||
if ptype in ("tool_use_start", "tool_use_delta", "tool_use"):
|
||||
if parsed.get("name") == "Task":
|
||||
return format_status("🤖", "Subagent working...")
|
||||
@@ -636,6 +640,7 @@ class ClaudeMessageHandler:
|
||||
|
||||
def _update_cancelled_nodes_ui(self, nodes: List[MessageNode]) -> None:
|
||||
"""Update status messages and persist tree state for cancelled nodes."""
|
||||
trees_to_save: dict[str, MessageTree] = {}
|
||||
for node in nodes:
|
||||
self.platform.fire_and_forget(
|
||||
self.platform.queue_edit_message(
|
||||
@@ -647,7 +652,9 @@ class ClaudeMessageHandler:
|
||||
)
|
||||
tree = self.tree_queue.get_tree_for_node(node.node_id)
|
||||
if tree:
|
||||
self.session_store.save_tree(tree.root_id, tree.to_dict())
|
||||
trees_to_save[tree.root_id] = tree
|
||||
for root_id, tree in trees_to_save.items():
|
||||
self.session_store.save_tree(root_id, tree.to_dict())
|
||||
|
||||
async def _handle_stop_command(self, incoming: IncomingMessage) -> None:
|
||||
"""Handle /stop command from messaging platform."""
|
||||
|
||||
@@ -79,8 +79,6 @@ class MessagingRateLimiter:
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if not cls._instance:
|
||||
pass
|
||||
return super(MessagingRateLimiter, cls).__new__(cls)
|
||||
|
||||
@classmethod
|
||||
|
||||
+45
-7
@@ -50,6 +50,9 @@ class SessionStore:
|
||||
# Key: "{platform}:{chat_id}" -> list of records
|
||||
self._message_log: Dict[str, List[Dict[str, Any]]] = {}
|
||||
self._message_log_ids: Dict[str, set[str]] = {}
|
||||
self._dirty = False
|
||||
self._save_timer: Optional[threading.Timer] = None
|
||||
self._save_debounce_secs = 0.5
|
||||
self._load()
|
||||
|
||||
def _make_key(self, platform: str, chat_id: str, msg_id: str) -> str:
|
||||
@@ -130,7 +133,7 @@ class SessionStore:
|
||||
logger.error(f"Failed to load sessions: {e}")
|
||||
|
||||
def _save(self) -> None:
|
||||
"""Persist sessions and trees to disk."""
|
||||
"""Persist sessions and trees to disk. Caller must hold self._lock."""
|
||||
try:
|
||||
data = {
|
||||
"sessions": {
|
||||
@@ -145,6 +148,41 @@ class SessionStore:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save sessions: {e}")
|
||||
|
||||
def _schedule_save(self) -> None:
|
||||
"""Schedule a debounced save. Caller must hold self._lock."""
|
||||
self._dirty = True
|
||||
if self._save_timer is not None:
|
||||
self._save_timer.cancel()
|
||||
self._save_timer = None
|
||||
self._save_timer = threading.Timer(
|
||||
self._save_debounce_secs, self._save_from_timer
|
||||
)
|
||||
self._save_timer.daemon = True
|
||||
self._save_timer.start()
|
||||
|
||||
def _save_from_timer(self) -> None:
|
||||
"""Timer callback: save if dirty. Runs in timer thread."""
|
||||
with self._lock:
|
||||
if not self._dirty:
|
||||
self._save_timer = None
|
||||
return
|
||||
self._save()
|
||||
self._dirty = False
|
||||
self._save_timer = None
|
||||
|
||||
def _flush_save(self) -> None:
|
||||
"""Immediate save, cancel any pending debounced save. Caller must hold self._lock."""
|
||||
if self._save_timer is not None:
|
||||
self._save_timer.cancel()
|
||||
self._save_timer = None
|
||||
self._dirty = False
|
||||
self._save()
|
||||
|
||||
def flush_pending_save(self) -> None:
|
||||
"""Flush any pending debounced save. Call on shutdown to avoid losing data."""
|
||||
with self._lock:
|
||||
self._flush_save()
|
||||
|
||||
def record_message_id(
|
||||
self,
|
||||
platform: str,
|
||||
@@ -192,7 +230,7 @@ class SessionStore:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
self._save()
|
||||
self._schedule_save()
|
||||
|
||||
def get_message_ids_for_chat(self, platform: str, chat_id: str) -> List[str]:
|
||||
"""Get all recorded message IDs for a chat (in insertion order)."""
|
||||
@@ -214,7 +252,7 @@ class SessionStore:
|
||||
self._node_to_tree.clear()
|
||||
self._message_log.clear()
|
||||
self._message_log_ids.clear()
|
||||
self._save()
|
||||
self._flush_save()
|
||||
|
||||
# ==================== Tree Methods ====================
|
||||
|
||||
@@ -233,7 +271,7 @@ class SessionStore:
|
||||
for node_id in tree_data.get("nodes", {}).keys():
|
||||
self._node_to_tree[node_id] = root_id
|
||||
|
||||
self._save()
|
||||
self._schedule_save()
|
||||
logger.debug(f"Saved tree {root_id}")
|
||||
|
||||
def get_tree(self, root_id: str) -> Optional[dict]:
|
||||
@@ -250,7 +288,7 @@ class SessionStore:
|
||||
"""Register a node ID to a tree root."""
|
||||
with self._lock:
|
||||
self._node_to_tree[node_id] = root_id
|
||||
self._save()
|
||||
self._schedule_save()
|
||||
|
||||
def get_all_trees(self) -> Dict[str, dict]:
|
||||
"""Get all stored trees (public accessor)."""
|
||||
@@ -269,7 +307,7 @@ class SessionStore:
|
||||
with self._lock:
|
||||
self._trees = trees
|
||||
self._node_to_tree = node_to_tree
|
||||
self._save()
|
||||
self._schedule_save()
|
||||
|
||||
def cleanup_old_trees(self, max_age_days: int = 30) -> int:
|
||||
"""Remove trees older than max_age_days."""
|
||||
@@ -299,7 +337,7 @@ class SessionStore:
|
||||
removed += 1
|
||||
|
||||
if removed:
|
||||
self._save()
|
||||
self._schedule_save()
|
||||
logger.info(f"Cleaned up {removed} old trees")
|
||||
|
||||
return removed
|
||||
|
||||
+18
-4
@@ -11,8 +11,9 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -299,6 +300,18 @@ class TranscriptBuffer:
|
||||
if not self._subagent_stack:
|
||||
return
|
||||
if tool_id:
|
||||
# O(1) common case: LIFO - top of stack matches.
|
||||
if self._subagent_stack[-1] == tool_id:
|
||||
self._subagent_stack.pop()
|
||||
if self._subagent_segments:
|
||||
self._subagent_segments.pop()
|
||||
if self._debug_subagent_stack:
|
||||
logger.debug(
|
||||
"SUBAGENT_STACK: pop id=%r depth=%d (LIFO)",
|
||||
tool_id,
|
||||
len(self._subagent_stack),
|
||||
)
|
||||
return
|
||||
# Pop to the matching id (defensive against non-LIFO emissions).
|
||||
try:
|
||||
idx = (
|
||||
@@ -545,7 +558,7 @@ class TranscriptBuffer:
|
||||
status_text = f"\n\n{status}" if status else ""
|
||||
prefix_marker = ctx.escape_text("... (truncated)\n")
|
||||
|
||||
def _join(parts: List[str], add_marker: bool) -> str:
|
||||
def _join(parts: Iterable[str], add_marker: bool) -> str:
|
||||
body = "\n".join(parts)
|
||||
if add_marker and body:
|
||||
body = prefix_marker + body
|
||||
@@ -557,13 +570,14 @@ class TranscriptBuffer:
|
||||
return candidate
|
||||
|
||||
# Drop oldest segments until under limit (keep the tail).
|
||||
parts = list(rendered)
|
||||
# Use deque for O(1) popleft; list.pop(0) would be O(n) per iteration.
|
||||
parts: deque[str] = deque(rendered)
|
||||
dropped = False
|
||||
while parts:
|
||||
candidate = _join(parts, add_marker=True)
|
||||
if len(candidate) <= limit_chars:
|
||||
return candidate
|
||||
parts.pop(0)
|
||||
parts.popleft()
|
||||
dropped = True
|
||||
|
||||
# Nothing fits; return status only with marker if possible.
|
||||
|
||||
@@ -284,6 +284,9 @@ class MessageTree:
|
||||
|
||||
Caller must hold the tree lock (e.g. via with_lock).
|
||||
Returns True if node was removed, False if not in queue.
|
||||
|
||||
Note: asyncio.Queue has no built-in remove; we filter via the internal
|
||||
deque. O(n) in queue size; acceptable for typical tree queue sizes.
|
||||
"""
|
||||
queue_deque: deque = self._queue._queue # type: ignore[attr-defined]
|
||||
if node_id not in queue_deque:
|
||||
|
||||
@@ -254,13 +254,14 @@ class TreeQueueManager:
|
||||
# 2. Drain queue and mark nodes as cancelled
|
||||
queue_nodes = tree.drain_queue_and_mark_cancelled()
|
||||
cancelled_nodes.extend(queue_nodes)
|
||||
cancelled_ids = {n.node_id for n in cancelled_nodes}
|
||||
|
||||
# 3. Cleanup: Mark ANY other PENDING or IN_PROGRESS nodes as ERROR
|
||||
cleanup_count = 0
|
||||
for node in tree.all_nodes():
|
||||
if (
|
||||
node.state in (MessageState.PENDING, MessageState.IN_PROGRESS)
|
||||
and node not in cancelled_nodes
|
||||
and node.node_id not in cancelled_ids
|
||||
):
|
||||
node.state = MessageState.ERROR
|
||||
node.error_message = "Stale task cleaned up"
|
||||
|
||||
@@ -29,6 +29,7 @@ async def test_reply_to_old_status_message_after_restore_routes_to_parent(
|
||||
handler1.tree_queue.register_node("status_A", tree.root_id)
|
||||
store.register_node("status_A", tree.root_id)
|
||||
store.save_tree(tree.root_id, tree.to_dict())
|
||||
store.flush_pending_save()
|
||||
|
||||
# "Restart": new store instance loads from disk, and we restore TreeQueueManager.
|
||||
store2 = SessionStore(storage_path=str(store_path))
|
||||
@@ -81,6 +82,7 @@ async def test_reply_to_old_status_message_without_mapping_creates_new_conversat
|
||||
)
|
||||
# Intentionally do NOT register "status_A" mapping.
|
||||
store.save_tree(tree.root_id, tree.to_dict())
|
||||
store.flush_pending_save()
|
||||
|
||||
store2 = SessionStore(storage_path=str(store_path))
|
||||
handler2 = ClaudeMessageHandler(mock_platform, mock_cli_manager, store2)
|
||||
|
||||
@@ -152,6 +152,7 @@ class TestSessionStoreClearAll:
|
||||
ids = store.get_message_ids_for_chat("telegram", "c1")
|
||||
assert ids == ["1", "2"]
|
||||
|
||||
store.flush_pending_save()
|
||||
store2 = SessionStore(storage_path=path)
|
||||
assert store2.get_message_ids_for_chat("telegram", "c1") == ["1", "2"]
|
||||
|
||||
|
||||
@@ -123,6 +123,20 @@ def test_transcript_truncates_by_dropping_oldest_segments():
|
||||
assert escape_md_v2("segment_0") not in out
|
||||
|
||||
|
||||
def test_transcript_render_many_segments_completes_quickly():
|
||||
"""Render with 200+ segments exercises O(n) truncation (deque popleft)."""
|
||||
t = TranscriptBuffer()
|
||||
for i in range(200):
|
||||
t.apply({"type": "text_start", "index": i})
|
||||
t.apply({"type": "text_delta", "index": i, "text": f"seg_{i} " + ("y" * 80)})
|
||||
t.apply({"type": "block_stop", "index": i})
|
||||
|
||||
out = t.render(_ctx(), limit_chars=500, status="ok")
|
||||
assert escape_md_v2("... (truncated)") in out
|
||||
assert "199" in out # last segment (MarkdownV2 escapes underscores)
|
||||
assert "seg_0 " not in out # oldest segment dropped
|
||||
|
||||
|
||||
def test_transcript_reused_index_closes_previous_open_block():
|
||||
t = TranscriptBuffer()
|
||||
# Open a text block at index 0, but never close it.
|
||||
|
||||
Reference in New Issue
Block a user