mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-06-01 22:09:04 +02:00
Major refactor done by itself
This commit is contained in:
@@ -0,0 +1,20 @@
|
||||
"""API layer for Claude Code Proxy."""
|
||||
|
||||
from .app import create_app, app
|
||||
from .models import (
|
||||
MessagesRequest,
|
||||
MessagesResponse,
|
||||
TokenCountRequest,
|
||||
TokenCountResponse,
|
||||
)
|
||||
from .dependencies import get_provider
|
||||
|
||||
__all__ = [
|
||||
"create_app",
|
||||
"app",
|
||||
"MessagesRequest",
|
||||
"MessagesResponse",
|
||||
"TokenCountRequest",
|
||||
"TokenCountResponse",
|
||||
"get_provider",
|
||||
]
|
||||
+102
@@ -0,0 +1,102 @@
|
||||
"""FastAPI application factory and configuration."""
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from .routes import router
|
||||
from .dependencies import cleanup_provider
|
||||
from providers.exceptions import ProviderError
|
||||
from config.settings import get_settings
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.FileHandler("server.log", encoding="utf-8", mode="w")],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Suppress noisy uvicorn logs
|
||||
logging.getLogger("uvicorn").setLevel(logging.WARNING)
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
logging.getLogger("uvicorn.error").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan manager."""
|
||||
settings = get_settings()
|
||||
logger.info("Starting Claude Code Proxy...")
|
||||
|
||||
# Initialize messaging platform if configured
|
||||
messaging_platform = None
|
||||
try:
|
||||
if settings.telegram_api_id and settings.telegram_api_hash:
|
||||
from messaging.telegram import TelegramPlatform
|
||||
|
||||
messaging_platform = TelegramPlatform()
|
||||
await messaging_platform.start()
|
||||
logger.info("Telegram platform started")
|
||||
except ImportError:
|
||||
logger.warning("Messaging module not yet available, skipping Telegram init")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start messaging platform: {e}")
|
||||
|
||||
# Store in app state for access in routes
|
||||
app.state.messaging_platform = messaging_platform
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
if messaging_platform:
|
||||
await messaging_platform.stop()
|
||||
await cleanup_provider()
|
||||
logger.info("Server shutting down...")
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create and configure the FastAPI application."""
|
||||
app = FastAPI(
|
||||
title="Claude Code Proxy",
|
||||
version="2.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Register routes
|
||||
app.include_router(router)
|
||||
|
||||
# Exception handlers
|
||||
@app.exception_handler(ProviderError)
|
||||
async def provider_error_handler(request: Request, exc: ProviderError):
|
||||
"""Handle provider-specific errors and return Anthropic format."""
|
||||
logger.error(f"Provider Error: {exc.error_type} - {exc.message}")
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=exc.to_anthropic_format(),
|
||||
)
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def general_error_handler(request: Request, exc: Exception):
|
||||
"""Handle general errors and return Anthropic format."""
|
||||
logger.error(f"General Error: {str(exc)}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": "api_error",
|
||||
"message": "An unexpected error occurred.",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# Default app instance for uvicorn
|
||||
app = create_app()
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Dependency injection for FastAPI."""
|
||||
|
||||
from typing import Optional
|
||||
from providers.base import ProviderConfig
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
from config.settings import get_settings
|
||||
|
||||
# Global provider instance (singleton)
|
||||
_provider: Optional[NvidiaNimProvider] = None
|
||||
|
||||
|
||||
def get_provider() -> NvidiaNimProvider:
|
||||
"""Get or create the NvidiaNimProvider instance."""
|
||||
global _provider
|
||||
if _provider is None:
|
||||
settings = get_settings()
|
||||
config = ProviderConfig(
|
||||
api_key=settings.nvidia_nim_api_key,
|
||||
base_url=settings.nvidia_nim_base_url,
|
||||
rate_limit=settings.nvidia_nim_rate_limit,
|
||||
rate_window=settings.nvidia_nim_rate_window,
|
||||
)
|
||||
_provider = NvidiaNimProvider(config)
|
||||
return _provider
|
||||
|
||||
|
||||
async def cleanup_provider():
|
||||
"""Cleanup provider resources."""
|
||||
global _provider
|
||||
if _provider and hasattr(_provider, "_client"):
|
||||
await _provider._client.aclose()
|
||||
_provider = None
|
||||
+177
@@ -0,0 +1,177 @@
|
||||
"""Pydantic models for API requests and responses."""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Union, Literal
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
|
||||
from config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Content Block Types
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class ContentBlockText(BaseModel):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class ContentBlockImage(BaseModel):
|
||||
type: Literal["image"]
|
||||
source: Dict[str, Any]
|
||||
|
||||
|
||||
class ContentBlockToolUse(BaseModel):
|
||||
type: Literal["tool_use"]
|
||||
id: str
|
||||
name: str
|
||||
input: Dict[str, Any]
|
||||
|
||||
|
||||
class ContentBlockToolResult(BaseModel):
|
||||
type: Literal["tool_result"]
|
||||
tool_use_id: str
|
||||
content: Union[str, List[Dict[str, Any]], Dict[str, Any], List[Any], Any]
|
||||
|
||||
|
||||
class ContentBlockThinking(BaseModel):
|
||||
type: Literal["thinking"]
|
||||
thinking: str
|
||||
|
||||
|
||||
class SystemContent(BaseModel):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Message Types
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: Literal["user", "assistant"]
|
||||
content: Union[
|
||||
str,
|
||||
List[
|
||||
Union[
|
||||
ContentBlockText,
|
||||
ContentBlockImage,
|
||||
ContentBlockToolUse,
|
||||
ContentBlockToolResult,
|
||||
ContentBlockThinking,
|
||||
]
|
||||
],
|
||||
]
|
||||
reasoning_content: Optional[str] = None
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
input_schema: Dict[str, Any]
|
||||
|
||||
|
||||
class ThinkingConfig(BaseModel):
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Request/Response Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MessagesRequest(BaseModel):
|
||||
model: str
|
||||
max_tokens: int
|
||||
messages: List[Message]
|
||||
system: Optional[Union[str, List[SystemContent]]] = None
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
stream: Optional[bool] = False
|
||||
temperature: Optional[float] = 1.0
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
tools: Optional[List[Tool]] = None
|
||||
tool_choice: Optional[Dict[str, Any]] = None
|
||||
thinking: Optional[ThinkingConfig] = None
|
||||
extra_body: Optional[Dict[str, Any]] = None
|
||||
original_model: Optional[str] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def map_model(self) -> "MessagesRequest":
|
||||
settings = get_settings()
|
||||
if self.original_model is None:
|
||||
self.original_model = self.model
|
||||
|
||||
clean_v = self.model
|
||||
for prefix in ["anthropic/", "openai/", "gemini/"]:
|
||||
if clean_v.startswith(prefix):
|
||||
clean_v = clean_v[len(prefix) :]
|
||||
break
|
||||
|
||||
if "haiku" in clean_v.lower():
|
||||
self.model = settings.small_model
|
||||
elif "sonnet" in clean_v.lower() or "opus" in clean_v.lower():
|
||||
self.model = settings.big_model
|
||||
|
||||
if self.model != self.original_model:
|
||||
logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class TokenCountRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Message]
|
||||
system: Optional[Union[str, List[SystemContent]]] = None
|
||||
tools: Optional[List[Tool]] = None
|
||||
thinking: Optional[ThinkingConfig] = None
|
||||
tool_choice: Optional[Dict[str, Any]] = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model_field(cls, v, info):
|
||||
settings = get_settings()
|
||||
clean_v = v
|
||||
for prefix in ["anthropic/", "openai/", "gemini/"]:
|
||||
if clean_v.startswith(prefix):
|
||||
clean_v = clean_v[len(prefix) :]
|
||||
break
|
||||
|
||||
if "haiku" in clean_v.lower():
|
||||
return settings.small_model
|
||||
elif "sonnet" in clean_v.lower() or "opus" in clean_v.lower():
|
||||
return settings.big_model
|
||||
return v
|
||||
|
||||
|
||||
class TokenCountResponse(BaseModel):
|
||||
input_tokens: int
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
input_tokens: int
|
||||
output_tokens: int
|
||||
cache_creation_input_tokens: int = 0
|
||||
cache_read_input_tokens: int = 0
|
||||
|
||||
|
||||
class MessagesResponse(BaseModel):
|
||||
id: str
|
||||
model: str
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: List[
|
||||
Union[
|
||||
ContentBlockText, ContentBlockToolUse, ContentBlockThinking, Dict[str, Any]
|
||||
]
|
||||
]
|
||||
type: Literal["message"] = "message"
|
||||
stop_reason: Optional[
|
||||
Literal["end_turn", "max_tokens", "stop_sequence", "tool_use"]
|
||||
] = None
|
||||
stop_sequence: Optional[str] = None
|
||||
usage: Usage
|
||||
+253
@@ -0,0 +1,253 @@
|
||||
"""FastAPI route handlers."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import tiktoken
|
||||
from fastapi import APIRouter, Request, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from .models import (
|
||||
MessagesRequest,
|
||||
MessagesResponse,
|
||||
TokenCountRequest,
|
||||
TokenCountResponse,
|
||||
Usage,
|
||||
)
|
||||
from .dependencies import get_provider
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
from providers.exceptions import ProviderError
|
||||
from providers.logging_utils import log_request_compact
|
||||
from config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ENCODER = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def extract_command_prefix(command: str) -> str:
|
||||
"""Extract the command prefix for fast prefix detection."""
|
||||
import shlex
|
||||
|
||||
if "`" in command or "$(" in command:
|
||||
return "command_injection_detected"
|
||||
|
||||
try:
|
||||
parts = shlex.split(command)
|
||||
if not parts:
|
||||
return "none"
|
||||
|
||||
env_prefix = []
|
||||
cmd_start = 0
|
||||
for i, part in enumerate(parts):
|
||||
if "=" in part and not part.startswith("-"):
|
||||
env_prefix.append(part)
|
||||
cmd_start = i + 1
|
||||
else:
|
||||
break
|
||||
|
||||
if cmd_start >= len(parts):
|
||||
return "none"
|
||||
|
||||
cmd_parts = parts[cmd_start:]
|
||||
if not cmd_parts:
|
||||
return "none"
|
||||
|
||||
first_word = cmd_parts[0]
|
||||
two_word_commands = {
|
||||
"git",
|
||||
"npm",
|
||||
"docker",
|
||||
"kubectl",
|
||||
"cargo",
|
||||
"go",
|
||||
"pip",
|
||||
"yarn",
|
||||
}
|
||||
|
||||
if first_word in two_word_commands and len(cmd_parts) > 1:
|
||||
second_word = cmd_parts[1]
|
||||
if not second_word.startswith("-"):
|
||||
return f"{first_word} {second_word}"
|
||||
return first_word
|
||||
return first_word if not env_prefix else " ".join(env_prefix) + " " + first_word
|
||||
|
||||
except ValueError:
|
||||
return command.split()[0] if command.split() else "none"
|
||||
|
||||
|
||||
def is_prefix_detection_request(request_data: MessagesRequest) -> tuple[bool, str]:
|
||||
"""Check if this is a fast prefix detection request."""
|
||||
if len(request_data.messages) != 1 or request_data.messages[0].role != "user":
|
||||
return False, ""
|
||||
|
||||
msg = request_data.messages[0]
|
||||
content = ""
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
elif isinstance(msg.content, list):
|
||||
for block in msg.content:
|
||||
if hasattr(block, "text"):
|
||||
content += block.text
|
||||
|
||||
if "<policy_spec>" in content and "Command:" in content:
|
||||
try:
|
||||
cmd_start = content.rfind("Command:") + len("Command:")
|
||||
return True, content[cmd_start:].strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False, ""
|
||||
|
||||
|
||||
def get_token_count(messages, system=None, tools=None) -> int:
|
||||
"""Estimate token count for a request."""
|
||||
total_tokens = 0
|
||||
|
||||
if system:
|
||||
if isinstance(system, str):
|
||||
total_tokens += len(ENCODER.encode(system))
|
||||
elif isinstance(system, list):
|
||||
for block in system:
|
||||
if hasattr(block, "text"):
|
||||
total_tokens += len(ENCODER.encode(block.text))
|
||||
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, str):
|
||||
total_tokens += len(ENCODER.encode(msg.content))
|
||||
elif isinstance(msg.content, list):
|
||||
for block in msg.content:
|
||||
b_type = getattr(block, "type", None)
|
||||
|
||||
if b_type == "text":
|
||||
total_tokens += len(ENCODER.encode(getattr(block, "text", "")))
|
||||
elif b_type == "thinking":
|
||||
total_tokens += len(ENCODER.encode(getattr(block, "thinking", "")))
|
||||
elif b_type == "tool_use":
|
||||
name = getattr(block, "name", "")
|
||||
inp = getattr(block, "input", {})
|
||||
total_tokens += len(ENCODER.encode(name))
|
||||
total_tokens += len(ENCODER.encode(json.dumps(inp)))
|
||||
total_tokens += 10
|
||||
elif b_type == "tool_result":
|
||||
content = getattr(block, "content", "")
|
||||
if isinstance(content, str):
|
||||
total_tokens += len(ENCODER.encode(content))
|
||||
else:
|
||||
total_tokens += len(ENCODER.encode(json.dumps(content)))
|
||||
total_tokens += 5
|
||||
|
||||
if tools:
|
||||
for tool in tools:
|
||||
tool_str = (
|
||||
tool.name + (tool.description or "") + json.dumps(tool.input_schema)
|
||||
)
|
||||
total_tokens += len(ENCODER.encode(tool_str))
|
||||
|
||||
total_tokens += len(messages) * 3
|
||||
if tools:
|
||||
total_tokens += len(tools) * 5
|
||||
|
||||
return max(1, total_tokens)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Routes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/v1/messages")
|
||||
async def create_message(
|
||||
request_data: MessagesRequest,
|
||||
raw_request: Request,
|
||||
provider: NvidiaNimProvider = Depends(get_provider),
|
||||
):
|
||||
"""Create a message (streaming or non-streaming)."""
|
||||
settings = get_settings()
|
||||
|
||||
try:
|
||||
if settings.fast_prefix_detection:
|
||||
is_prefix_req, command = is_prefix_detection_request(request_data)
|
||||
if is_prefix_req:
|
||||
return MessagesResponse(
|
||||
id=f"msg_{uuid.uuid4()}",
|
||||
model=request_data.model,
|
||||
content=[{"type": "text", "text": extract_command_prefix(command)}],
|
||||
stop_reason="end_turn",
|
||||
usage=Usage(input_tokens=100, output_tokens=5),
|
||||
)
|
||||
|
||||
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
||||
log_request_compact(logger, request_id, request_data)
|
||||
|
||||
if request_data.stream:
|
||||
input_tokens = get_token_count(
|
||||
request_data.messages, request_data.system, request_data.tools
|
||||
)
|
||||
return StreamingResponse(
|
||||
provider.stream_response(request_data, input_tokens=input_tokens),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
else:
|
||||
response_json = await provider.complete(request_data)
|
||||
return provider.convert_response(response_json, request_data)
|
||||
|
||||
except ProviderError:
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(f"Error: {str(e)}\n{traceback.format_exc()}")
|
||||
raise HTTPException(status_code=getattr(e, "status_code", 500), detail=str(e))
|
||||
|
||||
|
||||
@router.post("/v1/messages/count_tokens")
|
||||
async def count_tokens(request_data: TokenCountRequest):
|
||||
"""Count tokens for a request."""
|
||||
try:
|
||||
return TokenCountResponse(
|
||||
input_tokens=get_token_count(
|
||||
request_data.messages, request_data.system, request_data.tools
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def root():
|
||||
"""Root endpoint."""
|
||||
settings = get_settings()
|
||||
return {
|
||||
"status": "ok",
|
||||
"provider": "nvidia_nim",
|
||||
"big_model": settings.big_model,
|
||||
"small_model": settings.small_model,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@router.post("/stop")
|
||||
async def stop_cli():
|
||||
"""Stop all CLI sessions."""
|
||||
from cli import CLISessionManager
|
||||
|
||||
# This will be properly injected when messaging layer is complete
|
||||
return {"status": "not_implemented"}
|
||||
@@ -0,0 +1,7 @@
|
||||
"""CLI integration for Claude Code."""
|
||||
|
||||
from .session import CLISession
|
||||
from .manager import CLISessionManager
|
||||
from .parser import CLIParser
|
||||
|
||||
__all__ = ["CLISession", "CLISessionManager", "CLIParser"]
|
||||
+166
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
CLI Session Manager for Multi-Instance Claude CLI Support
|
||||
|
||||
Manages a pool of CLISession instances, each handling one conversation.
|
||||
This enables true parallel processing where multiple conversations run
|
||||
simultaneously in separate CLI processes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
import logging
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
|
||||
from .session import CLISession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CLISessionManager:
|
||||
"""
|
||||
Manages multiple CLISession instances for parallel conversation processing.
|
||||
|
||||
Each new conversation gets its own CLISession with its own subprocess.
|
||||
Replies to existing conversations reuse the same CLISession instance.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_path: str,
|
||||
api_url: str,
|
||||
allowed_dirs: Optional[List[str]] = None,
|
||||
max_sessions: int = 10,
|
||||
):
|
||||
"""
|
||||
Initialize the session manager.
|
||||
|
||||
Args:
|
||||
workspace_path: Working directory for CLI processes
|
||||
api_url: API URL for the proxy
|
||||
allowed_dirs: Directories the CLI is allowed to access
|
||||
max_sessions: Maximum concurrent sessions
|
||||
"""
|
||||
self.workspace = workspace_path
|
||||
self.api_url = api_url
|
||||
self.allowed_dirs = allowed_dirs or []
|
||||
self.max_sessions = max_sessions
|
||||
|
||||
self._sessions: Dict[str, CLISession] = {}
|
||||
self._pending_sessions: Dict[str, CLISession] = {}
|
||||
self._temp_to_real: Dict[str, str] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
logger.info(f"CLISessionManager initialized (max_sessions={max_sessions})")
|
||||
|
||||
async def get_or_create_session(
|
||||
self, session_id: Optional[str] = None
|
||||
) -> Tuple[CLISession, str, bool]:
|
||||
"""
|
||||
Get an existing session or create a new one.
|
||||
|
||||
Returns:
|
||||
Tuple of (CLISession instance, session_id, is_new_session)
|
||||
"""
|
||||
async with self._lock:
|
||||
if session_id:
|
||||
lookup_id = self._temp_to_real.get(session_id, session_id)
|
||||
|
||||
if lookup_id in self._sessions:
|
||||
return self._sessions[lookup_id], lookup_id, False
|
||||
if lookup_id in self._pending_sessions:
|
||||
return self._pending_sessions[lookup_id], lookup_id, False
|
||||
|
||||
total_sessions = len(self._sessions) + len(self._pending_sessions)
|
||||
if total_sessions >= self.max_sessions:
|
||||
await self._cleanup_idle_sessions_unlocked()
|
||||
total_sessions = len(self._sessions) + len(self._pending_sessions)
|
||||
if total_sessions >= self.max_sessions:
|
||||
raise RuntimeError(
|
||||
f"Maximum concurrent sessions ({self.max_sessions}) reached."
|
||||
)
|
||||
|
||||
temp_id = session_id if session_id else f"pending_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
new_session = CLISession(
|
||||
workspace_path=self.workspace,
|
||||
api_url=self.api_url,
|
||||
allowed_dirs=self.allowed_dirs,
|
||||
)
|
||||
self._pending_sessions[temp_id] = new_session
|
||||
logger.info(f"Created new session: {temp_id}")
|
||||
|
||||
return new_session, temp_id, True
|
||||
|
||||
async def register_real_session_id(
|
||||
self, temp_id: str, real_session_id: str
|
||||
) -> bool:
|
||||
"""Register the real session ID from CLI output."""
|
||||
async with self._lock:
|
||||
if temp_id not in self._pending_sessions:
|
||||
logger.warning(f"Temp session {temp_id} not found")
|
||||
return False
|
||||
|
||||
session = self._pending_sessions.pop(temp_id)
|
||||
self._sessions[real_session_id] = session
|
||||
self._temp_to_real[temp_id] = real_session_id
|
||||
|
||||
logger.info(f"Registered session: {temp_id} -> {real_session_id}")
|
||||
return True
|
||||
|
||||
async def get_real_session_id(self, temp_id: str) -> Optional[str]:
|
||||
"""Get the real session ID for a temporary ID."""
|
||||
async with self._lock:
|
||||
return self._temp_to_real.get(temp_id)
|
||||
|
||||
async def remove_session(self, session_id: str) -> bool:
|
||||
"""Remove a session from the manager."""
|
||||
async with self._lock:
|
||||
if session_id in self._pending_sessions:
|
||||
session = self._pending_sessions.pop(session_id)
|
||||
await session.stop()
|
||||
return True
|
||||
|
||||
if session_id in self._sessions:
|
||||
session = self._sessions.pop(session_id)
|
||||
await session.stop()
|
||||
for temp, real in list(self._temp_to_real.items()):
|
||||
if real == session_id:
|
||||
del self._temp_to_real[temp]
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _cleanup_idle_sessions_unlocked(self):
|
||||
"""Clean up idle sessions (must hold lock)."""
|
||||
idle = [sid for sid, s in self._sessions.items() if not s.is_busy]
|
||||
|
||||
for sid in idle[:3]:
|
||||
session = self._sessions.pop(sid)
|
||||
await session.stop()
|
||||
logger.debug(f"Cleaned up idle session: {sid}")
|
||||
|
||||
async def stop_all(self):
|
||||
"""Stop all sessions."""
|
||||
async with self._lock:
|
||||
all_sessions = list(self._sessions.values()) + list(
|
||||
self._pending_sessions.values()
|
||||
)
|
||||
for session in all_sessions:
|
||||
try:
|
||||
await session.stop()
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping session: {e}")
|
||||
|
||||
self._sessions.clear()
|
||||
self._pending_sessions.clear()
|
||||
self._temp_to_real.clear()
|
||||
logger.info("All sessions stopped")
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""Get session statistics."""
|
||||
return {
|
||||
"active_sessions": len(self._sessions),
|
||||
"pending_sessions": len(self._pending_sessions),
|
||||
"max_sessions": self.max_sessions,
|
||||
"busy_count": sum(1 for s in self._sessions.values() if s.is_busy),
|
||||
}
|
||||
+107
@@ -0,0 +1,107 @@
|
||||
"""CLI event parser for Claude Code CLI output."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CLIParser:
|
||||
"""Helper to structure raw CLI events."""
|
||||
|
||||
@staticmethod
|
||||
def parse_event(event: Dict) -> Optional[Dict]:
|
||||
"""
|
||||
Parse a CLI event and return a structured result.
|
||||
|
||||
Args:
|
||||
event: Raw event dictionary from CLI
|
||||
|
||||
Returns:
|
||||
Parsed event dict or None if not a recognized event type
|
||||
"""
|
||||
if not isinstance(event, dict):
|
||||
return None
|
||||
|
||||
etype = event.get("type")
|
||||
|
||||
# 1. Handle full messages (assistant or result)
|
||||
msg_obj = None
|
||||
if etype == "assistant":
|
||||
msg_obj = event.get("message")
|
||||
elif etype == "result":
|
||||
res = event.get("result")
|
||||
if isinstance(res, dict):
|
||||
msg_obj = res.get("message")
|
||||
if not msg_obj:
|
||||
msg_obj = event.get("message")
|
||||
|
||||
if msg_obj and isinstance(msg_obj, dict):
|
||||
content = msg_obj.get("content", [])
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
thinking_parts = []
|
||||
tool_calls = []
|
||||
for c in content:
|
||||
if not isinstance(c, dict):
|
||||
continue
|
||||
ctype = c.get("type")
|
||||
if ctype == "text":
|
||||
parts.append(c.get("text", ""))
|
||||
elif ctype == "thinking":
|
||||
thinking_parts.append(c.get("thinking", ""))
|
||||
elif ctype == "tool_use":
|
||||
tool_calls.append(c)
|
||||
|
||||
if tool_calls:
|
||||
# Check for subagents (Task tool)
|
||||
subagents = [
|
||||
t.get("input", {}).get("description", "Subagent")
|
||||
for t in tool_calls
|
||||
if t.get("name") == "Task"
|
||||
]
|
||||
if subagents:
|
||||
return {"type": "subagent_start", "tasks": subagents}
|
||||
return {"type": "tool_start", "tools": tool_calls}
|
||||
|
||||
# Return combined result if we have content
|
||||
result = {}
|
||||
if thinking_parts:
|
||||
result["thinking"] = "\n".join(thinking_parts)
|
||||
if parts:
|
||||
result["text"] = "".join(parts)
|
||||
if result:
|
||||
result["type"] = "content"
|
||||
return result
|
||||
|
||||
# 2. Handle streaming deltas
|
||||
if etype == "content_block_delta":
|
||||
delta = event.get("delta", {})
|
||||
if not isinstance(delta, dict):
|
||||
return None
|
||||
if delta.get("type") == "text_delta":
|
||||
return {"type": "content", "text": delta.get("text", "")}
|
||||
if delta.get("type") == "thinking_delta":
|
||||
return {"type": "thinking", "text": delta.get("thinking", "")}
|
||||
|
||||
# 3. Handle tool usage start
|
||||
if etype == "content_block_start":
|
||||
block = event.get("content_block", {})
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
if block.get("name") == "Task":
|
||||
desc = block.get("input", {}).get("description", "Subagent")
|
||||
return {"type": "subagent_start", "tasks": [desc]}
|
||||
return {"type": "tool_start", "tools": [block]}
|
||||
|
||||
# 4. Handle errors and exit
|
||||
if etype == "error":
|
||||
err = event.get("error")
|
||||
msg = err.get("message") if isinstance(err, dict) else str(err)
|
||||
return {"type": "error", "message": msg}
|
||||
elif etype == "exit":
|
||||
return {
|
||||
"type": "complete",
|
||||
"status": "success" if event.get("code") == 0 else "failed",
|
||||
}
|
||||
|
||||
return None
|
||||
+213
@@ -0,0 +1,213 @@
|
||||
"""Claude Code CLI session management."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncGenerator, Optional, Dict, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CLISession:
|
||||
"""Manages a single persistent Claude Code CLI subprocess."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_path: str,
|
||||
api_url: str,
|
||||
allowed_dirs: Optional[List[str]] = None,
|
||||
):
|
||||
self.workspace = os.path.normpath(os.path.abspath(workspace_path))
|
||||
self.api_url = api_url
|
||||
self.allowed_dirs = [os.path.normpath(d) for d in (allowed_dirs or [])]
|
||||
self.process: Optional[asyncio.subprocess.Process] = None
|
||||
self.current_session_id: Optional[str] = None
|
||||
self._is_busy = False
|
||||
self._cli_lock = asyncio.Lock()
|
||||
|
||||
@property
|
||||
def is_busy(self) -> bool:
|
||||
"""Check if a task is currently running."""
|
||||
return self._is_busy
|
||||
|
||||
async def start_task(
|
||||
self, prompt: str, session_id: Optional[str] = None
|
||||
) -> AsyncGenerator[dict, None]:
|
||||
"""
|
||||
Start a new task or continue an existing session.
|
||||
|
||||
Args:
|
||||
prompt: The user's message/prompt
|
||||
session_id: Optional session ID to resume
|
||||
|
||||
Yields:
|
||||
Event dictionaries from the CLI
|
||||
"""
|
||||
async with self._cli_lock:
|
||||
self._is_busy = True
|
||||
env = os.environ.copy()
|
||||
|
||||
if "ANTHROPIC_API_KEY" not in env:
|
||||
env["ANTHROPIC_API_KEY"] = "sk-placeholder-key-for-proxy"
|
||||
|
||||
env["ANTHROPIC_API_URL"] = self.api_url
|
||||
if self.api_url.endswith("/v1"):
|
||||
env["ANTHROPIC_BASE_URL"] = self.api_url[:-3]
|
||||
else:
|
||||
env["ANTHROPIC_BASE_URL"] = self.api_url
|
||||
|
||||
env["TERM"] = "dumb"
|
||||
env["PYTHONIOENCODING"] = "utf-8"
|
||||
|
||||
# Build command
|
||||
if session_id and not session_id.startswith("pending_"):
|
||||
cmd = [
|
||||
"claude",
|
||||
"--resume",
|
||||
session_id,
|
||||
"-p",
|
||||
prompt,
|
||||
"--output-format",
|
||||
"stream-json",
|
||||
"--dangerously-skip-permissions",
|
||||
"--verbose",
|
||||
]
|
||||
logger.info(f"Resuming Claude session {session_id}")
|
||||
else:
|
||||
cmd = [
|
||||
"claude",
|
||||
"-p",
|
||||
prompt,
|
||||
"--output-format",
|
||||
"stream-json",
|
||||
"--dangerously-skip-permissions",
|
||||
"--verbose",
|
||||
]
|
||||
logger.info(f"Starting new Claude session")
|
||||
|
||||
if self.allowed_dirs:
|
||||
for d in self.allowed_dirs:
|
||||
cmd.extend(["--add-dir", d])
|
||||
|
||||
try:
|
||||
self.process = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
cwd=self.workspace,
|
||||
env=env,
|
||||
)
|
||||
|
||||
if not self.process or not self.process.stdout:
|
||||
yield {"type": "exit", "code": 1}
|
||||
return
|
||||
|
||||
session_id_extracted = False
|
||||
buffer = bytearray()
|
||||
|
||||
while True:
|
||||
chunk = await self.process.stdout.read(65536)
|
||||
if not chunk:
|
||||
if buffer:
|
||||
line_str = buffer.decode("utf-8", errors="replace").strip()
|
||||
if line_str:
|
||||
async for event in self._handle_line_gen(
|
||||
line_str, session_id_extracted
|
||||
):
|
||||
if event.get("type") == "session_info":
|
||||
session_id_extracted = True
|
||||
yield event
|
||||
break
|
||||
|
||||
buffer.extend(chunk)
|
||||
|
||||
while True:
|
||||
newline_pos = buffer.find(b"\n")
|
||||
if newline_pos == -1:
|
||||
break
|
||||
|
||||
line = buffer[:newline_pos]
|
||||
buffer = buffer[newline_pos + 1 :]
|
||||
|
||||
line_str = line.decode("utf-8", errors="replace").strip()
|
||||
if line_str:
|
||||
async for event in self._handle_line_gen(
|
||||
line_str, session_id_extracted
|
||||
):
|
||||
if event.get("type") == "session_info":
|
||||
session_id_extracted = True
|
||||
yield event
|
||||
|
||||
if self.process.stderr:
|
||||
stderr_output = await self.process.stderr.read()
|
||||
if stderr_output:
|
||||
logger.error(
|
||||
f"Claude CLI Stderr: {stderr_output.decode('utf-8', errors='replace')}"
|
||||
)
|
||||
|
||||
return_code = await self.process.wait()
|
||||
logger.info(f"Claude CLI exited with code {return_code}")
|
||||
yield {"type": "exit", "code": return_code}
|
||||
finally:
|
||||
self._is_busy = False
|
||||
|
||||
async def _handle_line_gen(
|
||||
self, line_str: str, session_id_extracted: bool
|
||||
) -> AsyncGenerator[dict, None]:
|
||||
"""Process a single line and yield events."""
|
||||
try:
|
||||
event = json.loads(line_str)
|
||||
if not session_id_extracted:
|
||||
extracted_id = self._extract_session_id(event)
|
||||
if extracted_id:
|
||||
self.current_session_id = extracted_id
|
||||
logger.info(f"Extracted session ID: {extracted_id}")
|
||||
yield {"type": "session_info", "session_id": extracted_id}
|
||||
|
||||
yield event
|
||||
except json.JSONDecodeError:
|
||||
logger.debug(f"Non-JSON output: {line_str[:100]}")
|
||||
yield {"type": "raw", "content": line_str}
|
||||
|
||||
def _extract_session_id(self, event: Dict) -> Optional[str]:
|
||||
"""Extract session ID from CLI event."""
|
||||
if not isinstance(event, dict):
|
||||
return None
|
||||
|
||||
if "session_id" in event:
|
||||
return event["session_id"]
|
||||
if "sessionId" in event:
|
||||
return event["sessionId"]
|
||||
|
||||
for key in ["init", "system", "result", "metadata"]:
|
||||
if key in event and isinstance(event[key], dict):
|
||||
nested = event[key]
|
||||
if "session_id" in nested:
|
||||
return nested["session_id"]
|
||||
if "sessionId" in nested:
|
||||
return nested["sessionId"]
|
||||
|
||||
if "conversation" in event and isinstance(event["conversation"], dict):
|
||||
conv = event["conversation"]
|
||||
if "id" in conv:
|
||||
return conv["id"]
|
||||
|
||||
return None
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the CLI process."""
|
||||
if self.process and self.process.returncode is None:
|
||||
try:
|
||||
logger.info(f"Stopping Claude CLI process {self.process.pid}")
|
||||
self.process.terminate()
|
||||
try:
|
||||
await asyncio.wait_for(self.process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
self.process.kill()
|
||||
await self.process.wait()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping process: {e}")
|
||||
return False
|
||||
return False
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Configuration management."""
|
||||
|
||||
from .settings import Settings, get_settings
|
||||
|
||||
__all__ = ["Settings", "get_settings"]
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Centralized configuration using Pydantic Settings."""
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from pydantic import field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# ==================== NVIDIA NIM Config ====================
|
||||
nvidia_nim_api_key: str = ""
|
||||
nvidia_nim_base_url: str = "https://integrate.api.nvidia.com/v1"
|
||||
|
||||
# ==================== Model Mapping ====================
|
||||
big_model: str = "moonshotai/kimi-k2-thinking"
|
||||
small_model: str = "moonshotai/kimi-k2-thinking"
|
||||
|
||||
# ==================== Rate Limiting ====================
|
||||
nvidia_nim_rate_limit: int = 40
|
||||
nvidia_nim_rate_window: int = 60
|
||||
|
||||
# ==================== Fast Prefix Detection ====================
|
||||
fast_prefix_detection: bool = True
|
||||
|
||||
# ==================== Logging ====================
|
||||
log_full_payloads: bool = False
|
||||
|
||||
# ==================== NIM Core Parameters ====================
|
||||
nvidia_nim_temperature: float = 1.0
|
||||
nvidia_nim_top_p: float = 1.0
|
||||
nvidia_nim_top_k: int = -1
|
||||
nvidia_nim_max_tokens: int = 81920
|
||||
nvidia_nim_presence_penalty: float = 0.0
|
||||
nvidia_nim_frequency_penalty: float = 0.0
|
||||
|
||||
# ==================== NIM Advanced Parameters ====================
|
||||
nvidia_nim_min_p: float = 0.0
|
||||
nvidia_nim_repetition_penalty: float = 1.0
|
||||
nvidia_nim_seed: Optional[int] = None
|
||||
nvidia_nim_stop: Optional[str] = None
|
||||
|
||||
# ==================== NIM Flag Parameters ====================
|
||||
nvidia_nim_parallel_tool_calls: bool = True
|
||||
nvidia_nim_return_tokens_as_token_ids: bool = False
|
||||
nvidia_nim_include_stop_str_in_output: bool = False
|
||||
nvidia_nim_ignore_eos: bool = False
|
||||
|
||||
nvidia_nim_min_tokens: int = 0
|
||||
nvidia_nim_chat_template: str = ""
|
||||
nvidia_nim_request_id: str = ""
|
||||
|
||||
# ==================== Thinking/Reasoning Parameters ====================
|
||||
nvidia_nim_reasoning_effort: str = "high"
|
||||
nvidia_nim_include_reasoning: bool = True
|
||||
|
||||
# ==================== Bot Wrapper Config ====================
|
||||
telegram_api_id: Optional[str] = None
|
||||
telegram_api_hash: Optional[str] = None
|
||||
allowed_telegram_user_id: Optional[str] = None
|
||||
claude_workspace: str = "./agent_workspace"
|
||||
allowed_dir: str = ""
|
||||
max_cli_sessions: int = 10
|
||||
wrapper_ws_url: str = "ws://localhost:8083/ws"
|
||||
|
||||
# ==================== Server ====================
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8082
|
||||
|
||||
# Handle empty strings for optional int fields
|
||||
@field_validator("nvidia_nim_seed", mode="before")
|
||||
@classmethod
|
||||
def parse_optional_int(cls, v):
|
||||
if v == "" or v is None:
|
||||
return None
|
||||
return int(v)
|
||||
|
||||
# Handle empty strings for optional string fields
|
||||
@field_validator(
|
||||
"nvidia_nim_stop",
|
||||
"telegram_api_id",
|
||||
"telegram_api_hash",
|
||||
"allowed_telegram_user_id",
|
||||
mode="before",
|
||||
)
|
||||
@classmethod
|
||||
def parse_optional_str(cls, v):
|
||||
if v == "":
|
||||
return None
|
||||
return v
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""Get cached settings instance."""
|
||||
return Settings()
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Platform-agnostic messaging layer."""
|
||||
|
||||
from .base import MessagingPlatform
|
||||
from .models import IncomingMessage, OutgoingMessage
|
||||
from .handler import ClaudeMessageHandler
|
||||
from .session import SessionStore
|
||||
from .queue import MessageQueueManager
|
||||
|
||||
__all__ = [
|
||||
"MessagingPlatform",
|
||||
"IncomingMessage",
|
||||
"OutgoingMessage",
|
||||
"ClaudeMessageHandler",
|
||||
"SessionStore",
|
||||
"MessageQueueManager",
|
||||
]
|
||||
@@ -0,0 +1,86 @@
|
||||
"""Abstract base class for messaging platforms."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Awaitable, Optional, Any
|
||||
from .models import IncomingMessage
|
||||
|
||||
|
||||
class MessagingPlatform(ABC):
|
||||
"""
|
||||
Base class for all messaging platform adapters.
|
||||
|
||||
Implement this to add support for Telegram, Discord, Slack, etc.
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""Initialize and connect to the messaging platform."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""Disconnect and cleanup resources."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to: Optional[str] = None,
|
||||
parse_mode: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Send a message to a chat.
|
||||
|
||||
Args:
|
||||
chat_id: The chat/channel ID to send to
|
||||
text: Message content
|
||||
reply_to: Optional message ID to reply to
|
||||
parse_mode: Optional formatting mode ("markdown", "html")
|
||||
|
||||
Returns:
|
||||
The message ID of the sent message
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
parse_mode: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Edit an existing message.
|
||||
|
||||
Args:
|
||||
chat_id: The chat/channel ID
|
||||
message_id: The message ID to edit
|
||||
text: New message content
|
||||
parse_mode: Optional formatting mode
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_message(
|
||||
self,
|
||||
handler: Callable[[IncomingMessage], Awaitable[None]],
|
||||
) -> None:
|
||||
"""
|
||||
Register a message handler callback.
|
||||
|
||||
The handler will be called for each incoming message.
|
||||
|
||||
Args:
|
||||
handler: Async function that processes incoming messages
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if the platform is connected."""
|
||||
return False
|
||||
@@ -0,0 +1,315 @@
|
||||
"""
|
||||
Claude Message Handler
|
||||
|
||||
Platform-agnostic Claude interaction logic.
|
||||
Handles the core workflow of processing user messages via Claude CLI.
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, List, Tuple, TYPE_CHECKING
|
||||
|
||||
from .base import MessagingPlatform
|
||||
from .models import IncomingMessage, MessageContext
|
||||
from .session import SessionStore
|
||||
from .queue import MessageQueueManager, QueuedMessage
|
||||
from cli import CLISession, CLISessionManager, CLIParser
|
||||
from config.settings import get_settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClaudeMessageHandler:
|
||||
"""
|
||||
Platform-agnostic handler for Claude interactions.
|
||||
|
||||
This class contains the core logic for:
|
||||
- Processing user messages
|
||||
- Managing Claude CLI sessions
|
||||
- Updating status messages
|
||||
- Handling tool calls and thinking
|
||||
|
||||
It works with any MessagingPlatform implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
platform: MessagingPlatform,
|
||||
cli_manager: CLISessionManager,
|
||||
session_store: SessionStore,
|
||||
message_queue: MessageQueueManager,
|
||||
):
|
||||
self.platform = platform
|
||||
self.cli_manager = cli_manager
|
||||
self.session_store = session_store
|
||||
self.message_queue = message_queue
|
||||
self._flood_wait_until = 0
|
||||
|
||||
async def handle_message(self, incoming: IncomingMessage) -> None:
|
||||
"""
|
||||
Main entry point for handling an incoming message.
|
||||
|
||||
Determines if this is a new session or continuation,
|
||||
sends status message, and queues for processing.
|
||||
"""
|
||||
# Check for commands
|
||||
if incoming.text == "/stop":
|
||||
await self._handle_stop_command(incoming)
|
||||
return
|
||||
|
||||
if incoming.text == "/stats":
|
||||
await self._handle_stats_command(incoming)
|
||||
return
|
||||
|
||||
# Filter out status messages (our own messages)
|
||||
if any(
|
||||
incoming.text.startswith(p)
|
||||
for p in ["⏳", "💭", "🔧", "✅", "❌", "🚀", "🤖", "📋", "📊", "🔄"]
|
||||
):
|
||||
return
|
||||
|
||||
# Check if this is a reply to an existing conversation
|
||||
session_id_to_resume = None
|
||||
if incoming.is_reply():
|
||||
session_id_to_resume = self.session_store.get_session_by_msg(
|
||||
incoming.chat_id,
|
||||
incoming.reply_to_message_id,
|
||||
incoming.platform,
|
||||
)
|
||||
if session_id_to_resume:
|
||||
logger.info(f"Found session {session_id_to_resume} for reply")
|
||||
|
||||
# Send initial status message
|
||||
status_text = self._get_initial_status(session_id_to_resume)
|
||||
status_msg_id = await self.platform.send_message(
|
||||
incoming.chat_id,
|
||||
status_text,
|
||||
reply_to=incoming.message_id,
|
||||
)
|
||||
|
||||
# Create queued message
|
||||
queued = QueuedMessage(
|
||||
incoming=incoming,
|
||||
status_message_id=status_msg_id,
|
||||
)
|
||||
|
||||
# Determine session ID for queuing
|
||||
if session_id_to_resume:
|
||||
queue_session_id = session_id_to_resume
|
||||
else:
|
||||
# New session - use temp ID
|
||||
queue_session_id = f"pending_{incoming.message_id}"
|
||||
# Pre-register so replies work immediately
|
||||
self.session_store.save_session(
|
||||
session_id=queue_session_id,
|
||||
chat_id=incoming.chat_id,
|
||||
initial_msg_id=incoming.message_id,
|
||||
platform=incoming.platform,
|
||||
)
|
||||
self.session_store.update_last_message(queue_session_id, status_msg_id)
|
||||
|
||||
# Enqueue for processing
|
||||
await self.message_queue.enqueue(
|
||||
session_id=queue_session_id,
|
||||
message=queued,
|
||||
processor=self._process_task,
|
||||
)
|
||||
|
||||
async def _process_task(
|
||||
self,
|
||||
session_id_to_resume: Optional[str],
|
||||
queued: QueuedMessage,
|
||||
) -> None:
|
||||
"""Core task processor - handles a single Claude CLI interaction."""
|
||||
incoming = queued.incoming
|
||||
status_msg_id = queued.status_message_id
|
||||
chat_id = incoming.chat_id
|
||||
|
||||
# Unified message accumulator
|
||||
message_parts: List[Tuple[str, str]] = []
|
||||
last_ui_update = 0.0
|
||||
captured_session_id = (
|
||||
session_id_to_resume
|
||||
if not session_id_to_resume.startswith("pending_")
|
||||
else None
|
||||
)
|
||||
temp_session_id = (
|
||||
session_id_to_resume
|
||||
if session_id_to_resume.startswith("pending_")
|
||||
else None
|
||||
)
|
||||
|
||||
async def update_ui(status: Optional[str] = None, force: bool = False) -> None:
|
||||
nonlocal last_ui_update
|
||||
now = time.time()
|
||||
|
||||
# Check flood wait
|
||||
if now < self._flood_wait_until:
|
||||
return
|
||||
|
||||
if not force and now - last_ui_update < 1.0:
|
||||
return
|
||||
|
||||
try:
|
||||
display = self._build_message(message_parts, status)
|
||||
if display:
|
||||
await self.platform.edit_message(
|
||||
chat_id, status_msg_id, display, parse_mode="markdown"
|
||||
)
|
||||
last_ui_update = now
|
||||
except Exception as e:
|
||||
logger.error(f"UI update failed: {e}")
|
||||
|
||||
try:
|
||||
# Get or create CLI session
|
||||
try:
|
||||
(
|
||||
cli_session,
|
||||
session_or_temp_id,
|
||||
is_new,
|
||||
) = await self.cli_manager.get_or_create_session(
|
||||
session_id=captured_session_id
|
||||
)
|
||||
if is_new:
|
||||
temp_session_id = session_or_temp_id
|
||||
else:
|
||||
captured_session_id = session_or_temp_id
|
||||
except RuntimeError as e:
|
||||
message_parts.append(("error", str(e)))
|
||||
await update_ui("⏳ **Session limit reached**", force=True)
|
||||
return
|
||||
|
||||
# Process CLI events
|
||||
async for event_data in cli_session.start_task(
|
||||
incoming.text, session_id=captured_session_id
|
||||
):
|
||||
if not isinstance(event_data, dict):
|
||||
continue
|
||||
|
||||
# Handle session_info event
|
||||
if event_data.get("type") == "session_info":
|
||||
real_session_id = event_data.get("session_id")
|
||||
if real_session_id and temp_session_id:
|
||||
await self.cli_manager.register_real_session_id(
|
||||
temp_session_id, real_session_id
|
||||
)
|
||||
captured_session_id = real_session_id
|
||||
self.session_store.save_session(
|
||||
session_id=real_session_id,
|
||||
chat_id=chat_id,
|
||||
initial_msg_id=incoming.message_id,
|
||||
platform=incoming.platform,
|
||||
)
|
||||
continue
|
||||
|
||||
parsed = CLIParser.parse_event(event_data)
|
||||
|
||||
if not parsed:
|
||||
continue
|
||||
|
||||
if parsed["type"] == "thinking":
|
||||
message_parts.append(("thinking", parsed["text"]))
|
||||
await update_ui("🧠 **Claude is thinking...**")
|
||||
|
||||
elif parsed["type"] == "content":
|
||||
if parsed.get("text"):
|
||||
if message_parts and message_parts[-1][0] == "content":
|
||||
msg_type, content = message_parts[-1]
|
||||
message_parts[-1] = ("content", content + parsed["text"])
|
||||
else:
|
||||
message_parts.append(("content", parsed["text"]))
|
||||
await update_ui("🧠 **Claude is working...**")
|
||||
|
||||
elif parsed["type"] == "tool_start":
|
||||
names = [t.get("name") for t in parsed["tools"]]
|
||||
message_parts.append(("tool", ", ".join(names)))
|
||||
await update_ui("⏳ **Executing tools...**")
|
||||
|
||||
elif parsed["type"] == "complete":
|
||||
if not message_parts:
|
||||
message_parts.append(("content", "Done."))
|
||||
await update_ui("✅ **Complete**", force=True)
|
||||
|
||||
# Update session's last message
|
||||
if captured_session_id:
|
||||
self.session_store.update_last_message(
|
||||
captured_session_id, status_msg_id
|
||||
)
|
||||
|
||||
elif parsed["type"] == "error":
|
||||
message_parts.append(
|
||||
("error", parsed.get("message", "Unknown error"))
|
||||
)
|
||||
await update_ui("❌ **Error**", force=True)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
message_parts.append(("error", "Task was cancelled"))
|
||||
await update_ui("❌ **Cancelled**", force=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Task failed: {e}")
|
||||
message_parts.append(("error", str(e)[:200]))
|
||||
await update_ui("💥 **Task Failed**", force=True)
|
||||
|
||||
def _build_message(
|
||||
self,
|
||||
parts: List[Tuple[str, str]],
|
||||
status: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Build unified message from parts."""
|
||||
lines = []
|
||||
if status:
|
||||
lines.append(status)
|
||||
lines.append("")
|
||||
|
||||
for part_type, content in parts:
|
||||
if part_type == "thinking":
|
||||
display = content[:1200] + ("..." if len(content) > 1200 else "")
|
||||
lines.append(f"💭 **Thinking:**\n```\n{display}\n```")
|
||||
elif part_type == "tool":
|
||||
lines.append(f"🔧 **Tools:** `{content}`")
|
||||
elif part_type == "content":
|
||||
lines.append(content)
|
||||
elif part_type == "error":
|
||||
lines.append(f"⚠️ {content}")
|
||||
|
||||
result = "\n".join(lines)
|
||||
# Truncate if too long
|
||||
if len(result) > 3800:
|
||||
result = "..." + result[-3795:]
|
||||
if result.count("```") % 2 != 0:
|
||||
result += "\n```"
|
||||
return result
|
||||
|
||||
def _get_initial_status(self, session_id: Optional[str]) -> str:
|
||||
"""Get initial status message text."""
|
||||
if session_id:
|
||||
if self.message_queue.is_session_busy(session_id):
|
||||
queue_size = self.message_queue.get_queue_size(session_id) + 1
|
||||
return f"📋 **Queued** (position {queue_size}) - waiting..."
|
||||
return "🔄 **Continuing conversation...**"
|
||||
|
||||
stats = self.cli_manager.get_stats()
|
||||
if stats["active_sessions"] >= stats["max_sessions"]:
|
||||
return f"⏳ **Waiting for slot...** ({stats['active_sessions']}/{stats['max_sessions']})"
|
||||
return "⏳ **Launching new Claude CLI instance...**"
|
||||
|
||||
async def _handle_stop_command(self, incoming: IncomingMessage) -> None:
|
||||
"""Handle /stop command."""
|
||||
cancelled = await self.message_queue.cancel_all()
|
||||
await self.cli_manager.stop_all()
|
||||
await self.platform.send_message(
|
||||
incoming.chat_id,
|
||||
f"⏹ **Stopped.** Cancelled {len(cancelled)} pending messages.",
|
||||
)
|
||||
|
||||
async def _handle_stats_command(self, incoming: IncomingMessage) -> None:
|
||||
"""Handle /stats command."""
|
||||
stats = self.cli_manager.get_stats()
|
||||
await self.platform.send_message(
|
||||
incoming.chat_id,
|
||||
f"📊 **Stats**\n• Active: {stats['active_sessions']}\n• Max: {stats['max_sessions']}",
|
||||
)
|
||||
@@ -0,0 +1,64 @@
|
||||
"""Platform-agnostic message models."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Any
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class IncomingMessage:
|
||||
"""
|
||||
Platform-agnostic incoming message.
|
||||
|
||||
Adapters convert platform-specific events to this format.
|
||||
"""
|
||||
|
||||
text: str
|
||||
chat_id: str
|
||||
user_id: str
|
||||
message_id: str
|
||||
platform: str # "telegram", "discord", "slack", etc.
|
||||
|
||||
# Optional fields
|
||||
reply_to_message_id: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
# Platform-specific raw event for edge cases
|
||||
raw_event: Any = None
|
||||
|
||||
def is_reply(self) -> bool:
|
||||
"""Check if this message is a reply to another message."""
|
||||
return self.reply_to_message_id is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutgoingMessage:
|
||||
"""
|
||||
Platform-agnostic outgoing message.
|
||||
|
||||
The handler creates these, adapters convert to platform-specific format.
|
||||
"""
|
||||
|
||||
text: str
|
||||
chat_id: str
|
||||
|
||||
# Optional fields
|
||||
reply_to: Optional[str] = None
|
||||
parse_mode: Optional[str] = "markdown"
|
||||
|
||||
# For editing existing messages
|
||||
edit_message_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageContext:
|
||||
"""
|
||||
Context for message processing.
|
||||
|
||||
Passed to handlers to track state across a conversation.
|
||||
"""
|
||||
|
||||
session_id: Optional[str] = None
|
||||
is_new_session: bool = True
|
||||
status_message_id: Optional[str] = None
|
||||
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
Message Queue Manager for Messaging Platforms
|
||||
|
||||
Handles queuing of messages when Claude is busy processing a request.
|
||||
Messages are processed one-by-one in order per session.
|
||||
Platform-agnostic: works with any messaging platform.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Callable, Awaitable, Dict, Optional, List, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .models import IncomingMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueuedMessage:
|
||||
"""A message waiting to be processed."""
|
||||
|
||||
incoming: IncomingMessage
|
||||
status_message_id: str # The status message to update
|
||||
context: Any = None # Additional context if needed
|
||||
|
||||
|
||||
class SessionQueue:
|
||||
"""Queue for a single session."""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self.session_id = session_id
|
||||
self.queue: asyncio.Queue[QueuedMessage] = asyncio.Queue()
|
||||
self.is_processing = False
|
||||
self.current_task: Optional[asyncio.Task] = None
|
||||
self.current_message: Optional[QueuedMessage] = None
|
||||
|
||||
|
||||
class MessageQueueManager:
|
||||
"""
|
||||
Manages per-session message queues.
|
||||
|
||||
When a session is busy, new messages are queued and processed
|
||||
one-by-one after the current request completes.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._queues: Dict[str, SessionQueue] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def _get_or_create_queue(self, session_id: str) -> SessionQueue:
|
||||
"""Get existing queue or create new one for session."""
|
||||
if session_id not in self._queues:
|
||||
self._queues[session_id] = SessionQueue(session_id)
|
||||
return self._queues[session_id]
|
||||
|
||||
def is_session_busy(self, session_id: str) -> bool:
|
||||
"""Check if a session is currently processing a request."""
|
||||
if session_id not in self._queues:
|
||||
return False
|
||||
return self._queues[session_id].is_processing
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
session_id: str,
|
||||
message: QueuedMessage,
|
||||
processor: Callable[[str, QueuedMessage], Awaitable[None]],
|
||||
) -> bool:
|
||||
"""
|
||||
Add a message to the session's queue.
|
||||
|
||||
If the session is not busy, processing starts immediately.
|
||||
If busy, the message is queued for later processing.
|
||||
|
||||
Args:
|
||||
session_id: Claude session ID
|
||||
message: The queued message data
|
||||
processor: Async function to process the message
|
||||
|
||||
Returns:
|
||||
True if message was queued (session busy), False if processed immediately
|
||||
"""
|
||||
async with self._lock:
|
||||
sq = self._get_or_create_queue(session_id)
|
||||
|
||||
if sq.is_processing:
|
||||
# Session is busy, queue the message
|
||||
await sq.queue.put(message)
|
||||
queue_size = sq.queue.qsize()
|
||||
logger.info(
|
||||
f"Queued message for session {session_id}, queue size: {queue_size}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
# Session is free, start processing
|
||||
sq.is_processing = True
|
||||
|
||||
# Process outside the lock
|
||||
sq = self._queues[session_id]
|
||||
sq.current_task = asyncio.create_task(
|
||||
self._process_message(session_id, message, processor)
|
||||
)
|
||||
return False
|
||||
|
||||
async def _process_message(
|
||||
self,
|
||||
session_id: str,
|
||||
message: QueuedMessage,
|
||||
processor: Callable[[str, QueuedMessage], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Process a single message and then check the queue."""
|
||||
sq = self._queues.get(session_id)
|
||||
if sq:
|
||||
sq.current_message = message
|
||||
try:
|
||||
await processor(session_id, message)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Task for session {session_id} was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message for session {session_id}: {e}")
|
||||
finally:
|
||||
if sq:
|
||||
sq.current_message = None
|
||||
# Check if there are more messages in the queue
|
||||
await self._process_next(session_id, processor)
|
||||
|
||||
async def _process_next(
|
||||
self,
|
||||
session_id: str,
|
||||
processor: Callable[[str, QueuedMessage], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Process the next message in queue, if any."""
|
||||
async with self._lock:
|
||||
if session_id not in self._queues:
|
||||
return
|
||||
|
||||
sq = self._queues[session_id]
|
||||
|
||||
if sq.queue.empty():
|
||||
# No more messages, mark session as free
|
||||
sq.is_processing = False
|
||||
logger.debug(f"Session {session_id} queue empty, marking as free")
|
||||
return
|
||||
|
||||
# Get next message
|
||||
try:
|
||||
next_msg = sq.queue.get_nowait()
|
||||
logger.info(f"Processing next queued message for session {session_id}")
|
||||
except asyncio.QueueEmpty:
|
||||
sq.is_processing = False
|
||||
return
|
||||
|
||||
# Process next message (outside lock)
|
||||
sq.current_task = asyncio.create_task(
|
||||
self._process_message(session_id, next_msg, processor)
|
||||
)
|
||||
|
||||
def get_queue_size(self, session_id: str) -> int:
|
||||
"""Get the number of messages waiting in a session's queue."""
|
||||
if session_id not in self._queues:
|
||||
return 0
|
||||
return self._queues[session_id].queue.qsize()
|
||||
|
||||
def cancel_session(self, session_id: str) -> List[QueuedMessage]:
|
||||
"""
|
||||
Cancel all queued messages for a session and the running task.
|
||||
|
||||
Returns:
|
||||
List of messages that were cancelled (including the current one if any)
|
||||
"""
|
||||
if session_id not in self._queues:
|
||||
return []
|
||||
|
||||
sq = self._queues[session_id]
|
||||
cancelled_messages = []
|
||||
|
||||
# 1. Cancel running task
|
||||
if sq.current_task and not sq.current_task.done():
|
||||
sq.current_task.cancel()
|
||||
if sq.current_message:
|
||||
cancelled_messages.append(sq.current_message)
|
||||
|
||||
# 2. Clear queue
|
||||
while not sq.queue.empty():
|
||||
try:
|
||||
msg = sq.queue.get_nowait()
|
||||
cancelled_messages.append(msg)
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
sq.is_processing = False
|
||||
logger.info(
|
||||
f"Cancelled {len(cancelled_messages)} messages for session {session_id}"
|
||||
)
|
||||
return cancelled_messages
|
||||
|
||||
async def cancel_all(self) -> List[QueuedMessage]:
|
||||
"""
|
||||
Cancel everything in all sessions.
|
||||
|
||||
Returns:
|
||||
List of all cancelled messages across all sessions.
|
||||
"""
|
||||
async with self._lock:
|
||||
all_cancelled = []
|
||||
session_ids = list(self._queues.keys())
|
||||
for sid in session_ids:
|
||||
all_cancelled.extend(self.cancel_session(sid))
|
||||
return all_cancelled
|
||||
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
Session Store for Messaging Platforms
|
||||
|
||||
Provides persistent storage for mapping platform messages to Claude CLI session IDs.
|
||||
This enables conversation continuation when replying to old messages.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict
|
||||
from dataclasses import dataclass, asdict
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionRecord:
|
||||
"""A single session record."""
|
||||
|
||||
session_id: str
|
||||
chat_id: str # Changed to str for platform-agnostic support
|
||||
initial_msg_id: str
|
||||
last_msg_id: str
|
||||
platform: str # "telegram", "discord", etc.
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
|
||||
class SessionStore:
|
||||
"""
|
||||
Persistent storage for message ↔ Claude session mappings.
|
||||
|
||||
Uses a JSON file for storage with thread-safe operations.
|
||||
Platform-agnostic: works with any messaging platform.
|
||||
"""
|
||||
|
||||
def __init__(self, storage_path: str = "sessions.json"):
|
||||
self.storage_path = storage_path
|
||||
self._lock = threading.Lock()
|
||||
self._sessions: Dict[str, SessionRecord] = {} # session_id -> record
|
||||
self._msg_to_session: Dict[
|
||||
str, str
|
||||
] = {} # "platform:chat_id:msg_id" -> session_id
|
||||
self._load()
|
||||
|
||||
def _make_key(self, platform: str, chat_id: str, msg_id: str) -> str:
|
||||
"""Create a unique key from platform, chat_id and msg_id."""
|
||||
return f"{platform}:{chat_id}:{msg_id}"
|
||||
|
||||
def _load(self) -> None:
|
||||
"""Load sessions from disk."""
|
||||
if not os.path.exists(self.storage_path):
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.storage_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for sid, record_data in data.get("sessions", {}).items():
|
||||
# Handle legacy records without platform field
|
||||
if "platform" not in record_data:
|
||||
record_data["platform"] = "telegram"
|
||||
# Convert int to str for backwards compatibility
|
||||
for field in ["chat_id", "initial_msg_id", "last_msg_id"]:
|
||||
if isinstance(record_data.get(field), int):
|
||||
record_data[field] = str(record_data[field])
|
||||
|
||||
record = SessionRecord(**record_data)
|
||||
self._sessions[sid] = record
|
||||
# Index by initial and last message
|
||||
self._msg_to_session[
|
||||
self._make_key(
|
||||
record.platform, record.chat_id, record.initial_msg_id
|
||||
)
|
||||
] = sid
|
||||
self._msg_to_session[
|
||||
self._make_key(record.platform, record.chat_id, record.last_msg_id)
|
||||
] = sid
|
||||
|
||||
logger.info(
|
||||
f"Loaded {len(self._sessions)} sessions from {self.storage_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load sessions: {e}")
|
||||
|
||||
def _save(self) -> None:
|
||||
"""Persist sessions to disk."""
|
||||
try:
|
||||
data = {
|
||||
"sessions": {
|
||||
sid: asdict(record) for sid, record in self._sessions.items()
|
||||
}
|
||||
}
|
||||
with open(self.storage_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save sessions: {e}")
|
||||
|
||||
def save_session(
|
||||
self,
|
||||
session_id: str,
|
||||
chat_id: str,
|
||||
initial_msg_id: str,
|
||||
platform: str = "telegram",
|
||||
) -> None:
|
||||
"""
|
||||
Save a new session mapping.
|
||||
|
||||
Args:
|
||||
session_id: Claude CLI session ID
|
||||
chat_id: Chat ID (platform-specific)
|
||||
initial_msg_id: The message ID that started this session
|
||||
platform: Messaging platform name
|
||||
"""
|
||||
with self._lock:
|
||||
now = datetime.utcnow().isoformat()
|
||||
record = SessionRecord(
|
||||
session_id=session_id,
|
||||
chat_id=str(chat_id),
|
||||
initial_msg_id=str(initial_msg_id),
|
||||
last_msg_id=str(initial_msg_id),
|
||||
platform=platform,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
self._sessions[session_id] = record
|
||||
self._msg_to_session[
|
||||
self._make_key(platform, str(chat_id), str(initial_msg_id))
|
||||
] = session_id
|
||||
self._save()
|
||||
logger.info(
|
||||
f"Saved session {session_id} for {platform} chat {chat_id}, msg {initial_msg_id}"
|
||||
)
|
||||
|
||||
def get_session_by_msg(
|
||||
self, chat_id: str, msg_id: str, platform: str = "telegram"
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Look up a session ID by a message that's part of that session.
|
||||
|
||||
Args:
|
||||
chat_id: Chat ID
|
||||
msg_id: Message ID to look up
|
||||
platform: Messaging platform name
|
||||
|
||||
Returns:
|
||||
Session ID if found, None otherwise
|
||||
"""
|
||||
with self._lock:
|
||||
key = self._make_key(platform, str(chat_id), str(msg_id))
|
||||
return self._msg_to_session.get(key)
|
||||
|
||||
def update_last_message(self, session_id: str, msg_id: str) -> None:
|
||||
"""
|
||||
Update the last message ID for a session.
|
||||
|
||||
Args:
|
||||
session_id: Claude session ID
|
||||
msg_id: New last message ID
|
||||
"""
|
||||
with self._lock:
|
||||
if session_id not in self._sessions:
|
||||
logger.warning(f"Session {session_id} not found for update")
|
||||
return
|
||||
|
||||
record = self._sessions[session_id]
|
||||
|
||||
# Update record
|
||||
record.last_msg_id = str(msg_id)
|
||||
record.updated_at = datetime.utcnow().isoformat()
|
||||
|
||||
# Update index - add new key, keep old one for chain lookups
|
||||
new_key = self._make_key(record.platform, record.chat_id, str(msg_id))
|
||||
self._msg_to_session[new_key] = session_id
|
||||
|
||||
self._save()
|
||||
logger.debug(f"Updated session {session_id} last_msg to {msg_id}")
|
||||
|
||||
def get_session_record(self, session_id: str) -> Optional[SessionRecord]:
|
||||
"""Get full session record."""
|
||||
with self._lock:
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
def cleanup_old_sessions(self, max_age_days: int = 30) -> int:
|
||||
"""
|
||||
Remove sessions older than max_age_days.
|
||||
|
||||
Returns:
|
||||
Number of sessions removed
|
||||
"""
|
||||
with self._lock:
|
||||
cutoff = datetime.utcnow()
|
||||
removed = 0
|
||||
|
||||
to_remove = []
|
||||
for sid, record in self._sessions.items():
|
||||
try:
|
||||
created = datetime.fromisoformat(record.created_at)
|
||||
age_days = (cutoff - created).days
|
||||
if age_days > max_age_days:
|
||||
to_remove.append(sid)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for sid in to_remove:
|
||||
record = self._sessions.pop(sid)
|
||||
# Remove index entries
|
||||
self._msg_to_session.pop(
|
||||
self._make_key(
|
||||
record.platform, record.chat_id, record.initial_msg_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
self._msg_to_session.pop(
|
||||
self._make_key(record.platform, record.chat_id, record.last_msg_id),
|
||||
None,
|
||||
)
|
||||
removed += 1
|
||||
|
||||
if removed:
|
||||
self._save()
|
||||
logger.info(f"Cleaned up {removed} old sessions")
|
||||
|
||||
return removed
|
||||
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Telegram Platform Adapter
|
||||
|
||||
Implements MessagingPlatform for Telegram using Telethon.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Awaitable, Optional, Any
|
||||
|
||||
from .base import MessagingPlatform
|
||||
from .models import IncomingMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Optional import - Telethon may not be installed
|
||||
try:
|
||||
from telethon import TelegramClient, events, errors
|
||||
|
||||
TELETHON_AVAILABLE = True
|
||||
except ImportError:
|
||||
TelegramClient = None
|
||||
events = None
|
||||
errors = None
|
||||
TELETHON_AVAILABLE = False
|
||||
|
||||
|
||||
class TelegramPlatform(MessagingPlatform):
|
||||
"""
|
||||
Telegram messaging platform adapter.
|
||||
|
||||
Uses Telethon for Telegram API access.
|
||||
Designed for personal use (sending messages to yourself).
|
||||
"""
|
||||
|
||||
name = "telegram"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_id: Optional[str] = None,
|
||||
api_hash: Optional[str] = None,
|
||||
allowed_user_id: Optional[str] = None,
|
||||
session_path: str = "claude_bot.session",
|
||||
):
|
||||
if not TELETHON_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Telethon is required for Telegram support. Install with: pip install telethon"
|
||||
)
|
||||
|
||||
self.api_id = api_id or os.getenv("TELEGRAM_API_ID")
|
||||
self.api_hash = api_hash or os.getenv("TELEGRAM_API_HASH")
|
||||
self.allowed_user_id = allowed_user_id or os.getenv("ALLOWED_TELEGRAM_USER_ID")
|
||||
self.session_path = session_path
|
||||
|
||||
self._client: Optional[TelegramClient] = None
|
||||
self._message_handler: Optional[
|
||||
Callable[[IncomingMessage], Awaitable[None]]
|
||||
] = None
|
||||
self._connected = False
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Initialize and connect to Telegram."""
|
||||
if not self.api_id or not self.api_hash:
|
||||
raise ValueError("TELEGRAM_API_ID and TELEGRAM_API_HASH are required")
|
||||
|
||||
self._client = TelegramClient(
|
||||
self.session_path,
|
||||
int(self.api_id),
|
||||
self.api_hash,
|
||||
)
|
||||
|
||||
# Register event handler
|
||||
@self._client.on(events.NewMessage())
|
||||
async def on_new_message(event):
|
||||
await self._handle_event(event)
|
||||
|
||||
await self._client.start()
|
||||
self._connected = True
|
||||
|
||||
# Run in background
|
||||
asyncio.create_task(self._client.run_until_disconnected())
|
||||
|
||||
# Send startup notification
|
||||
try:
|
||||
await self._client.send_message("me", "🚀 **Claude Code Proxy is online!**")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not send startup message: {e}")
|
||||
|
||||
logger.info("Telegram platform started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Disconnect from Telegram."""
|
||||
if self._client:
|
||||
await self._client.disconnect()
|
||||
self._connected = False
|
||||
logger.info("Telegram platform stopped")
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
text: str,
|
||||
reply_to: Optional[str] = None,
|
||||
parse_mode: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Send a message to a chat."""
|
||||
if not self._client:
|
||||
raise RuntimeError("Telegram client not connected")
|
||||
|
||||
try:
|
||||
msg = await self._client.send_message(
|
||||
int(chat_id),
|
||||
text,
|
||||
reply_to=int(reply_to) if reply_to else None,
|
||||
parse_mode=parse_mode,
|
||||
)
|
||||
return str(msg.id)
|
||||
except errors.FloodWaitError as e:
|
||||
logger.error(f"Telegram flood wait: {e.seconds}s")
|
||||
raise
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
text: str,
|
||||
parse_mode: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Edit an existing message."""
|
||||
if not self._client:
|
||||
raise RuntimeError("Telegram client not connected")
|
||||
|
||||
try:
|
||||
await self._client.edit_message(
|
||||
int(chat_id),
|
||||
int(message_id),
|
||||
text,
|
||||
parse_mode=parse_mode,
|
||||
)
|
||||
except errors.FloodWaitError as e:
|
||||
logger.error(f"Telegram flood wait on edit: {e.seconds}s")
|
||||
raise
|
||||
except errors.MessageNotModifiedError:
|
||||
# Message content unchanged, ignore
|
||||
pass
|
||||
|
||||
def on_message(
|
||||
self,
|
||||
handler: Callable[[IncomingMessage], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Register a message handler callback."""
|
||||
self._message_handler = handler
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if connected to Telegram."""
|
||||
return self._connected and self._client is not None
|
||||
|
||||
async def _handle_event(self, event: Any) -> None:
|
||||
"""Handle incoming Telegram event."""
|
||||
# Security check
|
||||
if self.allowed_user_id:
|
||||
if str(event.sender_id) != str(self.allowed_user_id).strip():
|
||||
logger.debug(
|
||||
f"Ignored message from unauthorized user: {event.sender_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if not event.text:
|
||||
return
|
||||
|
||||
if not self._message_handler:
|
||||
logger.warning("No message handler registered")
|
||||
return
|
||||
|
||||
# Convert to platform-agnostic message
|
||||
incoming = IncomingMessage(
|
||||
text=event.text,
|
||||
chat_id=str(event.chat_id),
|
||||
user_id=str(event.sender_id),
|
||||
message_id=str(event.id),
|
||||
platform="telegram",
|
||||
reply_to_message_id=str(event.reply_to_msg_id)
|
||||
if event.reply_to_msg_id
|
||||
else None,
|
||||
raw_event=event,
|
||||
)
|
||||
|
||||
try:
|
||||
await self._message_handler(incoming)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
try:
|
||||
await self.send_message(
|
||||
incoming.chat_id,
|
||||
f"❌ **Error:** {str(e)[:200]}",
|
||||
reply_to=incoming.message_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
@@ -13,6 +13,7 @@ dependencies = [
|
||||
"tiktoken>=0.7.0",
|
||||
"websockets>=13.0",
|
||||
"telethon>=1.35.0",
|
||||
"pydantic-settings>=2.12.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
@@ -2,6 +2,11 @@ import pytest
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Set mock environment BEFORE any imports that use Settings
|
||||
os.environ.setdefault("NVIDIA_NIM_API_KEY", "test_key")
|
||||
os.environ.setdefault("BIG_MODEL", "test-model")
|
||||
os.environ.setdefault("SMALL_MODEL", "test-model")
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
|
||||
+2
-4
@@ -1,6 +1,7 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from server import app, get_provider
|
||||
from api.app import app
|
||||
from api.dependencies import get_provider
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
|
||||
@@ -63,11 +64,8 @@ def test_model_mapping():
|
||||
"max_tokens": 100,
|
||||
}
|
||||
client.post("/v1/messages", json=payload_haiku)
|
||||
# The actual call to provider should use the mapped model
|
||||
args, _ = mock_provider.complete.call_args
|
||||
# It should not be the original model
|
||||
assert args[0].model != "claude-3-haiku-20240307"
|
||||
# It should have original_model set
|
||||
assert args[0].original_model == "claude-3-haiku-20240307"
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,224 @@
|
||||
"""Tests for cli/ module."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
|
||||
class TestCLIParser:
|
||||
"""Test CLIParser event parsing."""
|
||||
|
||||
def test_parse_text_content(self):
|
||||
"""Test parsing text content from assistant message."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {
|
||||
"type": "assistant",
|
||||
"message": {"content": [{"type": "text", "text": "Hello, world!"}]},
|
||||
}
|
||||
result = CLIParser.parse_event(event)
|
||||
assert result["type"] == "content"
|
||||
assert result["text"] == "Hello, world!"
|
||||
|
||||
def test_parse_thinking_content(self):
|
||||
"""Test parsing thinking content."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"content": [{"type": "thinking", "thinking": "Let me think..."}]
|
||||
},
|
||||
}
|
||||
result = CLIParser.parse_event(event)
|
||||
assert result["type"] == "content"
|
||||
assert result["thinking"] == "Let me think..."
|
||||
|
||||
def test_parse_tool_use(self):
|
||||
"""Test parsing tool use content."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {
|
||||
"type": "assistant",
|
||||
"message": {
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"name": "read_file",
|
||||
"input": {"path": "/test"},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
result = CLIParser.parse_event(event)
|
||||
assert result["type"] == "tool_start"
|
||||
assert len(result["tools"]) == 1
|
||||
assert result["tools"][0]["name"] == "read_file"
|
||||
|
||||
def test_parse_text_delta(self):
|
||||
"""Test parsing streaming text delta."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {
|
||||
"type": "content_block_delta",
|
||||
"delta": {"type": "text_delta", "text": "streaming text"},
|
||||
}
|
||||
result = CLIParser.parse_event(event)
|
||||
assert result["type"] == "content"
|
||||
assert result["text"] == "streaming text"
|
||||
|
||||
def test_parse_thinking_delta(self):
|
||||
"""Test parsing streaming thinking delta."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {
|
||||
"type": "content_block_delta",
|
||||
"delta": {"type": "thinking_delta", "thinking": "thinking..."},
|
||||
}
|
||||
result = CLIParser.parse_event(event)
|
||||
assert result["type"] == "thinking"
|
||||
assert result["text"] == "thinking..."
|
||||
|
||||
def test_parse_error(self):
|
||||
"""Test parsing error event."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {"type": "error", "error": {"message": "Something went wrong"}}
|
||||
result = CLIParser.parse_event(event)
|
||||
assert result["type"] == "error"
|
||||
assert result["message"] == "Something went wrong"
|
||||
|
||||
def test_parse_exit_success(self):
|
||||
"""Test parsing exit event with success."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {"type": "exit", "code": 0}
|
||||
result = CLIParser.parse_event(event)
|
||||
assert result["type"] == "complete"
|
||||
assert result["status"] == "success"
|
||||
|
||||
def test_parse_exit_failure(self):
|
||||
"""Test parsing exit event with failure."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
event = {"type": "exit", "code": 1}
|
||||
result = CLIParser.parse_event(event)
|
||||
assert result["type"] == "complete"
|
||||
assert result["status"] == "failed"
|
||||
|
||||
def test_parse_invalid_event(self):
|
||||
"""Test parsing returns None for unrecognized event."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
result = CLIParser.parse_event({"type": "unknown"})
|
||||
assert result is None
|
||||
|
||||
def test_parse_non_dict(self):
|
||||
"""Test parsing returns None for non-dict input."""
|
||||
from cli.parser import CLIParser
|
||||
|
||||
result = CLIParser.parse_event("not a dict")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestCLISession:
|
||||
"""Test CLISession."""
|
||||
|
||||
def test_session_init(self):
|
||||
"""Test CLISession initialization."""
|
||||
from cli.session import CLISession
|
||||
|
||||
session = CLISession(
|
||||
workspace_path="/tmp/test",
|
||||
api_url="http://localhost:8082/v1",
|
||||
allowed_dirs=["/home/user/projects"],
|
||||
)
|
||||
assert session.workspace == "/tmp/test" or "test" in session.workspace
|
||||
assert session.api_url == "http://localhost:8082/v1"
|
||||
assert not session.is_busy
|
||||
|
||||
def test_session_extract_session_id(self):
|
||||
"""Test session ID extraction from various event formats."""
|
||||
from cli.session import CLISession
|
||||
|
||||
session = CLISession("/tmp", "http://localhost:8082/v1")
|
||||
|
||||
# Direct session_id field
|
||||
assert session._extract_session_id({"session_id": "abc123"}) == "abc123"
|
||||
assert session._extract_session_id({"sessionId": "abc123"}) == "abc123"
|
||||
|
||||
# Nested in init
|
||||
assert (
|
||||
session._extract_session_id({"init": {"session_id": "nested123"}})
|
||||
== "nested123"
|
||||
)
|
||||
|
||||
# Nested in result
|
||||
assert (
|
||||
session._extract_session_id({"result": {"session_id": "res123"}})
|
||||
== "res123"
|
||||
)
|
||||
|
||||
# Conversation id
|
||||
assert (
|
||||
session._extract_session_id({"conversation": {"id": "conv123"}})
|
||||
== "conv123"
|
||||
)
|
||||
|
||||
# No session ID
|
||||
assert session._extract_session_id({"type": "message"}) is None
|
||||
assert session._extract_session_id("not a dict") is None
|
||||
|
||||
|
||||
class TestCLISessionManager:
|
||||
"""Test CLISessionManager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_create_session(self):
|
||||
"""Test creating a new session."""
|
||||
from cli.manager import CLISessionManager
|
||||
|
||||
manager = CLISessionManager(
|
||||
workspace_path="/tmp/test",
|
||||
api_url="http://localhost:8082/v1",
|
||||
max_sessions=5,
|
||||
)
|
||||
|
||||
session, sid, is_new = await manager.get_or_create_session()
|
||||
assert session is not None
|
||||
assert sid.startswith("pending_")
|
||||
assert is_new is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_reuse_session(self):
|
||||
"""Test reusing an existing session."""
|
||||
from cli.manager import CLISessionManager
|
||||
|
||||
manager = CLISessionManager(
|
||||
workspace_path="/tmp/test",
|
||||
api_url="http://localhost:8082/v1",
|
||||
)
|
||||
|
||||
# Create first session
|
||||
s1, sid1, is_new1 = await manager.get_or_create_session()
|
||||
|
||||
# Request same session
|
||||
s2, sid2, is_new2 = await manager.get_or_create_session(session_id=sid1)
|
||||
|
||||
assert s1 is s2
|
||||
assert is_new2 is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_stats(self):
|
||||
"""Test manager stats."""
|
||||
from cli.manager import CLISessionManager
|
||||
|
||||
manager = CLISessionManager(
|
||||
workspace_path="/tmp/test",
|
||||
api_url="http://localhost:8082/v1",
|
||||
max_sessions=10,
|
||||
)
|
||||
|
||||
stats = manager.get_stats()
|
||||
assert stats["max_sessions"] == 10
|
||||
assert stats["active_sessions"] == 0
|
||||
assert stats["pending_sessions"] == 0
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Tests for config/settings.py"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
|
||||
|
||||
class TestSettings:
|
||||
"""Test Settings configuration."""
|
||||
|
||||
def test_settings_loads(self):
|
||||
"""Ensure Settings can be instantiated."""
|
||||
from config.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
assert settings is not None
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set and have correct types."""
|
||||
from config.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
assert isinstance(settings.nvidia_nim_rate_limit, int)
|
||||
assert isinstance(settings.nvidia_nim_rate_window, int)
|
||||
assert isinstance(settings.fast_prefix_detection, bool)
|
||||
assert isinstance(settings.max_cli_sessions, int)
|
||||
|
||||
def test_get_settings_cached(self):
|
||||
"""Test get_settings returns cached instance."""
|
||||
from config.settings import get_settings
|
||||
|
||||
s1 = get_settings()
|
||||
s2 = get_settings()
|
||||
assert s1 is s2 # Same object (cached)
|
||||
|
||||
def test_empty_string_to_none_for_optional_int(self):
|
||||
"""Test that empty string converts to None for optional int fields."""
|
||||
from config.settings import Settings
|
||||
|
||||
# Settings should handle NVIDIA_NIM_SEED="" gracefully
|
||||
settings = Settings()
|
||||
assert settings.nvidia_nim_seed is None or isinstance(
|
||||
settings.nvidia_nim_seed, int
|
||||
)
|
||||
|
||||
def test_model_mapping_defaults(self):
|
||||
"""Test model mapping defaults."""
|
||||
from config.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
assert "kimi" in settings.big_model.lower() or settings.big_model != ""
|
||||
assert "kimi" in settings.small_model.lower() or settings.small_model != ""
|
||||
@@ -0,0 +1,189 @@
|
||||
"""Tests for messaging/ module."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
|
||||
class TestMessagingModels:
|
||||
"""Test messaging models."""
|
||||
|
||||
def test_incoming_message_creation(self):
|
||||
"""Test IncomingMessage dataclass."""
|
||||
from messaging.models import IncomingMessage
|
||||
|
||||
msg = IncomingMessage(
|
||||
text="Hello",
|
||||
chat_id="123",
|
||||
user_id="456",
|
||||
message_id="789",
|
||||
platform="telegram",
|
||||
)
|
||||
assert msg.text == "Hello"
|
||||
assert msg.chat_id == "123"
|
||||
assert msg.platform == "telegram"
|
||||
assert msg.is_reply() is False
|
||||
|
||||
def test_incoming_message_with_reply(self):
|
||||
"""Test IncomingMessage as a reply."""
|
||||
from messaging.models import IncomingMessage
|
||||
|
||||
msg = IncomingMessage(
|
||||
text="Reply text",
|
||||
chat_id="123",
|
||||
user_id="456",
|
||||
message_id="789",
|
||||
platform="discord",
|
||||
reply_to_message_id="100",
|
||||
)
|
||||
assert msg.is_reply() is True
|
||||
assert msg.reply_to_message_id == "100"
|
||||
|
||||
def test_outgoing_message_creation(self):
|
||||
"""Test OutgoingMessage dataclass."""
|
||||
from messaging.models import OutgoingMessage
|
||||
|
||||
msg = OutgoingMessage(
|
||||
text="Response",
|
||||
chat_id="123",
|
||||
parse_mode="markdown",
|
||||
)
|
||||
assert msg.text == "Response"
|
||||
assert msg.parse_mode == "markdown"
|
||||
assert msg.edit_message_id is None
|
||||
|
||||
def test_message_context(self):
|
||||
"""Test MessageContext dataclass."""
|
||||
from messaging.models import MessageContext
|
||||
|
||||
ctx = MessageContext(session_id="sess123", is_new_session=False)
|
||||
assert ctx.session_id == "sess123"
|
||||
assert ctx.is_new_session is False
|
||||
|
||||
|
||||
class TestMessagingBase:
|
||||
"""Test MessagingPlatform ABC."""
|
||||
|
||||
def test_platform_is_abstract(self):
|
||||
"""Verify MessagingPlatform cannot be instantiated."""
|
||||
from messaging.base import MessagingPlatform
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
MessagingPlatform()
|
||||
|
||||
|
||||
class TestSessionStore:
|
||||
"""Test SessionStore."""
|
||||
|
||||
def test_session_store_init(self, tmp_path):
|
||||
"""Test SessionStore initialization."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
assert store._sessions == {}
|
||||
|
||||
def test_save_and_get_session(self, tmp_path):
|
||||
"""Test saving and retrieving a session."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
|
||||
store.save_session(
|
||||
session_id="sess_123",
|
||||
chat_id="chat_456",
|
||||
initial_msg_id="msg_789",
|
||||
platform="telegram",
|
||||
)
|
||||
|
||||
# Retrieve by message
|
||||
found = store.get_session_by_msg("chat_456", "msg_789", "telegram")
|
||||
assert found == "sess_123"
|
||||
|
||||
def test_update_last_message(self, tmp_path):
|
||||
"""Test updating last message in session."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
|
||||
store.save_session("sess_1", "chat_1", "msg_1", "telegram")
|
||||
store.update_last_message("sess_1", "msg_2")
|
||||
|
||||
# Should find session by new message too
|
||||
found = store.get_session_by_msg("chat_1", "msg_2", "telegram")
|
||||
assert found == "sess_1"
|
||||
|
||||
def test_get_session_record(self, tmp_path):
|
||||
"""Test getting full session record."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
store.save_session("sess_1", "chat_1", "msg_1", "telegram")
|
||||
|
||||
record = store.get_session_record("sess_1")
|
||||
assert record is not None
|
||||
assert record.session_id == "sess_1"
|
||||
assert record.platform == "telegram"
|
||||
|
||||
def test_session_not_found(self, tmp_path):
|
||||
"""Test getting non-existent session returns None."""
|
||||
from messaging.session import SessionStore
|
||||
|
||||
store = SessionStore(storage_path=str(tmp_path / "sessions.json"))
|
||||
|
||||
found = store.get_session_by_msg("notexist", "notexist", "telegram")
|
||||
assert found is None
|
||||
|
||||
|
||||
class TestMessageQueueManager:
|
||||
"""Test MessageQueueManager."""
|
||||
|
||||
def test_queue_manager_init(self):
|
||||
"""Test MessageQueueManager initialization."""
|
||||
from messaging.queue import MessageQueueManager
|
||||
|
||||
mgr = MessageQueueManager()
|
||||
assert mgr._queues == {}
|
||||
|
||||
def test_session_not_busy_initially(self):
|
||||
"""Test session is not busy when no messages."""
|
||||
from messaging.queue import MessageQueueManager
|
||||
|
||||
mgr = MessageQueueManager()
|
||||
assert mgr.is_session_busy("nonexistent") is False
|
||||
|
||||
def test_get_queue_size_empty(self):
|
||||
"""Test queue size is 0 for non-existent session."""
|
||||
from messaging.queue import MessageQueueManager
|
||||
|
||||
mgr = MessageQueueManager()
|
||||
assert mgr.get_queue_size("nonexistent") == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_and_process(self):
|
||||
"""Test enqueueing a message starts processing."""
|
||||
from messaging.queue import MessageQueueManager, QueuedMessage
|
||||
from messaging.models import IncomingMessage
|
||||
|
||||
mgr = MessageQueueManager()
|
||||
processed = []
|
||||
|
||||
async def processor(sid, msg):
|
||||
processed.append(msg)
|
||||
|
||||
incoming = IncomingMessage(
|
||||
text="test", chat_id="1", user_id="1", message_id="1", platform="test"
|
||||
)
|
||||
queued = QueuedMessage(incoming=incoming, status_message_id="status_1")
|
||||
|
||||
was_queued = await mgr.enqueue("session_1", queued, processor)
|
||||
|
||||
# First message should process immediately, not queue
|
||||
assert was_queued is False
|
||||
|
||||
def test_cancel_session_empty(self):
|
||||
"""Test cancelling non-existent session."""
|
||||
from messaging.queue import MessageQueueManager
|
||||
|
||||
mgr = MessageQueueManager()
|
||||
cancelled = mgr.cancel_session("nonexistent")
|
||||
assert cancelled == []
|
||||
@@ -49,6 +49,7 @@ dependencies = [
|
||||
{ name = "fastapi", extra = ["standard"] },
|
||||
{ name = "httpx" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "telethon" },
|
||||
{ name = "tiktoken" },
|
||||
@@ -68,6 +69,7 @@ requires-dist = [
|
||||
{ name = "fastapi", extras = ["standard"], specifier = ">=0.115.11" },
|
||||
{ name = "httpx", specifier = ">=0.25.0" },
|
||||
{ name = "pydantic", specifier = ">=2.0.0" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.12.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.0.0" },
|
||||
{ name = "telethon", specifier = ">=1.35.0" },
|
||||
{ name = "tiktoken", specifier = ">=0.7.0" },
|
||||
@@ -722,6 +724,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/63/37/3e32eeb2a451fddaa3898e2163746b0cffbbdbb4740d38372db0490d67f3/pydantic_core-2.27.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:7e17b560be3c98a8e3aa66ce828bdebb9e9ac6ad5466fba92eb74c4c95cb1151", size = 2004715, upload-time = "2024-12-18T11:31:22.821Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pydantic-settings"
|
||||
version = "2.12.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pydantic" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "typing-inspection" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/43/4b/ac7e0aae12027748076d72a8764ff1c9d82ca75a7a52622e67ed3f765c54/pydantic_settings-2.12.0.tar.gz", hash = "sha256:005538ef951e3c2a68e1c08b292b5f2e71490def8589d4221b95dab00dafcfd0", size = 194184, upload-time = "2025-11-10T14:25:47.013Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/60/5d4751ba3f4a40a6891f24eec885f51afd78d208498268c734e256fb13c4/pydantic_settings-2.12.0-py3-none-any.whl", hash = "sha256:fddb9fd99a5b18da837b29710391e945b1e30c135477f484084ee513adb93809", size = 51880, upload-time = "2025-11-10T14:25:45.546Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.19.1"
|
||||
@@ -1197,6 +1213,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", size = 37438, upload-time = "2024-06-07T18:52:13.582Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typing-inspection"
|
||||
version = "0.4.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/55/e3/70399cb7dd41c10ac53367ae42139cf4b1ca5f36bb3dc6c9d33acdb43655/typing_inspection-0.4.2.tar.gz", hash = "sha256:ba561c48a67c5958007083d386c3295464928b01faa735ab8547c5692e87f464", size = 75949, upload-time = "2025-10-01T02:14:41.687Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7", size = 14611, upload-time = "2025-10-01T02:14:40.154Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.6.3"
|
||||
|
||||
Reference in New Issue
Block a user