mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-06-02 06:13:46 +02:00
126 lines
4.3 KiB
Python
126 lines
4.3 KiB
Python
"""Model routing for Claude-compatible requests."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
|
|
from loguru import logger
|
|
|
|
from config.provider_ids import SUPPORTED_PROVIDER_IDS
|
|
from config.settings import Settings
|
|
|
|
from .gateway_model_ids import decode_gateway_model_id
|
|
from .models.anthropic import MessagesRequest, TokenCountRequest
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class ResolvedModel:
|
|
original_model: str
|
|
provider_id: str
|
|
provider_model: str
|
|
provider_model_ref: str
|
|
thinking_enabled: bool
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class RoutedMessagesRequest:
|
|
request: MessagesRequest
|
|
resolved: ResolvedModel
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class RoutedTokenCountRequest:
|
|
request: TokenCountRequest
|
|
resolved: ResolvedModel
|
|
|
|
|
|
class ModelRouter:
|
|
"""Resolve incoming Claude model names to configured provider/model pairs."""
|
|
|
|
def __init__(self, settings: Settings):
|
|
self._settings = settings
|
|
|
|
def resolve(self, claude_model_name: str) -> ResolvedModel:
|
|
(
|
|
direct_provider_id,
|
|
direct_provider_model,
|
|
force_thinking_enabled,
|
|
) = self._direct_provider_model(claude_model_name)
|
|
if direct_provider_id is not None and direct_provider_model is not None:
|
|
thinking_enabled = (
|
|
force_thinking_enabled
|
|
if force_thinking_enabled is not None
|
|
else self._settings.resolve_thinking(direct_provider_model)
|
|
)
|
|
logger.debug(
|
|
"MODEL DIRECT: '{}' -> provider='{}' model='{}' thinking={}",
|
|
claude_model_name,
|
|
direct_provider_id,
|
|
direct_provider_model,
|
|
thinking_enabled,
|
|
)
|
|
return ResolvedModel(
|
|
original_model=claude_model_name,
|
|
provider_id=direct_provider_id,
|
|
provider_model=direct_provider_model,
|
|
provider_model_ref=claude_model_name,
|
|
thinking_enabled=thinking_enabled,
|
|
)
|
|
|
|
provider_model_ref = self._settings.resolve_model(claude_model_name)
|
|
thinking_enabled = self._settings.resolve_thinking(claude_model_name)
|
|
provider_id = Settings.parse_provider_type(provider_model_ref)
|
|
provider_model = Settings.parse_model_name(provider_model_ref)
|
|
if provider_model != claude_model_name:
|
|
logger.debug(
|
|
"MODEL MAPPING: '{}' -> '{}'", claude_model_name, provider_model
|
|
)
|
|
return ResolvedModel(
|
|
original_model=claude_model_name,
|
|
provider_id=provider_id,
|
|
provider_model=provider_model,
|
|
provider_model_ref=provider_model_ref,
|
|
thinking_enabled=thinking_enabled,
|
|
)
|
|
|
|
def _direct_provider_model(
|
|
self, model_name: str
|
|
) -> tuple[str | None, str | None, bool | None]:
|
|
decoded = decode_gateway_model_id(model_name)
|
|
if decoded is not None:
|
|
if decoded.provider_id not in SUPPORTED_PROVIDER_IDS:
|
|
return None, None, None
|
|
return (
|
|
decoded.provider_id,
|
|
decoded.provider_model,
|
|
decoded.force_thinking_enabled,
|
|
)
|
|
|
|
provider_id, separator, provider_model = model_name.partition("/")
|
|
if not separator:
|
|
return None, None, None
|
|
if provider_id not in SUPPORTED_PROVIDER_IDS:
|
|
return None, None, None
|
|
if not provider_model:
|
|
return None, None, None
|
|
return provider_id, provider_model, None
|
|
|
|
def resolve_messages_request(
|
|
self, request: MessagesRequest
|
|
) -> RoutedMessagesRequest:
|
|
"""Return an internal routed request context."""
|
|
resolved = self.resolve(request.model)
|
|
routed = request.model_copy(deep=True)
|
|
routed.model = resolved.provider_model
|
|
return RoutedMessagesRequest(request=routed, resolved=resolved)
|
|
|
|
def resolve_token_count_request(
|
|
self, request: TokenCountRequest
|
|
) -> RoutedTokenCountRequest:
|
|
"""Return an internal token-count request context."""
|
|
resolved = self.resolve(request.model)
|
|
routed = request.model_copy(
|
|
update={"model": resolved.provider_model}, deep=True
|
|
)
|
|
return RoutedTokenCountRequest(request=routed, resolved=resolved)
|