mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-06-01 22:09:04 +02:00
Refactor provider routing and smoke coverage
This commit is contained in:
+18
-128
@@ -5,14 +5,10 @@ from loguru import logger
|
||||
|
||||
from config.settings import Settings
|
||||
from config.settings import get_settings as _get_settings
|
||||
from providers.base import BaseProvider, ProviderConfig
|
||||
from providers.base import BaseProvider
|
||||
from providers.common import get_user_facing_error_message
|
||||
from providers.deepseek import DEEPSEEK_BASE_URL, DeepSeekProvider
|
||||
from providers.exceptions import AuthenticationError
|
||||
from providers.llamacpp import LlamaCppProvider
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
|
||||
from providers.open_router import OPENROUTER_BASE_URL, OpenRouterProvider
|
||||
from providers.registry import PROVIDER_DESCRIPTORS, ProviderRegistry
|
||||
|
||||
# Provider registry: keyed by provider type string, lazily populated
|
||||
_providers: dict[str, BaseProvider] = {}
|
||||
@@ -23,132 +19,27 @@ def get_settings() -> Settings:
|
||||
return _get_settings()
|
||||
|
||||
|
||||
def _get_proxy_value(settings: Settings, attr_name: str) -> str:
|
||||
"""Return a provider proxy only when configured as a string."""
|
||||
value = getattr(settings, attr_name, "")
|
||||
return value if isinstance(value, str) else ""
|
||||
|
||||
|
||||
def _create_provider_for_type(provider_type: str, settings: Settings) -> BaseProvider:
|
||||
"""Construct and return a new provider instance for the given provider type."""
|
||||
_proxy_map = {
|
||||
"nvidia_nim": _get_proxy_value(settings, "nvidia_nim_proxy"),
|
||||
"open_router": _get_proxy_value(settings, "open_router_proxy"),
|
||||
"lmstudio": _get_proxy_value(settings, "lmstudio_proxy"),
|
||||
"llamacpp": _get_proxy_value(settings, "llamacpp_proxy"),
|
||||
}
|
||||
proxy = _proxy_map.get(provider_type, "")
|
||||
|
||||
if provider_type == "nvidia_nim":
|
||||
if not settings.nvidia_nim_api_key or not settings.nvidia_nim_api_key.strip():
|
||||
raise AuthenticationError(
|
||||
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
|
||||
"Get a key at https://build.nvidia.com/settings/api-keys"
|
||||
)
|
||||
config = ProviderConfig(
|
||||
api_key=settings.nvidia_nim_api_key,
|
||||
base_url=NVIDIA_NIM_BASE_URL,
|
||||
rate_limit=settings.provider_rate_limit,
|
||||
rate_window=settings.provider_rate_window,
|
||||
max_concurrency=settings.provider_max_concurrency,
|
||||
http_read_timeout=settings.http_read_timeout,
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
enable_thinking=settings.enable_thinking,
|
||||
proxy=proxy,
|
||||
)
|
||||
return NvidiaNimProvider(config, nim_settings=settings.nim)
|
||||
if provider_type == "open_router":
|
||||
if not settings.open_router_api_key or not settings.open_router_api_key.strip():
|
||||
raise AuthenticationError(
|
||||
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
|
||||
"Get a key at https://openrouter.ai/keys"
|
||||
)
|
||||
config = ProviderConfig(
|
||||
api_key=settings.open_router_api_key,
|
||||
base_url=OPENROUTER_BASE_URL,
|
||||
rate_limit=settings.provider_rate_limit,
|
||||
rate_window=settings.provider_rate_window,
|
||||
max_concurrency=settings.provider_max_concurrency,
|
||||
http_read_timeout=settings.http_read_timeout,
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
enable_thinking=settings.enable_thinking,
|
||||
proxy=proxy,
|
||||
)
|
||||
return OpenRouterProvider(config)
|
||||
if provider_type == "deepseek":
|
||||
if not settings.deepseek_api_key or not settings.deepseek_api_key.strip():
|
||||
raise AuthenticationError(
|
||||
"DEEPSEEK_API_KEY is not set. Add it to your .env file. "
|
||||
"Get a key at https://platform.deepseek.com/api_keys"
|
||||
)
|
||||
config = ProviderConfig(
|
||||
api_key=settings.deepseek_api_key,
|
||||
base_url=DEEPSEEK_BASE_URL,
|
||||
rate_limit=settings.provider_rate_limit,
|
||||
rate_window=settings.provider_rate_window,
|
||||
max_concurrency=settings.provider_max_concurrency,
|
||||
http_read_timeout=settings.http_read_timeout,
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
enable_thinking=settings.enable_thinking,
|
||||
)
|
||||
return DeepSeekProvider(config)
|
||||
if provider_type == "lmstudio":
|
||||
config = ProviderConfig(
|
||||
api_key="lm-studio",
|
||||
base_url=settings.lm_studio_base_url,
|
||||
rate_limit=settings.provider_rate_limit,
|
||||
rate_window=settings.provider_rate_window,
|
||||
max_concurrency=settings.provider_max_concurrency,
|
||||
http_read_timeout=settings.http_read_timeout,
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
enable_thinking=settings.enable_thinking,
|
||||
proxy=proxy,
|
||||
)
|
||||
return LMStudioProvider(config)
|
||||
if provider_type == "llamacpp":
|
||||
config = ProviderConfig(
|
||||
api_key="llamacpp",
|
||||
base_url=settings.llamacpp_base_url,
|
||||
rate_limit=settings.provider_rate_limit,
|
||||
rate_window=settings.provider_rate_window,
|
||||
max_concurrency=settings.provider_max_concurrency,
|
||||
http_read_timeout=settings.http_read_timeout,
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
enable_thinking=settings.enable_thinking,
|
||||
proxy=proxy,
|
||||
)
|
||||
return LlamaCppProvider(config)
|
||||
logger.error(
|
||||
"Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'deepseek', 'lmstudio', 'llamacpp'",
|
||||
provider_type,
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unknown provider_type: '{provider_type}'. "
|
||||
f"Supported: 'nvidia_nim', 'open_router', 'deepseek', 'lmstudio', 'llamacpp'"
|
||||
)
|
||||
|
||||
|
||||
def get_provider_for_type(provider_type: str) -> BaseProvider:
|
||||
"""Get or create a provider for the given provider type.
|
||||
|
||||
Providers are cached in the registry and reused across requests.
|
||||
"""
|
||||
if provider_type not in _providers:
|
||||
try:
|
||||
_providers[provider_type] = _create_provider_for_type(
|
||||
provider_type, get_settings()
|
||||
)
|
||||
except AuthenticationError as e:
|
||||
raise HTTPException(
|
||||
status_code=503, detail=get_user_facing_error_message(e)
|
||||
) from e
|
||||
try:
|
||||
provider = ProviderRegistry(_providers).get(provider_type, get_settings())
|
||||
except AuthenticationError as e:
|
||||
raise HTTPException(
|
||||
status_code=503, detail=get_user_facing_error_message(e)
|
||||
) from e
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Unknown provider_type: '{}'. Supported: {}",
|
||||
provider_type,
|
||||
", ".join(f"'{key}'" for key in PROVIDER_DESCRIPTORS),
|
||||
)
|
||||
raise
|
||||
if provider_type in _providers:
|
||||
logger.info("Provider initialized: {}", provider_type)
|
||||
return _providers[provider_type]
|
||||
return provider
|
||||
|
||||
|
||||
def require_api_key(
|
||||
@@ -196,7 +87,6 @@ def get_provider() -> BaseProvider:
|
||||
async def cleanup_provider():
|
||||
"""Cleanup all provider resources."""
|
||||
global _providers
|
||||
for provider in _providers.values():
|
||||
await provider.cleanup()
|
||||
await ProviderRegistry(_providers).cleanup()
|
||||
_providers = {}
|
||||
logger.debug("Provider cleanup completed")
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
"""Model routing for Claude-compatible requests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from config.settings import Settings
|
||||
|
||||
from .models.anthropic import MessagesRequest, TokenCountRequest
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ResolvedModel:
|
||||
original_model: str
|
||||
provider_id: str
|
||||
provider_model: str
|
||||
provider_model_ref: str
|
||||
|
||||
|
||||
class ModelRouter:
|
||||
"""Resolve incoming Claude model names to configured provider/model pairs."""
|
||||
|
||||
def __init__(self, settings: Settings):
|
||||
self._settings = settings
|
||||
|
||||
def resolve(self, claude_model_name: str) -> ResolvedModel:
|
||||
provider_model_ref = self._settings.resolve_model(claude_model_name)
|
||||
provider_id = Settings.parse_provider_type(provider_model_ref)
|
||||
provider_model = Settings.parse_model_name(provider_model_ref)
|
||||
if provider_model != claude_model_name:
|
||||
logger.debug(
|
||||
"MODEL MAPPING: '{}' -> '{}'", claude_model_name, provider_model
|
||||
)
|
||||
return ResolvedModel(
|
||||
original_model=claude_model_name,
|
||||
provider_id=provider_id,
|
||||
provider_model=provider_model,
|
||||
provider_model_ref=provider_model_ref,
|
||||
)
|
||||
|
||||
def resolve_messages_request(self, request: MessagesRequest) -> MessagesRequest:
|
||||
"""Return a routed copy of a MessagesRequest."""
|
||||
original_model = request.original_model or request.model
|
||||
resolved = self.resolve(original_model)
|
||||
routed = request.model_copy(deep=True)
|
||||
routed.original_model = resolved.original_model
|
||||
routed.resolved_provider_model = resolved.provider_model_ref
|
||||
routed.model = resolved.provider_model
|
||||
return routed
|
||||
|
||||
def resolve_token_count_request(
|
||||
self, request: TokenCountRequest
|
||||
) -> TokenCountRequest:
|
||||
"""Return a token-count request copy with provider model name applied."""
|
||||
resolved = self.resolve(request.model)
|
||||
return request.model_copy(update={"model": resolved.provider_model}, deep=True)
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from .anthropic import (
|
||||
ContentBlockImage,
|
||||
ContentBlockRedactedThinking,
|
||||
ContentBlockText,
|
||||
ContentBlockThinking,
|
||||
ContentBlockToolResult,
|
||||
@@ -24,6 +25,7 @@ from .responses import (
|
||||
|
||||
__all__ = [
|
||||
"ContentBlockImage",
|
||||
"ContentBlockRedactedThinking",
|
||||
"ContentBlockText",
|
||||
"ContentBlockThinking",
|
||||
"ContentBlockToolResult",
|
||||
|
||||
+11
-29
@@ -3,10 +3,7 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
|
||||
from config.settings import Settings, get_settings
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -44,6 +41,12 @@ class ContentBlockToolResult(BaseModel):
|
||||
class ContentBlockThinking(BaseModel):
|
||||
type: Literal["thinking"]
|
||||
thinking: str
|
||||
signature: str | None = None
|
||||
|
||||
|
||||
class ContentBlockRedactedThinking(BaseModel):
|
||||
type: Literal["redacted_thinking"]
|
||||
data: str
|
||||
|
||||
|
||||
class SystemContent(BaseModel):
|
||||
@@ -64,6 +67,7 @@ class Message(BaseModel):
|
||||
| ContentBlockToolUse
|
||||
| ContentBlockToolResult
|
||||
| ContentBlockThinking
|
||||
| ContentBlockRedactedThinking
|
||||
]
|
||||
)
|
||||
reasoning_content: str | None = None
|
||||
@@ -76,7 +80,9 @@ class Tool(BaseModel):
|
||||
|
||||
|
||||
class ThinkingConfig(BaseModel):
|
||||
enabled: bool = True
|
||||
enabled: bool | None = True
|
||||
type: str | None = None
|
||||
budget_tokens: int | None = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -100,22 +106,6 @@ class MessagesRequest(BaseModel):
|
||||
original_model: str | None = None
|
||||
resolved_provider_model: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def map_model(self) -> MessagesRequest:
|
||||
"""Map any Claude model name to the configured model (model-aware)."""
|
||||
settings = get_settings()
|
||||
if self.original_model is None:
|
||||
self.original_model = self.model
|
||||
|
||||
resolved_full = settings.resolve_model(self.original_model)
|
||||
self.resolved_provider_model = resolved_full
|
||||
self.model = Settings.parse_model_name(resolved_full)
|
||||
|
||||
if self.model != self.original_model:
|
||||
logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'")
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class TokenCountRequest(BaseModel):
|
||||
model: str
|
||||
@@ -124,11 +114,3 @@ class TokenCountRequest(BaseModel):
|
||||
tools: list[Tool] | None = None
|
||||
thinking: ThinkingConfig | None = None
|
||||
tool_choice: dict[str, Any] | None = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model_field(cls, v: str, info) -> str:
|
||||
"""Map any Claude model name to the configured model (model-aware)."""
|
||||
settings = get_settings()
|
||||
resolved_full = settings.resolve_model(v)
|
||||
return Settings.parse_model_name(resolved_full)
|
||||
|
||||
+11
-2
@@ -4,7 +4,12 @@ from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .anthropic import ContentBlockText, ContentBlockThinking, ContentBlockToolUse
|
||||
from .anthropic import (
|
||||
ContentBlockRedactedThinking,
|
||||
ContentBlockText,
|
||||
ContentBlockThinking,
|
||||
ContentBlockToolUse,
|
||||
)
|
||||
|
||||
|
||||
class TokenCountResponse(BaseModel):
|
||||
@@ -37,7 +42,11 @@ class MessagesResponse(BaseModel):
|
||||
model: str
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: list[
|
||||
ContentBlockText | ContentBlockToolUse | ContentBlockThinking | dict[str, Any]
|
||||
ContentBlockText
|
||||
| ContentBlockToolUse
|
||||
| ContentBlockThinking
|
||||
| ContentBlockRedactedThinking
|
||||
| dict[str, Any]
|
||||
]
|
||||
type: Literal["message"] = "message"
|
||||
stop_reason: (
|
||||
|
||||
+22
-85
@@ -1,21 +1,15 @@
|
||||
"""FastAPI route handlers."""
|
||||
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from config.settings import Settings
|
||||
from providers.common import get_user_facing_error_message
|
||||
from providers.exceptions import InvalidRequestError, ProviderError
|
||||
|
||||
from .dependencies import get_provider_for_type, get_settings, require_api_key
|
||||
from .models.anthropic import MessagesRequest, TokenCountRequest
|
||||
from .models.responses import ModelResponse, ModelsListResponse, TokenCountResponse
|
||||
from .optimization_handlers import try_optimizations
|
||||
from .models.responses import ModelResponse, ModelsListResponse
|
||||
from .request_utils import get_token_count
|
||||
from .services import ClaudeProxyService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -59,6 +53,17 @@ SUPPORTED_CLAUDE_MODELS = [
|
||||
]
|
||||
|
||||
|
||||
def get_proxy_service(
|
||||
settings: Settings = Depends(get_settings),
|
||||
) -> ClaudeProxyService:
|
||||
"""Build the request service for route handlers."""
|
||||
return ClaudeProxyService(
|
||||
settings,
|
||||
provider_getter=get_provider_for_type,
|
||||
token_counter=get_token_count,
|
||||
)
|
||||
|
||||
|
||||
def _probe_response(allow: str) -> Response:
|
||||
"""Return an empty success response for compatibility probes."""
|
||||
return Response(status_code=204, headers={"Allow": allow})
|
||||
@@ -70,61 +75,12 @@ def _probe_response(allow: str) -> Response:
|
||||
@router.post("/v1/messages")
|
||||
async def create_message(
|
||||
request_data: MessagesRequest,
|
||||
raw_request: Request,
|
||||
settings: Settings = Depends(get_settings),
|
||||
_raw_request: Request,
|
||||
service: ClaudeProxyService = Depends(get_proxy_service),
|
||||
_auth=Depends(require_api_key),
|
||||
):
|
||||
"""Create a message (always streaming)."""
|
||||
|
||||
try:
|
||||
if not request_data.messages:
|
||||
raise InvalidRequestError("messages cannot be empty")
|
||||
|
||||
optimized = try_optimizations(request_data, settings)
|
||||
if optimized is not None:
|
||||
return optimized
|
||||
logger.debug("No optimization matched, routing to provider")
|
||||
|
||||
# Resolve provider from the model-aware mapping
|
||||
provider_type = Settings.parse_provider_type(
|
||||
request_data.resolved_provider_model or settings.model
|
||||
)
|
||||
provider = get_provider_for_type(provider_type)
|
||||
|
||||
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
||||
logger.info(
|
||||
"API_REQUEST: request_id={} model={} messages={}",
|
||||
request_id,
|
||||
request_data.model,
|
||||
len(request_data.messages),
|
||||
)
|
||||
logger.debug("FULL_PAYLOAD [{}]: {}", request_id, request_data.model_dump())
|
||||
|
||||
input_tokens = get_token_count(
|
||||
request_data.messages, request_data.system, request_data.tools
|
||||
)
|
||||
return StreamingResponse(
|
||||
provider.stream_response(
|
||||
request_data,
|
||||
input_tokens=input_tokens,
|
||||
request_id=request_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
except ProviderError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e!s}\n{traceback.format_exc()}")
|
||||
raise HTTPException(
|
||||
status_code=getattr(e, "status_code", 500),
|
||||
detail=get_user_facing_error_message(e),
|
||||
) from e
|
||||
return service.create_message(request_data)
|
||||
|
||||
|
||||
@router.api_route("/v1/messages", methods=["HEAD", "OPTIONS"])
|
||||
@@ -134,32 +90,13 @@ async def probe_messages(_auth=Depends(require_api_key)):
|
||||
|
||||
|
||||
@router.post("/v1/messages/count_tokens")
|
||||
async def count_tokens(request_data: TokenCountRequest, _auth=Depends(require_api_key)):
|
||||
async def count_tokens(
|
||||
request_data: TokenCountRequest,
|
||||
service: ClaudeProxyService = Depends(get_proxy_service),
|
||||
_auth=Depends(require_api_key),
|
||||
):
|
||||
"""Count tokens for a request."""
|
||||
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
||||
with logger.contextualize(request_id=request_id):
|
||||
try:
|
||||
tokens = get_token_count(
|
||||
request_data.messages, request_data.system, request_data.tools
|
||||
)
|
||||
logger.info(
|
||||
"COUNT_TOKENS: request_id={} model={} messages={} input_tokens={}",
|
||||
request_id,
|
||||
getattr(request_data, "model", "unknown"),
|
||||
len(request_data.messages),
|
||||
tokens,
|
||||
)
|
||||
return TokenCountResponse(input_tokens=tokens)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"COUNT_TOKENS_ERROR: request_id={} error={}\n{}",
|
||||
request_id,
|
||||
get_user_facing_error_message(e),
|
||||
traceback.format_exc(),
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=get_user_facing_error_message(e)
|
||||
) from e
|
||||
return service.count_tokens(request_data)
|
||||
|
||||
|
||||
@router.api_route("/v1/messages/count_tokens", methods=["HEAD", "OPTIONS"])
|
||||
|
||||
+128
@@ -0,0 +1,128 @@
|
||||
"""Application services for the Claude-compatible API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from config.settings import Settings
|
||||
from providers.base import BaseProvider
|
||||
from providers.common import get_user_facing_error_message
|
||||
from providers.exceptions import InvalidRequestError, ProviderError
|
||||
|
||||
from .model_router import ModelRouter
|
||||
from .models.anthropic import MessagesRequest, TokenCountRequest
|
||||
from .models.responses import TokenCountResponse
|
||||
from .optimization_handlers import try_optimizations
|
||||
from .request_utils import get_token_count
|
||||
|
||||
TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int]
|
||||
|
||||
ProviderGetter = Callable[[str], BaseProvider]
|
||||
|
||||
|
||||
class ClaudeProxyService:
|
||||
"""Coordinate request optimization, model routing, token count, and providers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
provider_getter: ProviderGetter,
|
||||
model_router: ModelRouter | None = None,
|
||||
token_counter: TokenCounter = get_token_count,
|
||||
):
|
||||
self._settings = settings
|
||||
self._provider_getter = provider_getter
|
||||
self._model_router = model_router or ModelRouter(settings)
|
||||
self._token_counter = token_counter
|
||||
|
||||
def create_message(self, request_data: MessagesRequest) -> object:
|
||||
"""Create a message response or streaming response."""
|
||||
try:
|
||||
if not request_data.messages:
|
||||
raise InvalidRequestError("messages cannot be empty")
|
||||
|
||||
routed_request = self._model_router.resolve_messages_request(request_data)
|
||||
|
||||
optimized = try_optimizations(routed_request, self._settings)
|
||||
if optimized is not None:
|
||||
return optimized
|
||||
logger.debug("No optimization matched, routing to provider")
|
||||
|
||||
provider_type = (
|
||||
routed_request.resolved_provider_model or self._settings.model
|
||||
).split("/", 1)[0]
|
||||
provider = self._provider_getter(provider_type)
|
||||
|
||||
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
||||
logger.info(
|
||||
"API_REQUEST: request_id={} model={} messages={}",
|
||||
request_id,
|
||||
routed_request.model,
|
||||
len(routed_request.messages),
|
||||
)
|
||||
logger.debug(
|
||||
"FULL_PAYLOAD [{}]: {}", request_id, routed_request.model_dump()
|
||||
)
|
||||
|
||||
input_tokens = self._token_counter(
|
||||
routed_request.messages, routed_request.system, routed_request.tools
|
||||
)
|
||||
return StreamingResponse(
|
||||
provider.stream_response(
|
||||
routed_request,
|
||||
input_tokens=input_tokens,
|
||||
request_id=request_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"X-Accel-Buffering": "no",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
except ProviderError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e!s}\n{traceback.format_exc()}")
|
||||
raise HTTPException(
|
||||
status_code=getattr(e, "status_code", 500),
|
||||
detail=get_user_facing_error_message(e),
|
||||
) from e
|
||||
|
||||
def count_tokens(self, request_data: TokenCountRequest) -> TokenCountResponse:
|
||||
"""Count tokens for a request after applying configured model routing."""
|
||||
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
||||
with logger.contextualize(request_id=request_id):
|
||||
try:
|
||||
routed_request = self._model_router.resolve_token_count_request(
|
||||
request_data
|
||||
)
|
||||
tokens = self._token_counter(
|
||||
routed_request.messages, routed_request.system, routed_request.tools
|
||||
)
|
||||
logger.info(
|
||||
"COUNT_TOKENS: request_id={} model={} messages={} input_tokens={}",
|
||||
request_id,
|
||||
routed_request.model,
|
||||
len(routed_request.messages),
|
||||
tokens,
|
||||
)
|
||||
return TokenCountResponse(input_tokens=tokens)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"COUNT_TOKENS_ERROR: request_id={} error={}\n{}",
|
||||
request_id,
|
||||
get_user_facing_error_message(e),
|
||||
traceback.format_exc(),
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=get_user_facing_error_message(e)
|
||||
) from e
|
||||
Reference in New Issue
Block a user