Per claude model mapping (#66)

This commit is contained in:
Ali Khokhar
2026-03-01 21:32:23 -08:00
committed by GitHub
parent 763c8b62b7
commit 0b324e0421
15 changed files with 454 additions and 81 deletions
+11 -8
View File
@@ -10,20 +10,23 @@ OPENROUTER_API_KEY=""
LM_STUDIO_BASE_URL="http://localhost:1234/v1"
# All Claude model requests are mapped to this model
# All Claude model requests are mapped to these models, plain model is fallback
# Format: provider_type/model/name
# Valid providers: "nvidia_nim" | "open_router" | "lmstudio"
MODEL="nvidia_nim/stepfun-ai/step-3.5-flash"
MODEL_OPUS="nvidia_nim/z-ai/glm4.7"
MODEL_SONNET="arcee-ai/trinity-large-preview:free"
MODEL_HAIKU="stepfun/step-3.5-flash:free"
MODEL="nvidia_nim/z-ai/glm4.7"
# Provider Config
PROVIDER_RATE_LIMIT=40
PROVIDER_RATE_WINDOW=60
PROVIDER_MAX_CONCURRENCY=5
# Provider config
PROVIDER_RATE_LIMIT=1
PROVIDER_RATE_WINDOW=3
PROVIDER_MAX_CONCURRENCY=50
# HTTP client timeouts (seconds) for provider API requests
HTTP_READ_TIMEOUT=300
HTTP_READ_TIMEOUT=60
HTTP_WRITE_TIMEOUT=10
HTTP_CONNECT_TIMEOUT=2
@@ -44,7 +47,7 @@ WHISPER_DEVICE="nvidia_nim"
# - For cpu/cuda: Hugging Face ID or short name (tiny, base, small, medium, large-v2, large-v3, large-v3-turbo)
# - For nvidia_nim: NVIDIA NIM model (e.g., "nvidia/parakeet-ctc-1.1b-asr", "openai/whisper-large-v3")
# - For nvidia_nim, default to "openai/whisper-large-v3" for best performance
WHISPER_MODEL="openai/whisper-large-v3"
WHISPER_MODEL="openai/whisper-large-v3"
HF_TOKEN=""
+8 -1
View File
@@ -3,7 +3,9 @@
> This file is identical to CLAUDE.md. Keep them in sync.
## CODING ENVIRONMENT
- Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed
- Install Python 3.14.3 using `uv python install 3.14.3` if not already installed
- Always use `uv run` to run files instead of the global `python` command.
- Current uv ruff formatter is set to py314 which has supports multiple exception types without paranthesis (except TypeError, ValueError:)
- Read `.env.example` for environment variables.
@@ -14,11 +16,13 @@
- All 5 checks are enforced in `tests.yml` on push/merge.
## IDENTITY & CONTEXT
- You are an expert Software Architect and Systems Engineer.
- Goal: Zero-defect, root-cause-oriented engineering for bugs; test-driven engineering for new features. Think carefully; no need to rush.
- Code: Write the simplest code possible. Keep the codebase minimal and modular.
## ARCHITECTURE PRINCIPLES (see PLAN.md)
- **Shared utilities**: Extract common logic into shared packages (e.g. `providers/common/`). Do not have one provider import from another provider's utils.
- **DRY**: Extract shared base classes to eliminate duplication. Prefer composition over copy-paste.
- **Encapsulation**: Use accessor methods for internal state (e.g. `set_current_task()`), not direct `_attribute` assignment from outside.
@@ -30,6 +34,7 @@
- **Backward compatibility**: When moving modules, add re-exports from old locations so existing imports keep working.
## COGNITIVE WORKFLOW
1. **ANALYZE**: Read relevant files. Do not guess.
2. **PLAN**: Map out the logic. Identify root cause or required changes. Order changes by dependency.
3. **EXECUTE**: Fix the cause, not the symptom. Execute incrementally with clear commits.
@@ -38,8 +43,10 @@
6. **PROPAGATION**: Changes impact multiple files; propagate updates correctly.
## SUMMARY STANDARDS
- Summaries must be technical and granular.
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks] (if no residual risks then say none).
## TOOLS
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.
+8 -1
View File
@@ -3,7 +3,9 @@
> This file is identical to CLAUDE.md. Keep them in sync.
## CODING ENVIRONMENT
- Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed
- Install Python 3.14.3 using `uv python install 3.14.3` if not already installed
- Always use `uv run` to run files instead of the global `python` command.
- Current uv ruff formatter is set to py314 which has supports multiple exception types without paranthesis (except TypeError, ValueError:)
- Read `.env.example` for environment variables.
@@ -14,11 +16,13 @@
- All 5 checks are enforced in `tests.yml` on push/merge.
## IDENTITY & CONTEXT
- You are an expert Software Architect and Systems Engineer.
- Goal: Zero-defect, root-cause-oriented engineering for bugs; test-driven engineering for new features. Think carefully; no need to rush.
- Code: Write the simplest code possible. Keep the codebase minimal and modular.
## ARCHITECTURE PRINCIPLES (see PLAN.md)
- **Shared utilities**: Extract common logic into shared packages (e.g. `providers/common/`). Do not have one provider import from another provider's utils.
- **DRY**: Extract shared base classes to eliminate duplication. Prefer composition over copy-paste.
- **Encapsulation**: Use accessor methods for internal state (e.g. `set_current_task()`), not direct `_attribute` assignment from outside.
@@ -30,6 +34,7 @@
- **Backward compatibility**: When moving modules, add re-exports from old locations so existing imports keep working.
## COGNITIVE WORKFLOW
1. **ANALYZE**: Read relevant files. Do not guess.
2. **PLAN**: Map out the logic. Identify root cause or required changes. Order changes by dependency.
3. **EXECUTE**: Fix the cause, not the symptom. Execute incrementally with clear commits.
@@ -38,8 +43,10 @@
6. **PROPAGATION**: Changes impact multiple files; propagate updates correctly.
## SUMMARY STANDARDS
- Summaries must be technical and granular.
- Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks] (if no residual risks then say none).
## TOOLS
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.
- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use.
+1
View File
@@ -50,6 +50,7 @@ A lightweight proxy that routes Claude Code's Anthropic API calls to **NVIDIA NI
- **LM Studio**: No API key needed. Run locally with [LM Studio](https://lmstudio.ai)
2. Install [Claude Code](https://github.com/anthropics/claude-code)
3. Install [uv](https://github.com/astral-sh/uv)
4. Install Python 3.14.3: `uv python install 3.14.3`
### Clone & Configure
+2 -1
View File
@@ -1,7 +1,7 @@
"""API layer for Claude Code Proxy."""
from .app import app, create_app
from .dependencies import get_provider
from .dependencies import get_provider, get_provider_for_type
from .models import (
MessagesRequest,
MessagesResponse,
@@ -17,4 +17,5 @@ __all__ = [
"app",
"create_app",
"get_provider",
"get_provider_for_type",
]
+42 -32
View File
@@ -12,8 +12,8 @@ from providers.lmstudio import LMStudioProvider
from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider
from providers.open_router import OPENROUTER_BASE_URL, OpenRouterProvider
# Global provider instance (singleton)
_provider: BaseProvider | None = None
# Provider registry: keyed by provider type string, lazily populated
_providers: dict[str, BaseProvider] = {}
def get_settings() -> Settings:
@@ -21,9 +21,9 @@ def get_settings() -> Settings:
return _get_settings()
def _create_provider(settings: Settings) -> BaseProvider:
"""Construct and return a new provider instance from settings."""
if settings.provider_type == "nvidia_nim":
def _create_provider_for_type(provider_type: str, settings: Settings) -> BaseProvider:
"""Construct and return a new provider instance for the given provider type."""
if provider_type == "nvidia_nim":
if not settings.nvidia_nim_api_key or not settings.nvidia_nim_api_key.strip():
raise AuthenticationError(
"NVIDIA_NIM_API_KEY is not set. Add it to your .env file. "
@@ -39,8 +39,8 @@ def _create_provider(settings: Settings) -> BaseProvider:
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
provider = NvidiaNimProvider(config, nim_settings=settings.nim)
elif settings.provider_type == "open_router":
return NvidiaNimProvider(config, nim_settings=settings.nim)
if provider_type == "open_router":
if not settings.open_router_api_key or not settings.open_router_api_key.strip():
raise AuthenticationError(
"OPENROUTER_API_KEY is not set. Add it to your .env file. "
@@ -56,8 +56,8 @@ def _create_provider(settings: Settings) -> BaseProvider:
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
provider = OpenRouterProvider(config)
elif settings.provider_type == "lmstudio":
return OpenRouterProvider(config)
if provider_type == "lmstudio":
config = ProviderConfig(
api_key="lm-studio",
base_url=settings.lm_studio_base_url,
@@ -68,37 +68,47 @@ def _create_provider(settings: Settings) -> BaseProvider:
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
)
provider = LMStudioProvider(config)
else:
logger.error(
"Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
settings.provider_type,
)
raise ValueError(
f"Unknown provider_type: '{settings.provider_type}'. "
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
)
logger.info("Provider initialized: {}", settings.provider_type)
return provider
return LMStudioProvider(config)
logger.error(
"Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'",
provider_type,
)
raise ValueError(
f"Unknown provider_type: '{provider_type}'. "
f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'"
)
def get_provider() -> BaseProvider:
"""Get or create the provider instance based on settings.provider_type."""
global _provider
if _provider is None:
def get_provider_for_type(provider_type: str) -> BaseProvider:
"""Get or create a provider for the given provider type.
Providers are cached in the registry and reused across requests.
"""
if provider_type not in _providers:
try:
_provider = _create_provider(get_settings())
_providers[provider_type] = _create_provider_for_type(
provider_type, get_settings()
)
except AuthenticationError as e:
raise HTTPException(
status_code=503, detail=get_user_facing_error_message(e)
) from e
return _provider
logger.info("Provider initialized: {}", provider_type)
return _providers[provider_type]
def get_provider() -> BaseProvider:
"""Get or create the default provider (based on MODEL env var).
Backward-compatible convenience for health/root endpoints and tests.
"""
return get_provider_for_type(get_settings().provider_type)
async def cleanup_provider():
"""Cleanup provider resources."""
global _provider
if _provider:
await _provider.cleanup()
_provider = None
"""Cleanup all provider resources."""
global _providers
for provider in _providers.values():
await provider.cleanup()
_providers = {}
logger.debug("Provider cleanup completed")
+11 -12
View File
@@ -6,13 +6,12 @@ from typing import Any, Literal
from loguru import logger
from pydantic import BaseModel, field_validator, model_validator
from config.settings import get_settings
from config.settings import Settings, get_settings
# =============================================================================
# Content Block Types
# =============================================================================
class Role(StrEnum):
user = "user"
assistant = "assistant"
@@ -55,8 +54,6 @@ class SystemContent(BaseModel):
# =============================================================================
# Message Types
# =============================================================================
class Message(BaseModel):
role: Literal["user", "assistant"]
content: (
@@ -85,8 +82,6 @@ class ThinkingConfig(BaseModel):
# =============================================================================
# Request Models
# =============================================================================
class MessagesRequest(BaseModel):
model: str
max_tokens: int | None = None
@@ -103,15 +98,18 @@ class MessagesRequest(BaseModel):
thinking: ThinkingConfig | None = None
extra_body: dict[str, Any] | None = None
original_model: str | None = None
resolved_provider_model: str | None = None
@model_validator(mode="after")
def map_model(self) -> MessagesRequest:
"""Map any Claude model name to the configured model."""
"""Map any Claude model name to the configured model (tier-aware)."""
settings = get_settings()
if self.original_model is None:
self.original_model = self.model
self.model = settings.model_name
resolved_full = settings.resolve_model(self.original_model)
self.resolved_provider_model = resolved_full
self.model = Settings.parse_model_name(resolved_full)
if self.model != self.original_model:
logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'")
@@ -129,7 +127,8 @@ class TokenCountRequest(BaseModel):
@field_validator("model")
@classmethod
def validate_model_field(cls, v, info):
"""Map any Claude model name to the configured model."""
def validate_model_field(cls, v: str, info) -> str:
"""Map any Claude model name to the configured model (tier-aware)."""
settings = get_settings()
return settings.model_name
resolved_full = settings.resolve_model(v)
return Settings.parse_model_name(resolved_full)
+7 -5
View File
@@ -9,12 +9,11 @@ from fastapi.responses import StreamingResponse
from loguru import logger
from config.settings import Settings
from providers.base import BaseProvider
from providers.common import get_user_facing_error_message
from providers.exceptions import InvalidRequestError, ProviderError
from providers.logging_utils import build_request_summary, log_request_compact
from .dependencies import get_provider, get_settings
from .dependencies import get_provider_for_type, get_settings
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import TokenCountResponse
from .optimization_handlers import try_optimizations
@@ -26,13 +25,10 @@ router = APIRouter()
# =============================================================================
# Routes
# =============================================================================
@router.post("/v1/messages")
async def create_message(
request_data: MessagesRequest,
raw_request: Request,
provider: BaseProvider = Depends(get_provider),
settings: Settings = Depends(get_settings),
):
"""Create a message (always streaming)."""
@@ -45,6 +41,12 @@ async def create_message(
if optimized is not None:
return optimized
# Resolve provider from the tier-aware model mapping
provider_type = Settings.parse_provider_type(
request_data.resolved_provider_model or settings.model
)
provider = get_provider_for_type(provider_type)
request_id = f"req_{uuid.uuid4().hex[:12]}"
log_request_compact(logger, request_id, request_data)
+38 -5
View File
@@ -33,10 +33,16 @@ class Settings(BaseSettings):
)
# ==================== Model ====================
# All Claude model requests are mapped to this single model
# All Claude model requests are mapped to this single model (fallback)
# Format: provider_type/model/name
model: str = "nvidia_nim/meta/llama3-70b-instruct"
# Per-tier model overrides (optional, falls back to MODEL)
# Each can use a different provider
model_opus: str | None = Field(default=None, validation_alias="MODEL_OPUS")
model_sonnet: str | None = Field(default=None, validation_alias="MODEL_SONNET")
model_haiku: str | None = Field(default=None, validation_alias="MODEL_HAIKU")
# ==================== Provider Rate Limiting ====================
provider_rate_limit: int = Field(default=40, validation_alias="PROVIDER_RATE_LIMIT")
provider_rate_window: int = Field(
@@ -124,9 +130,11 @@ class Settings(BaseSettings):
)
return v
@field_validator("model")
@field_validator("model", "model_opus", "model_sonnet", "model_haiku")
@classmethod
def validate_model_format(cls, v: str) -> str:
def validate_model_format(cls, v: str | None) -> str | None:
if v is None:
return None
valid_providers = ("nvidia_nim", "open_router", "lmstudio")
if "/" not in v:
raise ValueError(
@@ -157,14 +165,39 @@ class Settings(BaseSettings):
@property
def provider_type(self) -> str:
"""Extract provider type from the model string."""
"""Extract provider type from the default model string."""
return self.model.split("/", 1)[0]
@property
def model_name(self) -> str:
"""Extract the actual model name from the model string."""
"""Extract the actual model name from the default model string."""
return self.model.split("/", 1)[1]
def resolve_model(self, claude_model_name: str) -> str:
"""Resolve a Claude model name to the configured provider/model string.
Classifies the incoming model into a tier (opus/sonnet/haiku) and
returns the tier-specific model if configured, otherwise the fallback MODEL.
"""
name_lower = claude_model_name.lower()
if "opus" in name_lower and self.model_opus is not None:
return self.model_opus
if "haiku" in name_lower and self.model_haiku is not None:
return self.model_haiku
if "sonnet" in name_lower and self.model_sonnet is not None:
return self.model_sonnet
return self.model
@staticmethod
def parse_provider_type(model_string: str) -> str:
"""Extract provider type from any 'provider/model' string."""
return model_string.split("/", 1)[0]
@staticmethod
def parse_model_name(model_string: str) -> str:
"""Extract model name from any 'provider/model' string."""
return model_string.split("/", 1)[1]
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
+1 -1
View File
@@ -3,7 +3,7 @@ name = "free-claude-code"
version = "2.0.0"
description = "Middleware between Claude Code CLI (Anthropic API) and NVIDIA NIM"
readme = "README.md"
requires-python = ">=3.14.2"
requires-python = ">=3.14.3"
dependencies = [
"fastapi[standard]>=0.115.11",
"uvicorn>=0.34.0",
+4 -7
View File
@@ -1,9 +1,8 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
from fastapi.testclient import TestClient
from api.app import app
from api.dependencies import get_provider
from providers.nvidia_nim import NvidiaNimProvider
# Mock provider
@@ -22,12 +21,10 @@ async def _mock_stream_response(*args, **kwargs):
mock_provider.stream_response = _mock_stream_response
# Patch get_provider_for_type to always return mock_provider
_patcher = patch("api.routes.get_provider_for_type", return_value=mock_provider)
_patcher.start()
def override_get_provider():
return mock_provider
app.dependency_overrides[get_provider] = override_get_provider
client = TestClient(app)
+78 -5
View File
@@ -3,7 +3,12 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from api.dependencies import cleanup_provider, get_provider, get_settings
from api.dependencies import (
cleanup_provider,
get_provider,
get_provider_for_type,
get_settings,
)
from config.nim import NimSettings
from providers.lmstudio import LMStudioProvider
from providers.nvidia_nim import NvidiaNimProvider
@@ -32,9 +37,13 @@ def _make_mock_settings(**overrides):
@pytest.fixture(autouse=True)
def reset_provider():
"""Reset the global _provider singleton between tests."""
with patch("api.dependencies._provider", None):
yield
"""Reset the global _providers registry between tests."""
import api.dependencies
saved = api.dependencies._providers
api.dependencies._providers = {}
yield
api.dependencies._providers = saved
@pytest.mark.asyncio
@@ -213,6 +222,70 @@ async def test_cleanup_provider_aclose_raises():
provider._client = AsyncMock()
provider._client.aclose = AsyncMock(side_effect=RuntimeError("cleanup failed"))
# Should propagate the error (current behavior - no try/except)
# Should propagate the error
with pytest.raises(RuntimeError, match="cleanup failed"):
await cleanup_provider()
# --- Provider Registry Tests ---
@pytest.mark.asyncio
async def test_get_provider_for_type_caches():
"""get_provider_for_type returns cached provider on second call."""
with patch("api.dependencies.get_settings") as mock_settings:
mock_settings.return_value = _make_mock_settings()
p1 = get_provider_for_type("nvidia_nim")
p2 = get_provider_for_type("nvidia_nim")
assert p1 is p2
assert isinstance(p1, NvidiaNimProvider)
@pytest.mark.asyncio
async def test_get_provider_for_type_different_types():
"""get_provider_for_type creates separate providers per type."""
with patch("api.dependencies.get_settings") as mock_settings:
mock_settings.return_value = _make_mock_settings()
nim = get_provider_for_type("nvidia_nim")
lmstudio = get_provider_for_type("lmstudio")
assert isinstance(nim, NvidiaNimProvider)
assert isinstance(lmstudio, LMStudioProvider)
assert nim is not lmstudio
@pytest.mark.asyncio
async def test_get_provider_for_type_missing_key_raises_503():
"""get_provider_for_type raises HTTPException 503 for missing API key."""
with patch("api.dependencies.get_settings") as mock_settings:
mock_settings.return_value = _make_mock_settings(open_router_api_key="")
with pytest.raises(HTTPException) as exc_info:
get_provider_for_type("open_router")
assert exc_info.value.status_code == 503
assert "OPENROUTER_API_KEY" in exc_info.value.detail
@pytest.mark.asyncio
async def test_cleanup_provider_cleans_all():
"""cleanup_provider cleans up all providers in the registry."""
with patch("api.dependencies.get_settings") as mock_settings:
mock_settings.return_value = _make_mock_settings()
nim = get_provider_for_type("nvidia_nim")
lmstudio = get_provider_for_type("lmstudio")
assert isinstance(nim, NvidiaNimProvider)
assert isinstance(lmstudio, LMStudioProvider)
nim._client = AsyncMock()
lmstudio._client = AsyncMock()
await cleanup_provider()
nim._client.aclose.assert_called_once()
lmstudio._client.aclose.assert_called_once()
+94
View File
@@ -61,3 +61,97 @@ def test_messages_request_model_mapping_logs(mock_settings):
assert "MODEL MAPPING" in args
assert "claude-2.1" in args
assert "target-model-from-settings" in args
def test_messages_request_resolved_provider_model_default(mock_settings):
"""resolved_provider_model is set to the full model string."""
with patch("api.models.anthropic.get_settings", return_value=mock_settings):
request = MessagesRequest(
model="claude-3-opus",
max_tokens=100,
messages=[Message(role="user", content="hello")],
)
assert (
request.resolved_provider_model == "nvidia_nim/target-model-from-settings"
)
def test_messages_request_tier_aware_opus_override():
"""Opus model routes to MODEL_OPUS when set."""
settings = Settings()
settings.model = "nvidia_nim/fallback-model"
settings.model_opus = "open_router/deepseek/deepseek-r1"
with patch("api.models.anthropic.get_settings", return_value=settings):
request = MessagesRequest(
model="claude-opus-4-20250514",
max_tokens=100,
messages=[Message(role="user", content="hello")],
)
assert request.model == "deepseek/deepseek-r1"
assert request.resolved_provider_model == "open_router/deepseek/deepseek-r1"
assert request.original_model == "claude-opus-4-20250514"
def test_messages_request_tier_aware_haiku_override():
"""Haiku model routes to MODEL_HAIKU when set."""
settings = Settings()
settings.model = "nvidia_nim/fallback-model"
settings.model_haiku = "lmstudio/qwen2.5-7b"
with patch("api.models.anthropic.get_settings", return_value=settings):
request = MessagesRequest(
model="claude-3-haiku-20240307",
max_tokens=100,
messages=[Message(role="user", content="hello")],
)
assert request.model == "qwen2.5-7b"
assert request.resolved_provider_model == "lmstudio/qwen2.5-7b"
def test_messages_request_tier_aware_sonnet_override():
"""Sonnet model routes to MODEL_SONNET when set."""
settings = Settings()
settings.model = "nvidia_nim/fallback-model"
settings.model_sonnet = "nvidia_nim/meta/llama-3.3-70b-instruct"
with patch("api.models.anthropic.get_settings", return_value=settings):
request = MessagesRequest(
model="claude-sonnet-4-20250514",
max_tokens=100,
messages=[Message(role="user", content="hello")],
)
assert request.model == "meta/llama-3.3-70b-instruct"
assert (
request.resolved_provider_model == "nvidia_nim/meta/llama-3.3-70b-instruct"
)
def test_messages_request_tier_fallback_when_not_set():
"""When tier override is None, falls back to MODEL."""
settings = Settings()
settings.model = "nvidia_nim/fallback-model"
# model_opus is None
with patch("api.models.anthropic.get_settings", return_value=settings):
request = MessagesRequest(
model="claude-opus-4-20250514",
max_tokens=100,
messages=[Message(role="user", content="hello")],
)
assert request.model == "fallback-model"
assert request.resolved_provider_model == "nvidia_nim/fallback-model"
def test_token_count_request_tier_aware():
"""TokenCountRequest also uses tier-aware resolution."""
settings = Settings()
settings.model = "nvidia_nim/fallback-model"
settings.model_haiku = "lmstudio/qwen2.5-7b"
with patch("api.models.anthropic.get_settings", return_value=settings):
request = TokenCountRequest(
model="claude-3-haiku-20240307",
messages=[Message(role="user", content="hello")],
)
assert request.model == "qwen2.5-7b"
+148 -2
View File
@@ -106,8 +106,6 @@ class TestSettings:
# --- NimSettings Validation Tests ---
class TestNimSettingsValidBounds:
"""Test that valid values within bounds are accepted."""
@@ -310,3 +308,151 @@ class TestSettingsOptionalStr:
monkeypatch.setenv("WHISPER_DEVICE", device)
s = Settings()
assert s.whisper_device == device
class TestPerTierModelMapping:
"""Test per-tier model fields and resolve_model()."""
def test_tier_fields_default_none(self):
"""Per-tier model fields default to None."""
from config.settings import Settings
s = Settings()
assert s.model_opus is None
assert s.model_sonnet is None
assert s.model_haiku is None
def test_model_opus_from_env(self, monkeypatch):
"""MODEL_OPUS env var is loaded."""
from config.settings import Settings
monkeypatch.setenv("MODEL_OPUS", "open_router/deepseek/deepseek-r1")
s = Settings()
assert s.model_opus == "open_router/deepseek/deepseek-r1"
def test_model_sonnet_from_env(self, monkeypatch):
"""MODEL_SONNET env var is loaded."""
from config.settings import Settings
monkeypatch.setenv("MODEL_SONNET", "nvidia_nim/meta/llama-3.3-70b-instruct")
s = Settings()
assert s.model_sonnet == "nvidia_nim/meta/llama-3.3-70b-instruct"
def test_model_haiku_from_env(self, monkeypatch):
"""MODEL_HAIKU env var is loaded."""
from config.settings import Settings
monkeypatch.setenv("MODEL_HAIKU", "lmstudio/qwen2.5-7b")
s = Settings()
assert s.model_haiku == "lmstudio/qwen2.5-7b"
def test_model_opus_invalid_provider_raises(self, monkeypatch):
"""MODEL_OPUS with invalid provider prefix raises ValidationError."""
from config.settings import Settings
monkeypatch.setenv("MODEL_OPUS", "bad_provider/some-model")
with pytest.raises(ValidationError, match="Invalid provider"):
Settings()
def test_model_opus_no_slash_raises(self, monkeypatch):
"""MODEL_OPUS without provider prefix raises ValidationError."""
from config.settings import Settings
monkeypatch.setenv("MODEL_OPUS", "noprefix")
with pytest.raises(ValidationError, match="provider type"):
Settings()
def test_model_haiku_invalid_provider_raises(self, monkeypatch):
"""MODEL_HAIKU with invalid provider prefix raises ValidationError."""
from config.settings import Settings
monkeypatch.setenv("MODEL_HAIKU", "invalid/model")
with pytest.raises(ValidationError, match="Invalid provider"):
Settings()
def test_resolve_model_opus_override(self):
"""resolve_model returns model_opus for opus model names."""
from config.settings import Settings
s = Settings()
s.model_opus = "open_router/deepseek/deepseek-r1"
assert (
s.resolve_model("claude-opus-4-20250514")
== "open_router/deepseek/deepseek-r1"
)
assert s.resolve_model("claude-3-opus") == "open_router/deepseek/deepseek-r1"
assert (
s.resolve_model("claude-3-opus-20240229")
== "open_router/deepseek/deepseek-r1"
)
def test_resolve_model_sonnet_override(self):
"""resolve_model returns model_sonnet for sonnet model names."""
from config.settings import Settings
s = Settings()
s.model_sonnet = "nvidia_nim/meta/llama-3.3-70b-instruct"
assert (
s.resolve_model("claude-sonnet-4-20250514")
== "nvidia_nim/meta/llama-3.3-70b-instruct"
)
assert (
s.resolve_model("claude-3-5-sonnet-20241022")
== "nvidia_nim/meta/llama-3.3-70b-instruct"
)
def test_resolve_model_haiku_override(self):
"""resolve_model returns model_haiku for haiku model names."""
from config.settings import Settings
s = Settings()
s.model_haiku = "lmstudio/qwen2.5-7b"
assert s.resolve_model("claude-3-haiku-20240307") == "lmstudio/qwen2.5-7b"
assert s.resolve_model("claude-3-5-haiku-20241022") == "lmstudio/qwen2.5-7b"
assert s.resolve_model("claude-haiku-4-20250514") == "lmstudio/qwen2.5-7b"
def test_resolve_model_fallback_when_tier_not_set(self):
"""resolve_model falls back to MODEL when tier override is None."""
from config.settings import Settings
s = Settings()
s.model = "nvidia_nim/fallback-model"
# No tier overrides set
assert s.resolve_model("claude-opus-4-20250514") == "nvidia_nim/fallback-model"
assert (
s.resolve_model("claude-sonnet-4-20250514") == "nvidia_nim/fallback-model"
)
assert s.resolve_model("claude-3-haiku-20240307") == "nvidia_nim/fallback-model"
def test_resolve_model_unknown_tier_falls_back(self):
"""resolve_model falls back to MODEL for unrecognized model names."""
from config.settings import Settings
s = Settings()
s.model = "nvidia_nim/fallback-model"
s.model_opus = "open_router/opus-model"
assert s.resolve_model("claude-2.1") == "nvidia_nim/fallback-model"
assert s.resolve_model("some-unknown-model") == "nvidia_nim/fallback-model"
def test_resolve_model_case_insensitive(self):
"""Tier classification is case-insensitive."""
from config.settings import Settings
s = Settings()
s.model_opus = "open_router/opus-model"
assert s.resolve_model("Claude-OPUS-4") == "open_router/opus-model"
def test_parse_provider_type(self):
"""parse_provider_type extracts provider from model string."""
from config.settings import Settings
assert Settings.parse_provider_type("nvidia_nim/meta/llama") == "nvidia_nim"
assert Settings.parse_provider_type("open_router/deepseek/r1") == "open_router"
assert Settings.parse_provider_type("lmstudio/qwen") == "lmstudio"
def test_parse_model_name(self):
"""parse_model_name extracts model name from model string."""
from config.settings import Settings
assert Settings.parse_model_name("nvidia_nim/meta/llama") == "meta/llama"
assert Settings.parse_model_name("lmstudio/qwen") == "qwen"
Generated
+1 -1
View File
@@ -1,6 +1,6 @@
version = 1
revision = 3
requires-python = ">=3.14.2"
requires-python = ">=3.14.3"
[[package]]
name = "accelerate"