Refactor provider routing and smoke coverage

This commit is contained in:
Alishahryar1
2026-04-24 19:34:34 -07:00
parent efa9f36c3a
commit 66ef23072c
68 changed files with 5115 additions and 1312 deletions
+18 -128
View File
@@ -5,14 +5,10 @@ from loguru import logger
from config.settings import Settings
from config.settings import get_settings as _get_settings
from providers.base import BaseProvider, ProviderConfig
from providers.base import BaseProvider
from providers.common import get_user_facing_error_message
from providers.deepseek import DEEPSEEK_BASE_URL, DeepSeekProvider
from providers.exceptions import AuthenticationError
from providers.llamacpp import LlamaCppProvider
from providers.lmstudio import LMStudioProvider
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
from providers.open_router import OPENROUTER_BASE_URL, OpenRouterProvider
from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry
# Provider registry: keyed by provider type string, lazily populated
_providers: dict[str, BaseProvider] = {}
@@ -23,132 +19,27 @@ def get_settings() -> Settings:
return _get_settings()
def _get_proxy_value(settings: Settings, attr_name: str) -> str:
"""Return a provider proxy only when configured as a string."""
value = getattr(settings, attr_name, "")
return value if isinstance(value, str) else ""
def _create_provider_for_type(provider_type: str, settings: Settings) -> BaseProvider:
"""Construct and return a new provider instance for the given provider type."""
_proxy_map = {
"nvidia_nim": _get_proxy_value(settings, "nvidia_nim_proxy"),
"open_router": _get_proxy_value(settings, "open_router_proxy"),
"lmstudio": _get_proxy_value(settings, "lmstudio_proxy"),
"llamacpp": _get_proxy_value(settings, "llamacpp_proxy"),
}
proxy = _proxy_map.get(provider_type, "")
if provider_type == "nvidia_nim":
if not settings.nvidia_nim_api_key or not settings.nvidia_nim_api_key.strip():
raise AuthenticationError(
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
"Get a key at https://build.nvidia.com/settings/api-keys"
)
config = ProviderConfig(
api_key=settings.nvidia_nim_api_key,
base_url=NVIDIA_NIM_BASE_URL,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
enable_thinking=settings.enable_thinking,
proxy=proxy,
)
return NvidiaNimProvider(config, nim_settings=settings.nim)
if provider_type == "open_router":
if not settings.open_router_api_key or not settings.open_router_api_key.strip():
raise AuthenticationError(
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
"Get a key at https://openrouter.ai/keys"
)
config = ProviderConfig(
api_key=settings.open_router_api_key,
base_url=OPENROUTER_BASE_URL,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
enable_thinking=settings.enable_thinking,
proxy=proxy,
)
return OpenRouterProvider(config)
if provider_type == "deepseek":
if not settings.deepseek_api_key or not settings.deepseek_api_key.strip():
raise AuthenticationError(
"DEEPSEEK_API_KEY is not set. Add it to your .env file. "
"Get a key at https://platform.deepseek.com/api_keys"
)
config = ProviderConfig(
api_key=settings.deepseek_api_key,
base_url=DEEPSEEK_BASE_URL,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
enable_thinking=settings.enable_thinking,
)
return DeepSeekProvider(config)
if provider_type == "lmstudio":
config = ProviderConfig(
api_key="lm-studio",
base_url=settings.lm_studio_base_url,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
enable_thinking=settings.enable_thinking,
proxy=proxy,
)
return LMStudioProvider(config)
if provider_type == "llamacpp":
config = ProviderConfig(
api_key="llamacpp",
base_url=settings.llamacpp_base_url,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
enable_thinking=settings.enable_thinking,
proxy=proxy,
)
return LlamaCppProvider(config)
logger.error(
"Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'deepseek', 'lmstudio', 'llamacpp'",
provider_type,
)
raise ValueError(
f"Unknown provider_type: '{provider_type}'. "
f"Supported: 'nvidia_nim', 'open_router', 'deepseek', 'lmstudio', 'llamacpp'"
)
def get_provider_for_type(provider_type: str) -> BaseProvider:
"""Get or create a provider for the given provider type.
Providers are cached in the registry and reused across requests.
"""
if provider_type not in _providers:
try:
_providers[provider_type] = _create_provider_for_type(
provider_type, get_settings()
)
except AuthenticationError as e:
raise HTTPException(
status_code=503, detail=get_user_facing_error_message(e)
) from e
try:
provider = ProviderRegistry(_providers).get(provider_type, get_settings())
except AuthenticationError as e:
raise HTTPException(
status_code=503, detail=get_user_facing_error_message(e)
) from e
except ValueError:
logger.error(
"Unknown provider_type: '{}'. Supported: {}",
provider_type,
", ".join(f"'{key}'" for key in PROVIDER_DESCRIPTORS),
)
raise
if provider_type in _providers:
logger.info("Provider initialized: {}", provider_type)
return _providers[provider_type]
return provider
def require_api_key(
@@ -196,7 +87,6 @@ def get_provider() -> BaseProvider:
async def cleanup_provider():
"""Cleanup all provider resources."""
global _providers
for provider in _providers.values():
await provider.cleanup()
await ProviderRegistry(_providers).cleanup()
_providers = {}
logger.debug("Provider cleanup completed")
+58
View File
@@ -0,0 +1,58 @@
"""Model routing for Claude-compatible requests."""
from __future__ import annotations
from dataclasses import dataclass
from loguru import logger
from config.settings import Settings
from .models.anthropic import MessagesRequest, TokenCountRequest
@dataclass(frozen=True, slots=True)
class ResolvedModel:
original_model: str
provider_id: str
provider_model: str
provider_model_ref: str
class ModelRouter:
"""Resolve incoming Claude model names to configured provider/model pairs."""
def __init__(self, settings: Settings):
self._settings = settings
def resolve(self, claude_model_name: str) -> ResolvedModel:
provider_model_ref = self._settings.resolve_model(claude_model_name)
provider_id = Settings.parse_provider_type(provider_model_ref)
provider_model = Settings.parse_model_name(provider_model_ref)
if provider_model != claude_model_name:
logger.debug(
"MODEL MAPPING: '{}' -> '{}'", claude_model_name, provider_model
)
return ResolvedModel(
original_model=claude_model_name,
provider_id=provider_id,
provider_model=provider_model,
provider_model_ref=provider_model_ref,
)
def resolve_messages_request(self, request: MessagesRequest) -> MessagesRequest:
"""Return a routed copy of a MessagesRequest."""
original_model = request.original_model or request.model
resolved = self.resolve(original_model)
routed = request.model_copy(deep=True)
routed.original_model = resolved.original_model
routed.resolved_provider_model = resolved.provider_model_ref
routed.model = resolved.provider_model
return routed
def resolve_token_count_request(
self, request: TokenCountRequest
) -> TokenCountRequest:
"""Return a token-count request copy with provider model name applied."""
resolved = self.resolve(request.model)
return request.model_copy(update={"model": resolved.provider_model}, deep=True)
+2
View File
@@ -2,6 +2,7 @@
from .anthropic import (
ContentBlockImage,
ContentBlockRedactedThinking,
ContentBlockText,
ContentBlockThinking,
ContentBlockToolResult,
@@ -24,6 +25,7 @@ from .responses import (
__all__ = [
"ContentBlockImage",
"ContentBlockRedactedThinking",
"ContentBlockText",
"ContentBlockThinking",
"ContentBlockToolResult",
+11 -29
View File
@@ -3,10 +3,7 @@
from enum import StrEnum
from typing import Any, Literal
from loguru import logger
from pydantic import BaseModel, field_validator, model_validator
from config.settings import Settings, get_settings
from pydantic import BaseModel
# =============================================================================
@@ -44,6 +41,12 @@ class ContentBlockToolResult(BaseModel):
class ContentBlockThinking(BaseModel):
type: Literal["thinking"]
thinking: str
signature: str | None = None
class ContentBlockRedactedThinking(BaseModel):
type: Literal["redacted_thinking"]
data: str
class SystemContent(BaseModel):
@@ -64,6 +67,7 @@ class Message(BaseModel):
| ContentBlockToolUse
| ContentBlockToolResult
| ContentBlockThinking
| ContentBlockRedactedThinking
]
)
reasoning_content: str | None = None
@@ -76,7 +80,9 @@ class Tool(BaseModel):
class ThinkingConfig(BaseModel):
enabled: bool = True
enabled: bool | None = True
type: str | None = None
budget_tokens: int | None = None
# =============================================================================
@@ -100,22 +106,6 @@ class MessagesRequest(BaseModel):
original_model: str | None = None
resolved_provider_model: str | None = None
@model_validator(mode="after")
def map_model(self) -> MessagesRequest:
"""Map any Claude model name to the configured model (model-aware)."""
settings = get_settings()
if self.original_model is None:
self.original_model = self.model
resolved_full = settings.resolve_model(self.original_model)
self.resolved_provider_model = resolved_full
self.model = Settings.parse_model_name(resolved_full)
if self.model != self.original_model:
logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'")
return self
class TokenCountRequest(BaseModel):
model: str
@@ -124,11 +114,3 @@ class TokenCountRequest(BaseModel):
tools: list[Tool] | None = None
thinking: ThinkingConfig | None = None
tool_choice: dict[str, Any] | None = None
@field_validator("model")
@classmethod
def validate_model_field(cls, v: str, info) -> str:
"""Map any Claude model name to the configured model (model-aware)."""
settings = get_settings()
resolved_full = settings.resolve_model(v)
return Settings.parse_model_name(resolved_full)
+11 -2
View File
@@ -4,7 +4,12 @@ from typing import Any, Literal
from pydantic import BaseModel
from .anthropic import ContentBlockText, ContentBlockThinking, ContentBlockToolUse
from .anthropic import (
ContentBlockRedactedThinking,
ContentBlockText,
ContentBlockThinking,
ContentBlockToolUse,
)
class TokenCountResponse(BaseModel):
@@ -37,7 +42,11 @@ class MessagesResponse(BaseModel):
model: str
role: Literal["assistant"] = "assistant"
content: list[
ContentBlockText | ContentBlockToolUse | ContentBlockThinking | dict[str, Any]
ContentBlockText
| ContentBlockToolUse
| ContentBlockThinking
| ContentBlockRedactedThinking
| dict[str, Any]
]
type: Literal["message"] = "message"
stop_reason: (
+22 -85
View File
@@ -1,21 +1,15 @@
"""FastAPI route handlers."""
import traceback
import uuid
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from fastapi.responses import StreamingResponse
from loguru import logger
from config.settings import Settings
from providers.common import get_user_facing_error_message
from providers.exceptions import InvalidRequestError, ProviderError
from .dependencies import get_provider_for_type, get_settings, require_api_key
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import ModelResponse, ModelsListResponse, TokenCountResponse
from .optimization_handlers import try_optimizations
from .models.responses import ModelResponse, ModelsListResponse
from .request_utils import get_token_count
from .services import ClaudeProxyService
router = APIRouter()
@@ -59,6 +53,17 @@ SUPPORTED_CLAUDE_MODELS = [
]
def get_proxy_service(
settings: Settings = Depends(get_settings),
) -> ClaudeProxyService:
"""Build the request service for route handlers."""
return ClaudeProxyService(
settings,
provider_getter=get_provider_for_type,
token_counter=get_token_count,
)
def _probe_response(allow: str) -> Response:
"""Return an empty success response for compatibility probes."""
return Response(status_code=204, headers={"Allow": allow})
@@ -70,61 +75,12 @@ def _probe_response(allow: str) -> Response:
@router.post("/v1/messages")
async def create_message(
request_data: MessagesRequest,
raw_request: Request,
settings: Settings = Depends(get_settings),
_raw_request: Request,
service: ClaudeProxyService = Depends(get_proxy_service),
_auth=Depends(require_api_key),
):
"""Create a message (always streaming)."""
try:
if not request_data.messages:
raise InvalidRequestError("messages cannot be empty")
optimized = try_optimizations(request_data, settings)
if optimized is not None:
return optimized
logger.debug("No optimization matched, routing to provider")
# Resolve provider from the model-aware mapping
provider_type = Settings.parse_provider_type(
request_data.resolved_provider_model or settings.model
)
provider = get_provider_for_type(provider_type)
request_id = f"req_{uuid.uuid4().hex[:12]}"
logger.info(
"API_REQUEST: request_id={} model={} messages={}",
request_id,
request_data.model,
len(request_data.messages),
)
logger.debug("FULL_PAYLOAD [{}]: {}", request_id, request_data.model_dump())
input_tokens = get_token_count(
request_data.messages, request_data.system, request_data.tools
)
return StreamingResponse(
provider.stream_response(
request_data,
input_tokens=input_tokens,
request_id=request_id,
),
media_type="text/event-stream",
headers={
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
except ProviderError:
raise
except Exception as e:
logger.error(f"Error: {e!s}\n{traceback.format_exc()}")
raise HTTPException(
status_code=getattr(e, "status_code", 500),
detail=get_user_facing_error_message(e),
) from e
return service.create_message(request_data)
@router.api_route("/v1/messages", methods=["HEAD", "OPTIONS"])
@@ -134,32 +90,13 @@ async def probe_messages(_auth=Depends(require_api_key)):
@router.post("/v1/messages/count_tokens")
async def count_tokens(request_data: TokenCountRequest, _auth=Depends(require_api_key)):
async def count_tokens(
request_data: TokenCountRequest,
service: ClaudeProxyService = Depends(get_proxy_service),
_auth=Depends(require_api_key),
):
"""Count tokens for a request."""
request_id = f"req_{uuid.uuid4().hex[:12]}"
with logger.contextualize(request_id=request_id):
try:
tokens = get_token_count(
request_data.messages, request_data.system, request_data.tools
)
logger.info(
"COUNT_TOKENS: request_id={} model={} messages={} input_tokens={}",
request_id,
getattr(request_data, "model", "unknown"),
len(request_data.messages),
tokens,
)
return TokenCountResponse(input_tokens=tokens)
except Exception as e:
logger.error(
"COUNT_TOKENS_ERROR: request_id={} error={}\n{}",
request_id,
get_user_facing_error_message(e),
traceback.format_exc(),
)
raise HTTPException(
status_code=500, detail=get_user_facing_error_message(e)
) from e
return service.count_tokens(request_data)
@router.api_route("/v1/messages/count_tokens", methods=["HEAD", "OPTIONS"])
+128
View File
@@ -0,0 +1,128 @@
"""Application services for the Claude-compatible API."""
from __future__ import annotations
import traceback
import uuid
from collections.abc import Callable
from typing import Any
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger
from config.settings import Settings
from providers.base import BaseProvider
from providers.common import get_user_facing_error_message
from providers.exceptions import InvalidRequestError, ProviderError
from .model_router import ModelRouter
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import TokenCountResponse
from .optimization_handlers import try_optimizations
from .request_utils import get_token_count
TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int]
ProviderGetter = Callable[[str], BaseProvider]
class ClaudeProxyService:
"""Coordinate request optimization, model routing, token count, and providers."""
def __init__(
self,
settings: Settings,
provider_getter: ProviderGetter,
model_router: ModelRouter | None = None,
token_counter: TokenCounter = get_token_count,
):
self._settings = settings
self._provider_getter = provider_getter
self._model_router = model_router or ModelRouter(settings)
self._token_counter = token_counter
def create_message(self, request_data: MessagesRequest) -> object:
"""Create a message response or streaming response."""
try:
if not request_data.messages:
raise InvalidRequestError("messages cannot be empty")
routed_request = self._model_router.resolve_messages_request(request_data)
optimized = try_optimizations(routed_request, self._settings)
if optimized is not None:
return optimized
logger.debug("No optimization matched, routing to provider")
provider_type = (
routed_request.resolved_provider_model or self._settings.model
).split("/", 1)[0]
provider = self._provider_getter(provider_type)
request_id = f"req_{uuid.uuid4().hex[:12]}"
logger.info(
"API_REQUEST: request_id={} model={} messages={}",
request_id,
routed_request.model,
len(routed_request.messages),
)
logger.debug(
"FULL_PAYLOAD [{}]: {}", request_id, routed_request.model_dump()
)
input_tokens = self._token_counter(
routed_request.messages, routed_request.system, routed_request.tools
)
return StreamingResponse(
provider.stream_response(
routed_request,
input_tokens=input_tokens,
request_id=request_id,
),
media_type="text/event-stream",
headers={
"X-Accel-Buffering": "no",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
except ProviderError:
raise
except Exception as e:
logger.error(f"Error: {e!s}\n{traceback.format_exc()}")
raise HTTPException(
status_code=getattr(e, "status_code", 500),
detail=get_user_facing_error_message(e),
) from e
def count_tokens(self, request_data: TokenCountRequest) -> TokenCountResponse:
"""Count tokens for a request after applying configured model routing."""
request_id = f"req_{uuid.uuid4().hex[:12]}"
with logger.contextualize(request_id=request_id):
try:
routed_request = self._model_router.resolve_token_count_request(
request_data
)
tokens = self._token_counter(
routed_request.messages, routed_request.system, routed_request.tools
)
logger.info(
"COUNT_TOKENS: request_id={} model={} messages={} input_tokens={}",
request_id,
routed_request.model,
len(routed_request.messages),
tokens,
)
return TokenCountResponse(input_tokens=tokens)
except Exception as e:
logger.error(
"COUNT_TOKENS_ERROR: request_id={} error={}\n{}",
request_id,
get_user_facing_error_message(e),
traceback.format_exc(),
)
raise HTTPException(
status_code=500, detail=get_user_facing_error_message(e)
) from e