Files
2026-04-30 21:27:23 -07:00

126 lines
4.3 KiB
Python

"""Model routing for Claude-compatible requests."""
from __future__ import annotations
from dataclasses import dataclass
from loguru import logger
from config.provider_ids import SUPPORTED_PROVIDER_IDS
from config.settings import Settings
from .gateway_model_ids import decode_gateway_model_id
from .models.anthropic import MessagesRequest, TokenCountRequest
@dataclass(frozen=True, slots=True)
class ResolvedModel:
original_model: str
provider_id: str
provider_model: str
provider_model_ref: str
thinking_enabled: bool
@dataclass(frozen=True, slots=True)
class RoutedMessagesRequest:
request: MessagesRequest
resolved: ResolvedModel
@dataclass(frozen=True, slots=True)
class RoutedTokenCountRequest:
request: TokenCountRequest
resolved: ResolvedModel
class ModelRouter:
"""Resolve incoming Claude model names to configured provider/model pairs."""
def __init__(self, settings: Settings):
self._settings = settings
def resolve(self, claude_model_name: str) -> ResolvedModel:
(
direct_provider_id,
direct_provider_model,
force_thinking_enabled,
) = self._direct_provider_model(claude_model_name)
if direct_provider_id is not None and direct_provider_model is not None:
thinking_enabled = (
force_thinking_enabled
if force_thinking_enabled is not None
else self._settings.resolve_thinking(direct_provider_model)
)
logger.debug(
"MODEL DIRECT: '{}' -> provider='{}' model='{}' thinking={}",
claude_model_name,
direct_provider_id,
direct_provider_model,
thinking_enabled,
)
return ResolvedModel(
original_model=claude_model_name,
provider_id=direct_provider_id,
provider_model=direct_provider_model,
provider_model_ref=claude_model_name,
thinking_enabled=thinking_enabled,
)
provider_model_ref = self._settings.resolve_model(claude_model_name)
thinking_enabled = self._settings.resolve_thinking(claude_model_name)
provider_id = Settings.parse_provider_type(provider_model_ref)
provider_model = Settings.parse_model_name(provider_model_ref)
if provider_model != claude_model_name:
logger.debug(
"MODEL MAPPING: '{}' -> '{}'", claude_model_name, provider_model
)
return ResolvedModel(
original_model=claude_model_name,
provider_id=provider_id,
provider_model=provider_model,
provider_model_ref=provider_model_ref,
thinking_enabled=thinking_enabled,
)
def _direct_provider_model(
self, model_name: str
) -> tuple[str | None, str | None, bool | None]:
decoded = decode_gateway_model_id(model_name)
if decoded is not None:
if decoded.provider_id not in SUPPORTED_PROVIDER_IDS:
return None, None, None
return (
decoded.provider_id,
decoded.provider_model,
decoded.force_thinking_enabled,
)
provider_id, separator, provider_model = model_name.partition("/")
if not separator:
return None, None, None
if provider_id not in SUPPORTED_PROVIDER_IDS:
return None, None, None
if not provider_model:
return None, None, None
return provider_id, provider_model, None
def resolve_messages_request(
self, request: MessagesRequest
) -> RoutedMessagesRequest:
"""Return an internal routed request context."""
resolved = self.resolve(request.model)
routed = request.model_copy(deep=True)
routed.model = resolved.provider_model
return RoutedMessagesRequest(request=routed, resolved=resolved)
def resolve_token_count_request(
self, request: TokenCountRequest
) -> RoutedTokenCountRequest:
"""Return an internal token-count request context."""
resolved = self.resolve(request.model)
routed = request.model_copy(
update={"model": resolved.provider_model}, deep=True
)
return RoutedTokenCountRequest(request=routed, resolved=resolved)