mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-06-02 06:13:46 +02:00
Per claude model mapping (#66)
This commit is contained in:
+11
-8
@@ -10,20 +10,23 @@ OPENROUTER_API_KEY=""
|
||||
LM_STUDIO_BASE_URL="http://localhost:1234/v1"
|
||||
|
||||
|
||||
# All Claude model requests are mapped to this model
|
||||
# All Claude model requests are mapped to these models, plain model is fallback
|
||||
# Format: provider_type/model/name
|
||||
# Valid providers: "nvidia_nim" | "open_router" | "lmstudio"
|
||||
MODEL="nvidia_nim/stepfun-ai/step-3.5-flash"
|
||||
MODEL_OPUS="nvidia_nim/z-ai/glm4.7"
|
||||
MODEL_SONNET="arcee-ai/trinity-large-preview:free"
|
||||
MODEL_HAIKU="stepfun/step-3.5-flash:free"
|
||||
MODEL="nvidia_nim/z-ai/glm4.7"
|
||||
|
||||
|
||||
# Provider Config
|
||||
PROVIDER_RATE_LIMIT=40
|
||||
PROVIDER_RATE_WINDOW=60
|
||||
PROVIDER_MAX_CONCURRENCY=5
|
||||
# Provider config
|
||||
PROVIDER_RATE_LIMIT=1
|
||||
PROVIDER_RATE_WINDOW=3
|
||||
PROVIDER_MAX_CONCURRENCY=50
|
||||
|
||||
|
||||
# HTTP client timeouts (seconds) for provider API requests
|
||||
HTTP_READ_TIMEOUT=300
|
||||
HTTP_READ_TIMEOUT=60
|
||||
HTTP_WRITE_TIMEOUT=10
|
||||
HTTP_CONNECT_TIMEOUT=2
|
||||
|
||||
@@ -44,7 +47,7 @@ WHISPER_DEVICE="nvidia_nim"
|
||||
# - For cpu/cuda: Hugging Face ID or short name (tiny, base, small, medium, large-v2, large-v3, large-v3-turbo)
|
||||
# - For nvidia_nim: NVIDIA NIM model (e.g., "nvidia/parakeet-ctc-1.1b-asr", "openai/whisper-large-v3")
|
||||
# - For nvidia_nim, default to "openai/whisper-large-v3" for best performance
|
||||
WHISPER_MODEL="openai/whisper-large-v3"
|
||||
WHISPER_MODEL="openai/whisper-large-v3"
|
||||
HF_TOKEN=""
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
> This file is identical to CLAUDE.md. Keep them in sync.
|
||||
|
||||
## CODING ENVIRONMENT
|
||||
|
||||
- Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed
|
||||
- Install Python 3.14.3 using `uv python install 3.14.3` if not already installed
|
||||
- Always use `uv run` to run files instead of the global `python` command.
|
||||
- Current uv ruff formatter is set to py314 which has supports multiple exception types without paranthesis (except TypeError, ValueError:)
|
||||
- Read `.env.example` for environment variables.
|
||||
@@ -14,11 +16,13 @@
|
||||
- All 5 checks are enforced in `tests.yml` on push/merge.
|
||||
|
||||
## IDENTITY & CONTEXT
|
||||
|
||||
- You are an expert Software Architect and Systems Engineer.
|
||||
- Goal: Zero-defect, root-cause-oriented engineering for bugs; test-driven engineering for new features. Think carefully; no need to rush.
|
||||
- Code: Write the simplest code possible. Keep the codebase minimal and modular.
|
||||
|
||||
## ARCHITECTURE PRINCIPLES (see PLAN.md)
|
||||
|
||||
- **Shared utilities**: Extract common logic into shared packages (e.g. `providers/common/`). Do not have one provider import from another provider's utils.
|
||||
- **DRY**: Extract shared base classes to eliminate duplication. Prefer composition over copy-paste.
|
||||
- **Encapsulation**: Use accessor methods for internal state (e.g. `set_current_task()`), not direct `_attribute` assignment from outside.
|
||||
@@ -30,6 +34,7 @@
|
||||
- **Backward compatibility**: When moving modules, add re-exports from old locations so existing imports keep working.
|
||||
|
||||
## COGNITIVE WORKFLOW
|
||||
|
||||
1. **ANALYZE**: Read relevant files. Do not guess.
|
||||
2. **PLAN**: Map out the logic. Identify root cause or required changes. Order changes by dependency.
|
||||
3. **EXECUTE**: Fix the cause, not the symptom. Execute incrementally with clear commits.
|
||||
@@ -38,8 +43,10 @@
|
||||
6. **PROPAGATION**: Changes impact multiple files; propagate updates correctly.
|
||||
|
||||
## SUMMARY STANDARDS
|
||||
|
||||
- Summaries must be technical and granular.
|
||||
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks] (if no residual risks then say none).
|
||||
|
||||
## TOOLS
|
||||
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.
|
||||
|
||||
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
> This file is identical to CLAUDE.md. Keep them in sync.
|
||||
|
||||
## CODING ENVIRONMENT
|
||||
|
||||
- Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed
|
||||
- Install Python 3.14.3 using `uv python install 3.14.3` if not already installed
|
||||
- Always use `uv run` to run files instead of the global `python` command.
|
||||
- Current uv ruff formatter is set to py314 which has supports multiple exception types without paranthesis (except TypeError, ValueError:)
|
||||
- Read `.env.example` for environment variables.
|
||||
@@ -14,11 +16,13 @@
|
||||
- All 5 checks are enforced in `tests.yml` on push/merge.
|
||||
|
||||
## IDENTITY & CONTEXT
|
||||
|
||||
- You are an expert Software Architect and Systems Engineer.
|
||||
- Goal: Zero-defect, root-cause-oriented engineering for bugs; test-driven engineering for new features. Think carefully; no need to rush.
|
||||
- Code: Write the simplest code possible. Keep the codebase minimal and modular.
|
||||
|
||||
## ARCHITECTURE PRINCIPLES (see PLAN.md)
|
||||
|
||||
- **Shared utilities**: Extract common logic into shared packages (e.g. `providers/common/`). Do not have one provider import from another provider's utils.
|
||||
- **DRY**: Extract shared base classes to eliminate duplication. Prefer composition over copy-paste.
|
||||
- **Encapsulation**: Use accessor methods for internal state (e.g. `set_current_task()`), not direct `_attribute` assignment from outside.
|
||||
@@ -30,6 +34,7 @@
|
||||
- **Backward compatibility**: When moving modules, add re-exports from old locations so existing imports keep working.
|
||||
|
||||
## COGNITIVE WORKFLOW
|
||||
|
||||
1. **ANALYZE**: Read relevant files. Do not guess.
|
||||
2. **PLAN**: Map out the logic. Identify root cause or required changes. Order changes by dependency.
|
||||
3. **EXECUTE**: Fix the cause, not the symptom. Execute incrementally with clear commits.
|
||||
@@ -38,8 +43,10 @@
|
||||
6. **PROPAGATION**: Changes impact multiple files; propagate updates correctly.
|
||||
|
||||
## SUMMARY STANDARDS
|
||||
|
||||
- Summaries must be technical and granular.
|
||||
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks] (if no residual risks then say none).
|
||||
|
||||
## TOOLS
|
||||
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.
|
||||
|
||||
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.
|
||||
|
||||
@@ -50,6 +50,7 @@ A lightweight proxy that routes Claude Code's Anthropic API calls to **NVIDIA NI
|
||||
- **LM Studio**: No API key needed. Run locally with [LM Studio](https://lmstudio.ai)
|
||||
2. Install [Claude Code](https://github.com/anthropics/claude-code)
|
||||
3. Install [uv](https://github.com/astral-sh/uv)
|
||||
4. Install Python 3.14.3: `uv python install 3.14.3`
|
||||
|
||||
### Clone & Configure
|
||||
|
||||
|
||||
+2
-1
@@ -1,7 +1,7 @@
|
||||
"""API layer for Claude Code Proxy."""
|
||||
|
||||
from .app import app, create_app
|
||||
from .dependencies import get_provider
|
||||
from .dependencies import get_provider, get_provider_for_type
|
||||
from .models import (
|
||||
MessagesRequest,
|
||||
MessagesResponse,
|
||||
@@ -17,4 +17,5 @@ __all__ = [
|
||||
"app",
|
||||
"create_app",
|
||||
"get_provider",
|
||||
"get_provider_for_type",
|
||||
]
|
||||
|
||||
+42
-32
@@ -12,8 +12,8 @@ from providers.lmstudio import LMStudioProvider
|
||||
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
|
||||
from providers.open_router import OPENROUTER_BASE_URL, OpenRouterProvider
|
||||
|
||||
# Global provider instance (singleton)
|
||||
_provider: BaseProvider | None = None
|
||||
# Provider registry: keyed by provider type string, lazily populated
|
||||
_providers: dict[str, BaseProvider] = {}
|
||||
|
||||
|
||||
def get_settings() -> Settings:
|
||||
@@ -21,9 +21,9 @@ def get_settings() -> Settings:
|
||||
return _get_settings()
|
||||
|
||||
|
||||
def _create_provider(settings: Settings) -> BaseProvider:
|
||||
"""Construct and return a new provider instance from settings."""
|
||||
if settings.provider_type == "nvidia_nim":
|
||||
def _create_provider_for_type(provider_type: str, settings: Settings) -> BaseProvider:
|
||||
"""Construct and return a new provider instance for the given provider type."""
|
||||
if provider_type == "nvidia_nim":
|
||||
if not settings.nvidia_nim_api_key or not settings.nvidia_nim_api_key.strip():
|
||||
raise AuthenticationError(
|
||||
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
|
||||
@@ -39,8 +39,8 @@ def _create_provider(settings: Settings) -> BaseProvider:
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
)
|
||||
provider = NvidiaNimProvider(config, nim_settings=settings.nim)
|
||||
elif settings.provider_type == "open_router":
|
||||
return NvidiaNimProvider(config, nim_settings=settings.nim)
|
||||
if provider_type == "open_router":
|
||||
if not settings.open_router_api_key or not settings.open_router_api_key.strip():
|
||||
raise AuthenticationError(
|
||||
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
|
||||
@@ -56,8 +56,8 @@ def _create_provider(settings: Settings) -> BaseProvider:
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
)
|
||||
provider = OpenRouterProvider(config)
|
||||
elif settings.provider_type == "lmstudio":
|
||||
return OpenRouterProvider(config)
|
||||
if provider_type == "lmstudio":
|
||||
config = ProviderConfig(
|
||||
api_key="lm-studio",
|
||||
base_url=settings.lm_studio_base_url,
|
||||
@@ -68,37 +68,47 @@ def _create_provider(settings: Settings) -> BaseProvider:
|
||||
http_write_timeout=settings.http_write_timeout,
|
||||
http_connect_timeout=settings.http_connect_timeout,
|
||||
)
|
||||
provider = LMStudioProvider(config)
|
||||
else:
|
||||
logger.error(
|
||||
"Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
|
||||
settings.provider_type,
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unknown provider_type: '{settings.provider_type}'. "
|
||||
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
|
||||
)
|
||||
logger.info("Provider initialized: {}", settings.provider_type)
|
||||
return provider
|
||||
return LMStudioProvider(config)
|
||||
logger.error(
|
||||
"Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
|
||||
provider_type,
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unknown provider_type: '{provider_type}'. "
|
||||
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
|
||||
)
|
||||
|
||||
|
||||
def get_provider() -> BaseProvider:
|
||||
"""Get or create the provider instance based on settings.provider_type."""
|
||||
global _provider
|
||||
if _provider is None:
|
||||
def get_provider_for_type(provider_type: str) -> BaseProvider:
|
||||
"""Get or create a provider for the given provider type.
|
||||
|
||||
Providers are cached in the registry and reused across requests.
|
||||
"""
|
||||
if provider_type not in _providers:
|
||||
try:
|
||||
_provider = _create_provider(get_settings())
|
||||
_providers[provider_type] = _create_provider_for_type(
|
||||
provider_type, get_settings()
|
||||
)
|
||||
except AuthenticationError as e:
|
||||
raise HTTPException(
|
||||
status_code=503, detail=get_user_facing_error_message(e)
|
||||
) from e
|
||||
return _provider
|
||||
logger.info("Provider initialized: {}", provider_type)
|
||||
return _providers[provider_type]
|
||||
|
||||
|
||||
def get_provider() -> BaseProvider:
|
||||
"""Get or create the default provider (based on MODEL env var).
|
||||
|
||||
Backward-compatible convenience for health/root endpoints and tests.
|
||||
"""
|
||||
return get_provider_for_type(get_settings().provider_type)
|
||||
|
||||
|
||||
async def cleanup_provider():
|
||||
"""Cleanup provider resources."""
|
||||
global _provider
|
||||
if _provider:
|
||||
await _provider.cleanup()
|
||||
_provider = None
|
||||
"""Cleanup all provider resources."""
|
||||
global _providers
|
||||
for provider in _providers.values():
|
||||
await provider.cleanup()
|
||||
_providers = {}
|
||||
logger.debug("Provider cleanup completed")
|
||||
|
||||
+11
-12
@@ -6,13 +6,12 @@ from typing import Any, Literal
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
|
||||
from config.settings import get_settings
|
||||
from config.settings import Settings, get_settings
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Content Block Types
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class Role(StrEnum):
|
||||
user = "user"
|
||||
assistant = "assistant"
|
||||
@@ -55,8 +54,6 @@ class SystemContent(BaseModel):
|
||||
# =============================================================================
|
||||
# Message Types
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: Literal["user", "assistant"]
|
||||
content: (
|
||||
@@ -85,8 +82,6 @@ class ThinkingConfig(BaseModel):
|
||||
# =============================================================================
|
||||
# Request Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class MessagesRequest(BaseModel):
|
||||
model: str
|
||||
max_tokens: int | None = None
|
||||
@@ -103,15 +98,18 @@ class MessagesRequest(BaseModel):
|
||||
thinking: ThinkingConfig | None = None
|
||||
extra_body: dict[str, Any] | None = None
|
||||
original_model: str | None = None
|
||||
resolved_provider_model: str | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def map_model(self) -> MessagesRequest:
|
||||
"""Map any Claude model name to the configured model."""
|
||||
"""Map any Claude model name to the configured model (tier-aware)."""
|
||||
settings = get_settings()
|
||||
if self.original_model is None:
|
||||
self.original_model = self.model
|
||||
|
||||
self.model = settings.model_name
|
||||
resolved_full = settings.resolve_model(self.original_model)
|
||||
self.resolved_provider_model = resolved_full
|
||||
self.model = Settings.parse_model_name(resolved_full)
|
||||
|
||||
if self.model != self.original_model:
|
||||
logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'")
|
||||
@@ -129,7 +127,8 @@ class TokenCountRequest(BaseModel):
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model_field(cls, v, info):
|
||||
"""Map any Claude model name to the configured model."""
|
||||
def validate_model_field(cls, v: str, info) -> str:
|
||||
"""Map any Claude model name to the configured model (tier-aware)."""
|
||||
settings = get_settings()
|
||||
return settings.model_name
|
||||
resolved_full = settings.resolve_model(v)
|
||||
return Settings.parse_model_name(resolved_full)
|
||||
|
||||
+7
-5
@@ -9,12 +9,11 @@ from fastapi.responses import StreamingResponse
|
||||
from loguru import logger
|
||||
|
||||
from config.settings import Settings
|
||||
from providers.base import BaseProvider
|
||||
from providers.common import get_user_facing_error_message
|
||||
from providers.exceptions import InvalidRequestError, ProviderError
|
||||
from providers.logging_utils import build_request_summary, log_request_compact
|
||||
|
||||
from .dependencies import get_provider, get_settings
|
||||
from .dependencies import get_provider_for_type, get_settings
|
||||
from .models.anthropic import MessagesRequest, TokenCountRequest
|
||||
from .models.responses import TokenCountResponse
|
||||
from .optimization_handlers import try_optimizations
|
||||
@@ -26,13 +25,10 @@ router = APIRouter()
|
||||
# =============================================================================
|
||||
# Routes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.post("/v1/messages")
|
||||
async def create_message(
|
||||
request_data: MessagesRequest,
|
||||
raw_request: Request,
|
||||
provider: BaseProvider = Depends(get_provider),
|
||||
settings: Settings = Depends(get_settings),
|
||||
):
|
||||
"""Create a message (always streaming)."""
|
||||
@@ -45,6 +41,12 @@ async def create_message(
|
||||
if optimized is not None:
|
||||
return optimized
|
||||
|
||||
# Resolve provider from the tier-aware model mapping
|
||||
provider_type = Settings.parse_provider_type(
|
||||
request_data.resolved_provider_model or settings.model
|
||||
)
|
||||
provider = get_provider_for_type(provider_type)
|
||||
|
||||
request_id = f"req_{uuid.uuid4().hex[:12]}"
|
||||
log_request_compact(logger, request_id, request_data)
|
||||
|
||||
|
||||
+38
-5
@@ -33,10 +33,16 @@ class Settings(BaseSettings):
|
||||
)
|
||||
|
||||
# ==================== Model ====================
|
||||
# All Claude model requests are mapped to this single model
|
||||
# All Claude model requests are mapped to this single model (fallback)
|
||||
# Format: provider_type/model/name
|
||||
model: str = "nvidia_nim/meta/llama3-70b-instruct"
|
||||
|
||||
# Per-tier model overrides (optional, falls back to MODEL)
|
||||
# Each can use a different provider
|
||||
model_opus: str | None = Field(default=None, validation_alias="MODEL_OPUS")
|
||||
model_sonnet: str | None = Field(default=None, validation_alias="MODEL_SONNET")
|
||||
model_haiku: str | None = Field(default=None, validation_alias="MODEL_HAIKU")
|
||||
|
||||
# ==================== Provider Rate Limiting ====================
|
||||
provider_rate_limit: int = Field(default=40, validation_alias="PROVIDER_RATE_LIMIT")
|
||||
provider_rate_window: int = Field(
|
||||
@@ -124,9 +130,11 @@ class Settings(BaseSettings):
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("model")
|
||||
@field_validator("model", "model_opus", "model_sonnet", "model_haiku")
|
||||
@classmethod
|
||||
def validate_model_format(cls, v: str) -> str:
|
||||
def validate_model_format(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return None
|
||||
valid_providers = ("nvidia_nim", "open_router", "lmstudio")
|
||||
if "/" not in v:
|
||||
raise ValueError(
|
||||
@@ -157,14 +165,39 @@ class Settings(BaseSettings):
|
||||
|
||||
@property
|
||||
def provider_type(self) -> str:
|
||||
"""Extract provider type from the model string."""
|
||||
"""Extract provider type from the default model string."""
|
||||
return self.model.split("/", 1)[0]
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Extract the actual model name from the model string."""
|
||||
"""Extract the actual model name from the default model string."""
|
||||
return self.model.split("/", 1)[1]
|
||||
|
||||
def resolve_model(self, claude_model_name: str) -> str:
|
||||
"""Resolve a Claude model name to the configured provider/model string.
|
||||
|
||||
Classifies the incoming model into a tier (opus/sonnet/haiku) and
|
||||
returns the tier-specific model if configured, otherwise the fallback MODEL.
|
||||
"""
|
||||
name_lower = claude_model_name.lower()
|
||||
if "opus" in name_lower and self.model_opus is not None:
|
||||
return self.model_opus
|
||||
if "haiku" in name_lower and self.model_haiku is not None:
|
||||
return self.model_haiku
|
||||
if "sonnet" in name_lower and self.model_sonnet is not None:
|
||||
return self.model_sonnet
|
||||
return self.model
|
||||
|
||||
@staticmethod
|
||||
def parse_provider_type(model_string: str) -> str:
|
||||
"""Extract provider type from any 'provider/model' string."""
|
||||
return model_string.split("/", 1)[0]
|
||||
|
||||
@staticmethod
|
||||
def parse_model_name(model_string: str) -> str:
|
||||
"""Extract model name from any 'provider/model' string."""
|
||||
return model_string.split("/", 1)[1]
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
|
||||
+1
-1
@@ -3,7 +3,7 @@ name = "free-claude-code"
|
||||
version = "2.0.0"
|
||||
description = "Middleware between Claude Code CLI (Anthropic API) and NVIDIA NIM"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.14.2"
|
||||
requires-python = ">=3.14.3"
|
||||
dependencies = [
|
||||
"fastapi[standard]>=0.115.11",
|
||||
"uvicorn>=0.34.0",
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from api.app import app
|
||||
from api.dependencies import get_provider
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
|
||||
# Mock provider
|
||||
@@ -22,12 +21,10 @@ async def _mock_stream_response(*args, **kwargs):
|
||||
|
||||
mock_provider.stream_response = _mock_stream_response
|
||||
|
||||
# Patch get_provider_for_type to always return mock_provider
|
||||
_patcher = patch("api.routes.get_provider_for_type", return_value=mock_provider)
|
||||
_patcher.start()
|
||||
|
||||
def override_get_provider():
|
||||
return mock_provider
|
||||
|
||||
|
||||
app.dependency_overrides[get_provider] = override_get_provider
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.dependencies import cleanup_provider, get_provider, get_settings
|
||||
from api.dependencies import (
|
||||
cleanup_provider,
|
||||
get_provider,
|
||||
get_provider_for_type,
|
||||
get_settings,
|
||||
)
|
||||
from config.nim import NimSettings
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
@@ -32,9 +37,13 @@ def _make_mock_settings(**overrides):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_provider():
|
||||
"""Reset the global _provider singleton between tests."""
|
||||
with patch("api.dependencies._provider", None):
|
||||
yield
|
||||
"""Reset the global _providers registry between tests."""
|
||||
import api.dependencies
|
||||
|
||||
saved = api.dependencies._providers
|
||||
api.dependencies._providers = {}
|
||||
yield
|
||||
api.dependencies._providers = saved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -213,6 +222,70 @@ async def test_cleanup_provider_aclose_raises():
|
||||
provider._client = AsyncMock()
|
||||
provider._client.aclose = AsyncMock(side_effect=RuntimeError("cleanup failed"))
|
||||
|
||||
# Should propagate the error (current behavior - no try/except)
|
||||
# Should propagate the error
|
||||
with pytest.raises(RuntimeError, match="cleanup failed"):
|
||||
await cleanup_provider()
|
||||
|
||||
|
||||
# --- Provider Registry Tests ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_for_type_caches():
|
||||
"""get_provider_for_type returns cached provider on second call."""
|
||||
with patch("api.dependencies.get_settings") as mock_settings:
|
||||
mock_settings.return_value = _make_mock_settings()
|
||||
|
||||
p1 = get_provider_for_type("nvidia_nim")
|
||||
p2 = get_provider_for_type("nvidia_nim")
|
||||
|
||||
assert p1 is p2
|
||||
assert isinstance(p1, NvidiaNimProvider)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_for_type_different_types():
|
||||
"""get_provider_for_type creates separate providers per type."""
|
||||
with patch("api.dependencies.get_settings") as mock_settings:
|
||||
mock_settings.return_value = _make_mock_settings()
|
||||
|
||||
nim = get_provider_for_type("nvidia_nim")
|
||||
lmstudio = get_provider_for_type("lmstudio")
|
||||
|
||||
assert isinstance(nim, NvidiaNimProvider)
|
||||
assert isinstance(lmstudio, LMStudioProvider)
|
||||
assert nim is not lmstudio
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_for_type_missing_key_raises_503():
|
||||
"""get_provider_for_type raises HTTPException 503 for missing API key."""
|
||||
with patch("api.dependencies.get_settings") as mock_settings:
|
||||
mock_settings.return_value = _make_mock_settings(open_router_api_key="")
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
get_provider_for_type("open_router")
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert "OPENROUTER_API_KEY" in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_provider_cleans_all():
|
||||
"""cleanup_provider cleans up all providers in the registry."""
|
||||
with patch("api.dependencies.get_settings") as mock_settings:
|
||||
mock_settings.return_value = _make_mock_settings()
|
||||
|
||||
nim = get_provider_for_type("nvidia_nim")
|
||||
lmstudio = get_provider_for_type("lmstudio")
|
||||
|
||||
assert isinstance(nim, NvidiaNimProvider)
|
||||
assert isinstance(lmstudio, LMStudioProvider)
|
||||
|
||||
nim._client = AsyncMock()
|
||||
lmstudio._client = AsyncMock()
|
||||
|
||||
await cleanup_provider()
|
||||
|
||||
nim._client.aclose.assert_called_once()
|
||||
lmstudio._client.aclose.assert_called_once()
|
||||
|
||||
@@ -61,3 +61,97 @@ def test_messages_request_model_mapping_logs(mock_settings):
|
||||
assert "MODEL MAPPING" in args
|
||||
assert "claude-2.1" in args
|
||||
assert "target-model-from-settings" in args
|
||||
|
||||
|
||||
def test_messages_request_resolved_provider_model_default(mock_settings):
|
||||
"""resolved_provider_model is set to the full model string."""
|
||||
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
|
||||
request = MessagesRequest(
|
||||
model="claude-3-opus",
|
||||
max_tokens=100,
|
||||
messages=[Message(role="user", content="hello")],
|
||||
)
|
||||
assert (
|
||||
request.resolved_provider_model == "nvidia_nim/target-model-from-settings"
|
||||
)
|
||||
|
||||
|
||||
def test_messages_request_tier_aware_opus_override():
|
||||
"""Opus model routes to MODEL_OPUS when set."""
|
||||
settings = Settings()
|
||||
settings.model = "nvidia_nim/fallback-model"
|
||||
settings.model_opus = "open_router/deepseek/deepseek-r1"
|
||||
|
||||
with patch("api.models.anthropic.get_settings", return_value=settings):
|
||||
request = MessagesRequest(
|
||||
model="claude-opus-4-20250514",
|
||||
max_tokens=100,
|
||||
messages=[Message(role="user", content="hello")],
|
||||
)
|
||||
assert request.model == "deepseek/deepseek-r1"
|
||||
assert request.resolved_provider_model == "open_router/deepseek/deepseek-r1"
|
||||
assert request.original_model == "claude-opus-4-20250514"
|
||||
|
||||
|
||||
def test_messages_request_tier_aware_haiku_override():
|
||||
"""Haiku model routes to MODEL_HAIKU when set."""
|
||||
settings = Settings()
|
||||
settings.model = "nvidia_nim/fallback-model"
|
||||
settings.model_haiku = "lmstudio/qwen2.5-7b"
|
||||
|
||||
with patch("api.models.anthropic.get_settings", return_value=settings):
|
||||
request = MessagesRequest(
|
||||
model="claude-3-haiku-20240307",
|
||||
max_tokens=100,
|
||||
messages=[Message(role="user", content="hello")],
|
||||
)
|
||||
assert request.model == "qwen2.5-7b"
|
||||
assert request.resolved_provider_model == "lmstudio/qwen2.5-7b"
|
||||
|
||||
|
||||
def test_messages_request_tier_aware_sonnet_override():
|
||||
"""Sonnet model routes to MODEL_SONNET when set."""
|
||||
settings = Settings()
|
||||
settings.model = "nvidia_nim/fallback-model"
|
||||
settings.model_sonnet = "nvidia_nim/meta/llama-3.3-70b-instruct"
|
||||
|
||||
with patch("api.models.anthropic.get_settings", return_value=settings):
|
||||
request = MessagesRequest(
|
||||
model="claude-sonnet-4-20250514",
|
||||
max_tokens=100,
|
||||
messages=[Message(role="user", content="hello")],
|
||||
)
|
||||
assert request.model == "meta/llama-3.3-70b-instruct"
|
||||
assert (
|
||||
request.resolved_provider_model == "nvidia_nim/meta/llama-3.3-70b-instruct"
|
||||
)
|
||||
|
||||
|
||||
def test_messages_request_tier_fallback_when_not_set():
|
||||
"""When tier override is None, falls back to MODEL."""
|
||||
settings = Settings()
|
||||
settings.model = "nvidia_nim/fallback-model"
|
||||
# model_opus is None
|
||||
|
||||
with patch("api.models.anthropic.get_settings", return_value=settings):
|
||||
request = MessagesRequest(
|
||||
model="claude-opus-4-20250514",
|
||||
max_tokens=100,
|
||||
messages=[Message(role="user", content="hello")],
|
||||
)
|
||||
assert request.model == "fallback-model"
|
||||
assert request.resolved_provider_model == "nvidia_nim/fallback-model"
|
||||
|
||||
|
||||
def test_token_count_request_tier_aware():
|
||||
"""TokenCountRequest also uses tier-aware resolution."""
|
||||
settings = Settings()
|
||||
settings.model = "nvidia_nim/fallback-model"
|
||||
settings.model_haiku = "lmstudio/qwen2.5-7b"
|
||||
|
||||
with patch("api.models.anthropic.get_settings", return_value=settings):
|
||||
request = TokenCountRequest(
|
||||
model="claude-3-haiku-20240307",
|
||||
messages=[Message(role="user", content="hello")],
|
||||
)
|
||||
assert request.model == "qwen2.5-7b"
|
||||
|
||||
+148
-2
@@ -106,8 +106,6 @@ class TestSettings:
|
||||
|
||||
|
||||
# --- NimSettings Validation Tests ---
|
||||
|
||||
|
||||
class TestNimSettingsValidBounds:
|
||||
"""Test that valid values within bounds are accepted."""
|
||||
|
||||
@@ -310,3 +308,151 @@ class TestSettingsOptionalStr:
|
||||
monkeypatch.setenv("WHISPER_DEVICE", device)
|
||||
s = Settings()
|
||||
assert s.whisper_device == device
|
||||
|
||||
|
||||
class TestPerTierModelMapping:
|
||||
"""Test per-tier model fields and resolve_model()."""
|
||||
|
||||
def test_tier_fields_default_none(self):
|
||||
"""Per-tier model fields default to None."""
|
||||
from config.settings import Settings
|
||||
|
||||
s = Settings()
|
||||
assert s.model_opus is None
|
||||
assert s.model_sonnet is None
|
||||
assert s.model_haiku is None
|
||||
|
||||
def test_model_opus_from_env(self, monkeypatch):
|
||||
"""MODEL_OPUS env var is loaded."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setenv("MODEL_OPUS", "open_router/deepseek/deepseek-r1")
|
||||
s = Settings()
|
||||
assert s.model_opus == "open_router/deepseek/deepseek-r1"
|
||||
|
||||
def test_model_sonnet_from_env(self, monkeypatch):
|
||||
"""MODEL_SONNET env var is loaded."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setenv("MODEL_SONNET", "nvidia_nim/meta/llama-3.3-70b-instruct")
|
||||
s = Settings()
|
||||
assert s.model_sonnet == "nvidia_nim/meta/llama-3.3-70b-instruct"
|
||||
|
||||
def test_model_haiku_from_env(self, monkeypatch):
|
||||
"""MODEL_HAIKU env var is loaded."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setenv("MODEL_HAIKU", "lmstudio/qwen2.5-7b")
|
||||
s = Settings()
|
||||
assert s.model_haiku == "lmstudio/qwen2.5-7b"
|
||||
|
||||
def test_model_opus_invalid_provider_raises(self, monkeypatch):
|
||||
"""MODEL_OPUS with invalid provider prefix raises ValidationError."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setenv("MODEL_OPUS", "bad_provider/some-model")
|
||||
with pytest.raises(ValidationError, match="Invalid provider"):
|
||||
Settings()
|
||||
|
||||
def test_model_opus_no_slash_raises(self, monkeypatch):
|
||||
"""MODEL_OPUS without provider prefix raises ValidationError."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setenv("MODEL_OPUS", "noprefix")
|
||||
with pytest.raises(ValidationError, match="provider type"):
|
||||
Settings()
|
||||
|
||||
def test_model_haiku_invalid_provider_raises(self, monkeypatch):
|
||||
"""MODEL_HAIKU with invalid provider prefix raises ValidationError."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setenv("MODEL_HAIKU", "invalid/model")
|
||||
with pytest.raises(ValidationError, match="Invalid provider"):
|
||||
Settings()
|
||||
|
||||
def test_resolve_model_opus_override(self):
|
||||
"""resolve_model returns model_opus for opus model names."""
|
||||
from config.settings import Settings
|
||||
|
||||
s = Settings()
|
||||
s.model_opus = "open_router/deepseek/deepseek-r1"
|
||||
assert (
|
||||
s.resolve_model("claude-opus-4-20250514")
|
||||
== "open_router/deepseek/deepseek-r1"
|
||||
)
|
||||
assert s.resolve_model("claude-3-opus") == "open_router/deepseek/deepseek-r1"
|
||||
assert (
|
||||
s.resolve_model("claude-3-opus-20240229")
|
||||
== "open_router/deepseek/deepseek-r1"
|
||||
)
|
||||
|
||||
def test_resolve_model_sonnet_override(self):
|
||||
"""resolve_model returns model_sonnet for sonnet model names."""
|
||||
from config.settings import Settings
|
||||
|
||||
s = Settings()
|
||||
s.model_sonnet = "nvidia_nim/meta/llama-3.3-70b-instruct"
|
||||
assert (
|
||||
s.resolve_model("claude-sonnet-4-20250514")
|
||||
== "nvidia_nim/meta/llama-3.3-70b-instruct"
|
||||
)
|
||||
assert (
|
||||
s.resolve_model("claude-3-5-sonnet-20241022")
|
||||
== "nvidia_nim/meta/llama-3.3-70b-instruct"
|
||||
)
|
||||
|
||||
def test_resolve_model_haiku_override(self):
|
||||
"""resolve_model returns model_haiku for haiku model names."""
|
||||
from config.settings import Settings
|
||||
|
||||
s = Settings()
|
||||
s.model_haiku = "lmstudio/qwen2.5-7b"
|
||||
assert s.resolve_model("claude-3-haiku-20240307") == "lmstudio/qwen2.5-7b"
|
||||
assert s.resolve_model("claude-3-5-haiku-20241022") == "lmstudio/qwen2.5-7b"
|
||||
assert s.resolve_model("claude-haiku-4-20250514") == "lmstudio/qwen2.5-7b"
|
||||
|
||||
def test_resolve_model_fallback_when_tier_not_set(self):
|
||||
"""resolve_model falls back to MODEL when tier override is None."""
|
||||
from config.settings import Settings
|
||||
|
||||
s = Settings()
|
||||
s.model = "nvidia_nim/fallback-model"
|
||||
# No tier overrides set
|
||||
assert s.resolve_model("claude-opus-4-20250514") == "nvidia_nim/fallback-model"
|
||||
assert (
|
||||
s.resolve_model("claude-sonnet-4-20250514") == "nvidia_nim/fallback-model"
|
||||
)
|
||||
assert s.resolve_model("claude-3-haiku-20240307") == "nvidia_nim/fallback-model"
|
||||
|
||||
def test_resolve_model_unknown_tier_falls_back(self):
|
||||
"""resolve_model falls back to MODEL for unrecognized model names."""
|
||||
from config.settings import Settings
|
||||
|
||||
s = Settings()
|
||||
s.model = "nvidia_nim/fallback-model"
|
||||
s.model_opus = "open_router/opus-model"
|
||||
assert s.resolve_model("claude-2.1") == "nvidia_nim/fallback-model"
|
||||
assert s.resolve_model("some-unknown-model") == "nvidia_nim/fallback-model"
|
||||
|
||||
def test_resolve_model_case_insensitive(self):
|
||||
"""Tier classification is case-insensitive."""
|
||||
from config.settings import Settings
|
||||
|
||||
s = Settings()
|
||||
s.model_opus = "open_router/opus-model"
|
||||
assert s.resolve_model("Claude-OPUS-4") == "open_router/opus-model"
|
||||
|
||||
def test_parse_provider_type(self):
|
||||
"""parse_provider_type extracts provider from model string."""
|
||||
from config.settings import Settings
|
||||
|
||||
assert Settings.parse_provider_type("nvidia_nim/meta/llama") == "nvidia_nim"
|
||||
assert Settings.parse_provider_type("open_router/deepseek/r1") == "open_router"
|
||||
assert Settings.parse_provider_type("lmstudio/qwen") == "lmstudio"
|
||||
|
||||
def test_parse_model_name(self):
|
||||
"""parse_model_name extracts model name from model string."""
|
||||
from config.settings import Settings
|
||||
|
||||
assert Settings.parse_model_name("nvidia_nim/meta/llama") == "meta/llama"
|
||||
assert Settings.parse_model_name("lmstudio/qwen") == "qwen"
|
||||
|
||||
Reference in New Issue
Block a user