Added claude-code native model picker

This commit is contained in:
Alishahryar1
2026-04-30 20:34:35 -07:00
parent 7d80cc3426
commit 72b34ad57c
12 changed files with 452 additions and 216 deletions
+46
View File
@@ -6,10 +6,13 @@ from dataclasses import dataclass
from loguru import logger
from config.provider_ids import SUPPORTED_PROVIDER_IDS
from config.settings import Settings
from .models.anthropic import MessagesRequest, TokenCountRequest
GATEWAY_MODEL_ID_PREFIX = "anthropic"
@dataclass(frozen=True, slots=True)
class ResolvedModel:
@@ -39,6 +42,25 @@ class ModelRouter:
self._settings = settings
def resolve(self, claude_model_name: str) -> ResolvedModel:
direct_provider_id, direct_provider_model = self._direct_provider_model(
claude_model_name
)
if direct_provider_id is not None and direct_provider_model is not None:
thinking_enabled = self._settings.resolve_thinking(direct_provider_model)
logger.debug(
"MODEL DIRECT: '{}' -> provider='{}' model='{}'",
claude_model_name,
direct_provider_id,
direct_provider_model,
)
return ResolvedModel(
original_model=claude_model_name,
provider_id=direct_provider_id,
provider_model=direct_provider_model,
provider_model_ref=claude_model_name,
thinking_enabled=thinking_enabled,
)
provider_model_ref = self._settings.resolve_model(claude_model_name)
thinking_enabled = self._settings.resolve_thinking(claude_model_name)
provider_id = Settings.parse_provider_type(provider_model_ref)
@@ -55,6 +77,30 @@ class ModelRouter:
thinking_enabled=thinking_enabled,
)
def _direct_provider_model(self, model_name: str) -> tuple[str | None, str | None]:
provider_id, separator, provider_model = model_name.partition("/")
if not separator:
return None, None
if provider_id == GATEWAY_MODEL_ID_PREFIX:
return self._gateway_encoded_provider_model(provider_model)
if provider_id not in SUPPORTED_PROVIDER_IDS:
return None, None
if not provider_model:
return None, None
return provider_id, provider_model
def _gateway_encoded_provider_model(
self, model_name: str
) -> tuple[str | None, str | None]:
provider_id, separator, provider_model = model_name.partition("/")
if not separator:
return None, None
if provider_id not in SUPPORTED_PROVIDER_IDS:
return None, None
if not provider_model:
return None, None
return provider_id, provider_model
def resolve_messages_request(
self, request: MessagesRequest
) -> RoutedMessagesRequest:
+70 -8
View File
@@ -5,6 +5,7 @@ from loguru import logger
from config.settings import Settings
from core.anthropic import get_token_count
from providers.registry import ProviderRegistry
from . import dependencies
from .dependencies import get_settings, require_api_key
@@ -14,6 +15,9 @@ from .services import ClaudeProxyService
router = APIRouter()
DISCOVERED_MODEL_CREATED_AT = "1970-01-01T00:00:00Z"
GATEWAY_MODEL_ID_PREFIX = "anthropic"
SUPPORTED_CLAUDE_MODELS = [
ModelResponse(
@@ -73,6 +77,63 @@ def _probe_response(allow: str) -> Response:
return Response(status_code=204, headers={"Allow": allow})
def _gateway_model_id(provider_model_ref: str) -> str:
return f"{GATEWAY_MODEL_ID_PREFIX}/{provider_model_ref}"
def _discovered_model_response(model_id: str, *, display_name: str) -> ModelResponse:
return ModelResponse(
id=model_id,
display_name=display_name,
created_at=DISCOVERED_MODEL_CREATED_AT,
)
def _append_unique_model(
models: list[ModelResponse], seen: set[str], model: ModelResponse
) -> None:
if model.id in seen:
return
seen.add(model.id)
models.append(model)
def _build_models_list_response(
settings: Settings, provider_registry: ProviderRegistry | None
) -> ModelsListResponse:
models: list[ModelResponse] = []
seen: set[str] = set()
for ref in settings.configured_chat_model_refs():
_append_unique_model(
models,
seen,
_discovered_model_response(
_gateway_model_id(ref.model_ref), display_name=ref.model_ref
),
)
if provider_registry is not None:
for model_ref in provider_registry.cached_prefixed_model_refs():
_append_unique_model(
models,
seen,
_discovered_model_response(
_gateway_model_id(model_ref), display_name=model_ref
),
)
for model in SUPPORTED_CLAUDE_MODELS:
_append_unique_model(models, seen, model)
return ModelsListResponse(
data=models,
first_id=models[0].id if models else None,
has_more=False,
last_id=models[-1].id if models else None,
)
# =============================================================================
# Routes
# =============================================================================
@@ -139,14 +200,15 @@ async def probe_health():
@router.get("/v1/models", response_model=ModelsListResponse)
async def list_models(_auth=Depends(require_api_key)):
"""List the Claude model ids this proxy advertises for compatibility."""
return ModelsListResponse(
data=SUPPORTED_CLAUDE_MODELS,
first_id=SUPPORTED_CLAUDE_MODELS[0].id if SUPPORTED_CLAUDE_MODELS else None,
has_more=False,
last_id=SUPPORTED_CLAUDE_MODELS[-1].id if SUPPORTED_CLAUDE_MODELS else None,
)
async def list_models(
request: Request,
settings: Settings = Depends(get_settings),
_auth=Depends(require_api_key),
):
"""List the model ids this proxy advertises to Claude-compatible clients."""
registry = getattr(request.app.state, "provider_registry", None)
provider_registry = registry if isinstance(registry, ProviderRegistry) else None
return _build_models_list_response(settings, provider_registry)
@router.post("/stop")
+1
View File
@@ -105,6 +105,7 @@ class AppRuntime:
try:
warn_if_process_auth_token(self.settings)
await self._provider_registry.validate_configured_models(self.settings)
self._provider_registry.start_model_list_refresh(self.settings)
await self._start_messaging_if_configured()
self._publish_state()
except Exception as exc: