added tree based session handler

This commit is contained in:
Alishahryar1
2026-01-29 18:17:42 -08:00
parent 2c6b40012a
commit 5fd7268d3f
9 changed files with 1350 additions and 418 deletions
+117 -59
View File
@@ -1,8 +1,8 @@
"""
Session Store for Messaging Platforms
Provides persistent storage for mapping platform messages to Claude CLI session IDs.
This enables conversation continuation when replying to old messages.
Provides persistent storage for mapping platform messages to Claude CLI session IDs
and message trees for conversation continuation.
"""
import json
@@ -21,17 +21,17 @@ class SessionRecord:
"""A single session record."""
session_id: str
chat_id: str # Changed to str for platform-agnostic support
chat_id: str
initial_msg_id: str
last_msg_id: str
platform: str # "telegram", "discord", etc.
platform: str
created_at: str
updated_at: str
class SessionStore:
"""
Persistent storage for message ↔ Claude session mappings.
Persistent storage for message ↔ Claude session mappings and message trees.
Uses a JSON file for storage with thread-safe operations.
Platform-agnostic: works with any messaging platform.
@@ -40,10 +40,12 @@ class SessionStore:
def __init__(self, storage_path: str = "sessions.json"):
self.storage_path = storage_path
self._lock = threading.Lock()
self._sessions: Dict[str, SessionRecord] = {} # session_id -> record
self._sessions: Dict[str, SessionRecord] = {}
self._msg_to_session: Dict[
str, str
] = {} # "platform:chat_id:msg_id" -> session_id
self._trees: Dict[str, dict] = {} # root_id -> tree data
self._node_to_tree: Dict[str, str] = {} # node_id -> root_id
self._load()
def _make_key(self, platform: str, chat_id: str, msg_id: str) -> str:
@@ -51,7 +53,7 @@ class SessionStore:
return f"{platform}:{chat_id}:{msg_id}"
def _load(self) -> None:
"""Load sessions from disk."""
"""Load sessions and trees from disk."""
if not os.path.exists(self.storage_path):
return
@@ -59,18 +61,16 @@ class SessionStore:
with open(self.storage_path, "r", encoding="utf-8") as f:
data = json.load(f)
# Load sessions (legacy support)
for sid, record_data in data.get("sessions", {}).items():
# Handle legacy records without platform field
if "platform" not in record_data:
record_data["platform"] = "telegram"
# Convert int to str for backwards compatibility
for field in ["chat_id", "initial_msg_id", "last_msg_id"]:
if isinstance(record_data.get(field), int):
record_data[field] = str(record_data[field])
record = SessionRecord(**record_data)
self._sessions[sid] = record
# Index by initial and last message
self._msg_to_session[
self._make_key(
record.platform, record.chat_id, record.initial_msg_id
@@ -80,25 +80,33 @@ class SessionStore:
self._make_key(record.platform, record.chat_id, record.last_msg_id)
] = sid
# Load trees
self._trees = data.get("trees", {})
self._node_to_tree = data.get("node_to_tree", {})
logger.info(
f"Loaded {len(self._sessions)} sessions from {self.storage_path}"
f"Loaded {len(self._sessions)} sessions and {len(self._trees)} trees from {self.storage_path}"
)
except Exception as e:
logger.error(f"Failed to load sessions: {e}")
def _save(self) -> None:
"""Persist sessions to disk."""
"""Persist sessions and trees to disk."""
try:
data = {
"sessions": {
sid: asdict(record) for sid, record in self._sessions.items()
}
},
"trees": self._trees,
"node_to_tree": self._node_to_tree,
}
with open(self.storage_path, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
except Exception as e:
logger.error(f"Failed to save sessions: {e}")
# ==================== Session Methods ====================
def save_session(
self,
session_id: str,
@@ -106,15 +114,7 @@ class SessionStore:
initial_msg_id: str,
platform: str = "telegram",
) -> None:
"""
Save a new session mapping.
Args:
session_id: Claude CLI session ID
chat_id: Chat ID (platform-specific)
initial_msg_id: The message ID that started this session
platform: Messaging platform name
"""
"""Save a new session mapping."""
with self._lock:
now = datetime.utcnow().isoformat()
record = SessionRecord(
@@ -138,65 +138,38 @@ class SessionStore:
def get_session_by_msg(
self, chat_id: str, msg_id: str, platform: str = "telegram"
) -> Optional[str]:
"""
Look up a session ID by a message that's part of that session.
Args:
chat_id: Chat ID
msg_id: Message ID to look up
platform: Messaging platform name
Returns:
Session ID if found, None otherwise
"""
"""Look up a session ID by a message that's part of that session."""
with self._lock:
key = self._make_key(platform, str(chat_id), str(msg_id))
return self._msg_to_session.get(key)
def update_last_message(self, session_id: str, msg_id: str) -> None:
"""
Update the last message ID for a session.
Args:
session_id: Claude session ID
msg_id: New last message ID
"""
"""Update the last message ID for a session."""
with self._lock:
if session_id not in self._sessions:
logger.warning(f"Session {session_id} not found for update")
return
record = self._sessions[session_id]
# Update record
record.last_msg_id = str(msg_id)
record.updated_at = datetime.utcnow().isoformat()
# Update index - add new key, keep old one for chain lookups
new_key = self._make_key(record.platform, record.chat_id, str(msg_id))
self._msg_to_session[new_key] = session_id
self._save()
logger.debug(f"Updated session {session_id} last_msg to {msg_id}")
def rename_session(self, old_id: str, new_id: str) -> bool:
"""
Rename a session ID, migrating all message mappings.
This is crucial for handing over "pending_" sessions to real Claude session IDs.
"""
"""Rename a session ID, migrating all message mappings."""
with self._lock:
if old_id not in self._sessions:
logger.warning(f"Session {old_id} not found for rename to {new_id}")
return False
# Transfer record
record = self._sessions.pop(old_id)
record.session_id = new_id
record.updated_at = datetime.utcnow().isoformat()
self._sessions[new_id] = record
# Update all message mappings pointing to the old ID
items_to_update = [
k for k, v in self._msg_to_session.items() if v == old_id
]
@@ -215,12 +188,7 @@ class SessionStore:
return self._sessions.get(session_id)
def cleanup_old_sessions(self, max_age_days: int = 30) -> int:
"""
Remove sessions older than max_age_days.
Returns:
Number of sessions removed
"""
"""Remove sessions older than max_age_days."""
with self._lock:
cutoff = datetime.utcnow()
removed = 0
@@ -237,7 +205,6 @@ class SessionStore:
for sid in to_remove:
record = self._sessions.pop(sid)
# Remove index entries
self._msg_to_session.pop(
self._make_key(
record.platform, record.chat_id, record.initial_msg_id
@@ -255,3 +222,94 @@ class SessionStore:
logger.info(f"Cleaned up {removed} old sessions")
return removed
# ==================== Tree Methods ====================
def save_tree(self, root_id: str, tree_data: dict) -> None:
"""
Save a message tree.
Args:
root_id: Root node ID of the tree
tree_data: Serialized tree data from tree.to_dict()
"""
with self._lock:
self._trees[root_id] = tree_data
# Update node-to-tree mapping
for node_id in tree_data.get("nodes", {}).keys():
self._node_to_tree[node_id] = root_id
self._save()
logger.debug(f"Saved tree {root_id}")
def get_tree(self, root_id: str) -> Optional[dict]:
"""Get a tree by its root ID."""
with self._lock:
return self._trees.get(root_id)
def get_tree_by_node(self, node_id: str) -> Optional[dict]:
"""Get the tree containing a node."""
with self._lock:
root_id = self._node_to_tree.get(node_id)
if not root_id:
return None
return self._trees.get(root_id)
def get_tree_root_for_node(self, node_id: str) -> Optional[str]:
"""Get the root ID of the tree containing a node."""
with self._lock:
return self._node_to_tree.get(node_id)
def register_node(self, node_id: str, root_id: str) -> None:
"""Register a node ID to a tree root."""
with self._lock:
self._node_to_tree[node_id] = root_id
self._save()
def update_tree_node(self, root_id: str, node_id: str, node_data: dict) -> None:
"""Update a specific node in a tree."""
with self._lock:
if root_id not in self._trees:
logger.warning(f"Tree {root_id} not found")
return
if "nodes" not in self._trees[root_id]:
self._trees[root_id]["nodes"] = {}
self._trees[root_id]["nodes"][node_id] = node_data
self._node_to_tree[node_id] = root_id
self._save()
def cleanup_old_trees(self, max_age_days: int = 30) -> int:
"""Remove trees older than max_age_days."""
with self._lock:
cutoff = datetime.utcnow()
removed = 0
to_remove = []
for root_id, tree_data in self._trees.items():
try:
nodes = tree_data.get("nodes", {})
root_node = nodes.get(root_id, {})
created_str = root_node.get("created_at")
if created_str:
created = datetime.fromisoformat(created_str)
age_days = (cutoff - created).days
if age_days > max_age_days:
to_remove.append(root_id)
except Exception:
pass
for root_id in to_remove:
tree_data = self._trees.pop(root_id)
# Remove node mappings
for node_id in tree_data.get("nodes", {}).keys():
self._node_to_tree.pop(node_id, None)
removed += 1
if removed:
self._save()
logger.info(f"Cleaned up {removed} old trees")
return removed