diff --git a/.env.example b/.env.example index 3b3f7df..8357caf 100644 --- a/.env.example +++ b/.env.example @@ -10,20 +10,23 @@ OPENROUTER_API_KEY="" LM_STUDIO_BASE_URL="http://localhost:1234/v1" -# All Claude model requests are mapped to this model +# All Claude model requests are mapped to these models, plain model is fallback # Format: provider_type/model/name # Valid providers: "nvidia_nim" | "open_router" | "lmstudio" -MODEL="nvidia_nim/stepfun-ai/step-3.5-flash" +MODEL_OPUS="nvidia_nim/z-ai/glm4.7" +MODEL_SONNET="arcee-ai/trinity-large-preview:free" +MODEL_HAIKU="stepfun/step-3.5-flash:free" +MODEL="nvidia_nim/z-ai/glm4.7" -# Provider Config -PROVIDER_RATE_LIMIT=40 -PROVIDER_RATE_WINDOW=60 -PROVIDER_MAX_CONCURRENCY=5 +# Provider config +PROVIDER_RATE_LIMIT=1 +PROVIDER_RATE_WINDOW=3 +PROVIDER_MAX_CONCURRENCY=50 # HTTP client timeouts (seconds) for provider API requests -HTTP_READ_TIMEOUT=300 +HTTP_READ_TIMEOUT=60 HTTP_WRITE_TIMEOUT=10 HTTP_CONNECT_TIMEOUT=2 @@ -44,7 +47,7 @@ WHISPER_DEVICE="nvidia_nim" # - For cpu/cuda: Hugging Face ID or short name (tiny, base, small, medium, large-v2, large-v3, large-v3-turbo) # - For nvidia_nim: NVIDIA NIM model (e.g., "nvidia/parakeet-ctc-1.1b-asr", "openai/whisper-large-v3") # - For nvidia_nim, default to "openai/whisper-large-v3" for best performance -WHISPER_MODEL="openai/whisper-large-v3" +WHISPER_MODEL="openai/whisper-large-v3" HF_TOKEN="" diff --git a/AGENTS.md b/AGENTS.md index af1d413..d9e5439 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -3,7 +3,9 @@ > This file is identical to CLAUDE.md. Keep them in sync. ## CODING ENVIRONMENT + - Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed +- Install Python 3.14.3 using `uv python install 3.14.3` if not already installed - Always use `uv run` to run files instead of the global `python` command. - Current uv ruff formatter is set to py314 which has supports multiple exception types without paranthesis (except TypeError, ValueError:) - Read `.env.example` for environment variables. @@ -14,11 +16,13 @@ - All 5 checks are enforced in `tests.yml` on push/merge. ## IDENTITY & CONTEXT + - You are an expert Software Architect and Systems Engineer. - Goal: Zero-defect, root-cause-oriented engineering for bugs; test-driven engineering for new features. Think carefully; no need to rush. - Code: Write the simplest code possible. Keep the codebase minimal and modular. ## ARCHITECTURE PRINCIPLES (see PLAN.md) + - **Shared utilities**: Extract common logic into shared packages (e.g. `providers/common/`). Do not have one provider import from another provider's utils. - **DRY**: Extract shared base classes to eliminate duplication. Prefer composition over copy-paste. - **Encapsulation**: Use accessor methods for internal state (e.g. `set_current_task()`), not direct `_attribute` assignment from outside. @@ -30,6 +34,7 @@ - **Backward compatibility**: When moving modules, add re-exports from old locations so existing imports keep working. ## COGNITIVE WORKFLOW + 1. **ANALYZE**: Read relevant files. Do not guess. 2. **PLAN**: Map out the logic. Identify root cause or required changes. Order changes by dependency. 3. **EXECUTE**: Fix the cause, not the symptom. Execute incrementally with clear commits. @@ -38,8 +43,10 @@ 6. **PROPAGATION**: Changes impact multiple files; propagate updates correctly. ## SUMMARY STANDARDS + - Summaries must be technical and granular. - Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks] (if no residual risks then say none). ## TOOLS -- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use. \ No newline at end of file + +- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use. diff --git a/CLAUDE.md b/CLAUDE.md index af1d413..d9e5439 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -3,7 +3,9 @@ > This file is identical to CLAUDE.md. Keep them in sync. ## CODING ENVIRONMENT + - Install astral uv using "curl -LsSf https://astral.sh/uv/install.sh | sh" if not already installed +- Install Python 3.14.3 using `uv python install 3.14.3` if not already installed - Always use `uv run` to run files instead of the global `python` command. - Current uv ruff formatter is set to py314 which has supports multiple exception types without paranthesis (except TypeError, ValueError:) - Read `.env.example` for environment variables. @@ -14,11 +16,13 @@ - All 5 checks are enforced in `tests.yml` on push/merge. ## IDENTITY & CONTEXT + - You are an expert Software Architect and Systems Engineer. - Goal: Zero-defect, root-cause-oriented engineering for bugs; test-driven engineering for new features. Think carefully; no need to rush. - Code: Write the simplest code possible. Keep the codebase minimal and modular. ## ARCHITECTURE PRINCIPLES (see PLAN.md) + - **Shared utilities**: Extract common logic into shared packages (e.g. `providers/common/`). Do not have one provider import from another provider's utils. - **DRY**: Extract shared base classes to eliminate duplication. Prefer composition over copy-paste. - **Encapsulation**: Use accessor methods for internal state (e.g. `set_current_task()`), not direct `_attribute` assignment from outside. @@ -30,6 +34,7 @@ - **Backward compatibility**: When moving modules, add re-exports from old locations so existing imports keep working. ## COGNITIVE WORKFLOW + 1. **ANALYZE**: Read relevant files. Do not guess. 2. **PLAN**: Map out the logic. Identify root cause or required changes. Order changes by dependency. 3. **EXECUTE**: Fix the cause, not the symptom. Execute incrementally with clear commits. @@ -38,8 +43,10 @@ 6. **PROPAGATION**: Changes impact multiple files; propagate updates correctly. ## SUMMARY STANDARDS + - Summaries must be technical and granular. - Include: [Files Changed], [Logic Altered], [Verification Method], [Residual Risks] (if no residual risks then say none). ## TOOLS -- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use. \ No newline at end of file + +- Prefer built-in tools (grep, read_file, etc.) over manual workflows. Check tool availability before use. diff --git a/README.md b/README.md index 27ac0df..368f53f 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,7 @@ A lightweight proxy that routes Claude Code's Anthropic API calls to **NVIDIA NI - **LM Studio**: No API key needed. Run locally with [LM Studio](https://lmstudio.ai) 2. Install [Claude Code](https://github.com/anthropics/claude-code) 3. Install [uv](https://github.com/astral-sh/uv) +4. Install Python 3.14.3: `uv python install 3.14.3` ### Clone & Configure diff --git a/api/__init__.py b/api/__init__.py index 0a8e06d..459d019 100644 --- a/api/__init__.py +++ b/api/__init__.py @@ -1,7 +1,7 @@ """API layer for Claude Code Proxy.""" from .app import app, create_app -from .dependencies import get_provider +from .dependencies import get_provider, get_provider_for_type from .models import ( MessagesRequest, MessagesResponse, @@ -17,4 +17,5 @@ __all__ = [ "app", "create_app", "get_provider", + "get_provider_for_type", ] diff --git a/api/dependencies.py b/api/dependencies.py index 770d3af..cb9db64 100644 --- a/api/dependencies.py +++ b/api/dependencies.py @@ -12,8 +12,8 @@ from providers.lmstudio import LMStudioProvider from providers.nvidia_nim import NVIDIA_NIM_BASE_URL, NvidiaNimProvider from providers.open_router import OPENROUTER_BASE_URL, OpenRouterProvider -# Global provider instance (singleton) -_provider: BaseProvider | None = None +# Provider registry: keyed by provider type string, lazily populated +_providers: dict[str, BaseProvider] = {} def get_settings() -> Settings: @@ -21,9 +21,9 @@ def get_settings() -> Settings: return _get_settings() -def _create_provider(settings: Settings) -> BaseProvider: - """Construct and return a new provider instance from settings.""" - if settings.provider_type == "nvidia_nim": +def _create_provider_for_type(provider_type: str, settings: Settings) -> BaseProvider: + """Construct and return a new provider instance for the given provider type.""" + if provider_type == "nvidia_nim": if not settings.nvidia_nim_api_key or not settings.nvidia_nim_api_key.strip(): raise AuthenticationError( "NVIDIA_NIM_API_KEY is not set. Add it to your .env file. " @@ -39,8 +39,8 @@ def _create_provider(settings: Settings) -> BaseProvider: http_write_timeout=settings.http_write_timeout, http_connect_timeout=settings.http_connect_timeout, ) - provider = NvidiaNimProvider(config, nim_settings=settings.nim) - elif settings.provider_type == "open_router": + return NvidiaNimProvider(config, nim_settings=settings.nim) + if provider_type == "open_router": if not settings.open_router_api_key or not settings.open_router_api_key.strip(): raise AuthenticationError( "OPENROUTER_API_KEY is not set. Add it to your .env file. " @@ -56,8 +56,8 @@ def _create_provider(settings: Settings) -> BaseProvider: http_write_timeout=settings.http_write_timeout, http_connect_timeout=settings.http_connect_timeout, ) - provider = OpenRouterProvider(config) - elif settings.provider_type == "lmstudio": + return OpenRouterProvider(config) + if provider_type == "lmstudio": config = ProviderConfig( api_key="lm-studio", base_url=settings.lm_studio_base_url, @@ -68,37 +68,47 @@ def _create_provider(settings: Settings) -> BaseProvider: http_write_timeout=settings.http_write_timeout, http_connect_timeout=settings.http_connect_timeout, ) - provider = LMStudioProvider(config) - else: - logger.error( - "Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'", - settings.provider_type, - ) - raise ValueError( - f"Unknown provider_type: '{settings.provider_type}'. " - f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'" - ) - logger.info("Provider initialized: {}", settings.provider_type) - return provider + return LMStudioProvider(config) + logger.error( + "Unknown provider_type: '{}'. Supported: 'nvidia_nim', 'open_router', 'lmstudio'", + provider_type, + ) + raise ValueError( + f"Unknown provider_type: '{provider_type}'. " + f"Supported: 'nvidia_nim', 'open_router', 'lmstudio'" + ) -def get_provider() -> BaseProvider: - """Get or create the provider instance based on settings.provider_type.""" - global _provider - if _provider is None: +def get_provider_for_type(provider_type: str) -> BaseProvider: + """Get or create a provider for the given provider type. + + Providers are cached in the registry and reused across requests. + """ + if provider_type not in _providers: try: - _provider = _create_provider(get_settings()) + _providers[provider_type] = _create_provider_for_type( + provider_type, get_settings() + ) except AuthenticationError as e: raise HTTPException( status_code=503, detail=get_user_facing_error_message(e) ) from e - return _provider + logger.info("Provider initialized: {}", provider_type) + return _providers[provider_type] + + +def get_provider() -> BaseProvider: + """Get or create the default provider (based on MODEL env var). + + Backward-compatible convenience for health/root endpoints and tests. + """ + return get_provider_for_type(get_settings().provider_type) async def cleanup_provider(): - """Cleanup provider resources.""" - global _provider - if _provider: - await _provider.cleanup() - _provider = None + """Cleanup all provider resources.""" + global _providers + for provider in _providers.values(): + await provider.cleanup() + _providers = {} logger.debug("Provider cleanup completed") diff --git a/api/models/anthropic.py b/api/models/anthropic.py index f38a1ba..87f4698 100644 --- a/api/models/anthropic.py +++ b/api/models/anthropic.py @@ -6,13 +6,12 @@ from typing import Any, Literal from loguru import logger from pydantic import BaseModel, field_validator, model_validator -from config.settings import get_settings +from config.settings import Settings, get_settings + # ============================================================================= # Content Block Types # ============================================================================= - - class Role(StrEnum): user = "user" assistant = "assistant" @@ -55,8 +54,6 @@ class SystemContent(BaseModel): # ============================================================================= # Message Types # ============================================================================= - - class Message(BaseModel): role: Literal["user", "assistant"] content: ( @@ -85,8 +82,6 @@ class ThinkingConfig(BaseModel): # ============================================================================= # Request Models # ============================================================================= - - class MessagesRequest(BaseModel): model: str max_tokens: int | None = None @@ -103,15 +98,18 @@ class MessagesRequest(BaseModel): thinking: ThinkingConfig | None = None extra_body: dict[str, Any] | None = None original_model: str | None = None + resolved_provider_model: str | None = None @model_validator(mode="after") def map_model(self) -> MessagesRequest: - """Map any Claude model name to the configured model.""" + """Map any Claude model name to the configured model (tier-aware).""" settings = get_settings() if self.original_model is None: self.original_model = self.model - self.model = settings.model_name + resolved_full = settings.resolve_model(self.original_model) + self.resolved_provider_model = resolved_full + self.model = Settings.parse_model_name(resolved_full) if self.model != self.original_model: logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'") @@ -129,7 +127,8 @@ class TokenCountRequest(BaseModel): @field_validator("model") @classmethod - def validate_model_field(cls, v, info): - """Map any Claude model name to the configured model.""" + def validate_model_field(cls, v: str, info) -> str: + """Map any Claude model name to the configured model (tier-aware).""" settings = get_settings() - return settings.model_name + resolved_full = settings.resolve_model(v) + return Settings.parse_model_name(resolved_full) diff --git a/api/routes.py b/api/routes.py index 7513bb5..2c5f2df 100644 --- a/api/routes.py +++ b/api/routes.py @@ -9,12 +9,11 @@ from fastapi.responses import StreamingResponse from loguru import logger from config.settings import Settings -from providers.base import BaseProvider from providers.common import get_user_facing_error_message from providers.exceptions import InvalidRequestError, ProviderError from providers.logging_utils import build_request_summary, log_request_compact -from .dependencies import get_provider, get_settings +from .dependencies import get_provider_for_type, get_settings from .models.anthropic import MessagesRequest, TokenCountRequest from .models.responses import TokenCountResponse from .optimization_handlers import try_optimizations @@ -26,13 +25,10 @@ router = APIRouter() # ============================================================================= # Routes # ============================================================================= - - @router.post("/v1/messages") async def create_message( request_data: MessagesRequest, raw_request: Request, - provider: BaseProvider = Depends(get_provider), settings: Settings = Depends(get_settings), ): """Create a message (always streaming).""" @@ -45,6 +41,12 @@ async def create_message( if optimized is not None: return optimized + # Resolve provider from the tier-aware model mapping + provider_type = Settings.parse_provider_type( + request_data.resolved_provider_model or settings.model + ) + provider = get_provider_for_type(provider_type) + request_id = f"req_{uuid.uuid4().hex[:12]}" log_request_compact(logger, request_id, request_data) diff --git a/config/settings.py b/config/settings.py index 0a4102a..7c30507 100644 --- a/config/settings.py +++ b/config/settings.py @@ -33,10 +33,16 @@ class Settings(BaseSettings): ) # ==================== Model ==================== - # All Claude model requests are mapped to this single model + # All Claude model requests are mapped to this single model (fallback) # Format: provider_type/model/name model: str = "nvidia_nim/meta/llama3-70b-instruct" + # Per-tier model overrides (optional, falls back to MODEL) + # Each can use a different provider + model_opus: str | None = Field(default=None, validation_alias="MODEL_OPUS") + model_sonnet: str | None = Field(default=None, validation_alias="MODEL_SONNET") + model_haiku: str | None = Field(default=None, validation_alias="MODEL_HAIKU") + # ==================== Provider Rate Limiting ==================== provider_rate_limit: int = Field(default=40, validation_alias="PROVIDER_RATE_LIMIT") provider_rate_window: int = Field( @@ -124,9 +130,11 @@ class Settings(BaseSettings): ) return v - @field_validator("model") + @field_validator("model", "model_opus", "model_sonnet", "model_haiku") @classmethod - def validate_model_format(cls, v: str) -> str: + def validate_model_format(cls, v: str | None) -> str | None: + if v is None: + return None valid_providers = ("nvidia_nim", "open_router", "lmstudio") if "/" not in v: raise ValueError( @@ -157,14 +165,39 @@ class Settings(BaseSettings): @property def provider_type(self) -> str: - """Extract provider type from the model string.""" + """Extract provider type from the default model string.""" return self.model.split("/", 1)[0] @property def model_name(self) -> str: - """Extract the actual model name from the model string.""" + """Extract the actual model name from the default model string.""" return self.model.split("/", 1)[1] + def resolve_model(self, claude_model_name: str) -> str: + """Resolve a Claude model name to the configured provider/model string. + + Classifies the incoming model into a tier (opus/sonnet/haiku) and + returns the tier-specific model if configured, otherwise the fallback MODEL. + """ + name_lower = claude_model_name.lower() + if "opus" in name_lower and self.model_opus is not None: + return self.model_opus + if "haiku" in name_lower and self.model_haiku is not None: + return self.model_haiku + if "sonnet" in name_lower and self.model_sonnet is not None: + return self.model_sonnet + return self.model + + @staticmethod + def parse_provider_type(model_string: str) -> str: + """Extract provider type from any 'provider/model' string.""" + return model_string.split("/", 1)[0] + + @staticmethod + def parse_model_name(model_string: str) -> str: + """Extract model name from any 'provider/model' string.""" + return model_string.split("/", 1)[1] + model_config = SettingsConfigDict( env_file=".env", env_file_encoding="utf-8", diff --git a/pyproject.toml b/pyproject.toml index 57efda9..d0bf120 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "free-claude-code" version = "2.0.0" description = "Middleware between Claude Code CLI (Anthropic API) and NVIDIA NIM" readme = "README.md" -requires-python = ">=3.14.2" +requires-python = ">=3.14.3" dependencies = [ "fastapi[standard]>=0.115.11", "uvicorn>=0.34.0", diff --git a/tests/api/test_api.py b/tests/api/test_api.py index 7339f58..3b3305a 100644 --- a/tests/api/test_api.py +++ b/tests/api/test_api.py @@ -1,9 +1,8 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from fastapi.testclient import TestClient from api.app import app -from api.dependencies import get_provider from providers.nvidia_nim import NvidiaNimProvider # Mock provider @@ -22,12 +21,10 @@ async def _mock_stream_response(*args, **kwargs): mock_provider.stream_response = _mock_stream_response +# Patch get_provider_for_type to always return mock_provider +_patcher = patch("api.routes.get_provider_for_type", return_value=mock_provider) +_patcher.start() -def override_get_provider(): - return mock_provider - - -app.dependency_overrides[get_provider] = override_get_provider client = TestClient(app) diff --git a/tests/api/test_dependencies.py b/tests/api/test_dependencies.py index 270ce01..7d130b1 100644 --- a/tests/api/test_dependencies.py +++ b/tests/api/test_dependencies.py @@ -3,7 +3,12 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi import HTTPException -from api.dependencies import cleanup_provider, get_provider, get_settings +from api.dependencies import ( + cleanup_provider, + get_provider, + get_provider_for_type, + get_settings, +) from config.nim import NimSettings from providers.lmstudio import LMStudioProvider from providers.nvidia_nim import NvidiaNimProvider @@ -32,9 +37,13 @@ def _make_mock_settings(**overrides): @pytest.fixture(autouse=True) def reset_provider(): - """Reset the global _provider singleton between tests.""" - with patch("api.dependencies._provider", None): - yield + """Reset the global _providers registry between tests.""" + import api.dependencies + + saved = api.dependencies._providers + api.dependencies._providers = {} + yield + api.dependencies._providers = saved @pytest.mark.asyncio @@ -213,6 +222,70 @@ async def test_cleanup_provider_aclose_raises(): provider._client = AsyncMock() provider._client.aclose = AsyncMock(side_effect=RuntimeError("cleanup failed")) - # Should propagate the error (current behavior - no try/except) + # Should propagate the error with pytest.raises(RuntimeError, match="cleanup failed"): await cleanup_provider() + + +# --- Provider Registry Tests --- + + +@pytest.mark.asyncio +async def test_get_provider_for_type_caches(): + """get_provider_for_type returns cached provider on second call.""" + with patch("api.dependencies.get_settings") as mock_settings: + mock_settings.return_value = _make_mock_settings() + + p1 = get_provider_for_type("nvidia_nim") + p2 = get_provider_for_type("nvidia_nim") + + assert p1 is p2 + assert isinstance(p1, NvidiaNimProvider) + + +@pytest.mark.asyncio +async def test_get_provider_for_type_different_types(): + """get_provider_for_type creates separate providers per type.""" + with patch("api.dependencies.get_settings") as mock_settings: + mock_settings.return_value = _make_mock_settings() + + nim = get_provider_for_type("nvidia_nim") + lmstudio = get_provider_for_type("lmstudio") + + assert isinstance(nim, NvidiaNimProvider) + assert isinstance(lmstudio, LMStudioProvider) + assert nim is not lmstudio + + +@pytest.mark.asyncio +async def test_get_provider_for_type_missing_key_raises_503(): + """get_provider_for_type raises HTTPException 503 for missing API key.""" + with patch("api.dependencies.get_settings") as mock_settings: + mock_settings.return_value = _make_mock_settings(open_router_api_key="") + + with pytest.raises(HTTPException) as exc_info: + get_provider_for_type("open_router") + + assert exc_info.value.status_code == 503 + assert "OPENROUTER_API_KEY" in exc_info.value.detail + + +@pytest.mark.asyncio +async def test_cleanup_provider_cleans_all(): + """cleanup_provider cleans up all providers in the registry.""" + with patch("api.dependencies.get_settings") as mock_settings: + mock_settings.return_value = _make_mock_settings() + + nim = get_provider_for_type("nvidia_nim") + lmstudio = get_provider_for_type("lmstudio") + + assert isinstance(nim, NvidiaNimProvider) + assert isinstance(lmstudio, LMStudioProvider) + + nim._client = AsyncMock() + lmstudio._client = AsyncMock() + + await cleanup_provider() + + nim._client.aclose.assert_called_once() + lmstudio._client.aclose.assert_called_once() diff --git a/tests/api/test_models_validators.py b/tests/api/test_models_validators.py index 822c210..bd56701 100644 --- a/tests/api/test_models_validators.py +++ b/tests/api/test_models_validators.py @@ -61,3 +61,97 @@ def test_messages_request_model_mapping_logs(mock_settings): assert "MODEL MAPPING" in args assert "claude-2.1" in args assert "target-model-from-settings" in args + + +def test_messages_request_resolved_provider_model_default(mock_settings): + """resolved_provider_model is set to the full model string.""" + with patch("api.models.anthropic.get_settings", return_value=mock_settings): + request = MessagesRequest( + model="claude-3-opus", + max_tokens=100, + messages=[Message(role="user", content="hello")], + ) + assert ( + request.resolved_provider_model == "nvidia_nim/target-model-from-settings" + ) + + +def test_messages_request_tier_aware_opus_override(): + """Opus model routes to MODEL_OPUS when set.""" + settings = Settings() + settings.model = "nvidia_nim/fallback-model" + settings.model_opus = "open_router/deepseek/deepseek-r1" + + with patch("api.models.anthropic.get_settings", return_value=settings): + request = MessagesRequest( + model="claude-opus-4-20250514", + max_tokens=100, + messages=[Message(role="user", content="hello")], + ) + assert request.model == "deepseek/deepseek-r1" + assert request.resolved_provider_model == "open_router/deepseek/deepseek-r1" + assert request.original_model == "claude-opus-4-20250514" + + +def test_messages_request_tier_aware_haiku_override(): + """Haiku model routes to MODEL_HAIKU when set.""" + settings = Settings() + settings.model = "nvidia_nim/fallback-model" + settings.model_haiku = "lmstudio/qwen2.5-7b" + + with patch("api.models.anthropic.get_settings", return_value=settings): + request = MessagesRequest( + model="claude-3-haiku-20240307", + max_tokens=100, + messages=[Message(role="user", content="hello")], + ) + assert request.model == "qwen2.5-7b" + assert request.resolved_provider_model == "lmstudio/qwen2.5-7b" + + +def test_messages_request_tier_aware_sonnet_override(): + """Sonnet model routes to MODEL_SONNET when set.""" + settings = Settings() + settings.model = "nvidia_nim/fallback-model" + settings.model_sonnet = "nvidia_nim/meta/llama-3.3-70b-instruct" + + with patch("api.models.anthropic.get_settings", return_value=settings): + request = MessagesRequest( + model="claude-sonnet-4-20250514", + max_tokens=100, + messages=[Message(role="user", content="hello")], + ) + assert request.model == "meta/llama-3.3-70b-instruct" + assert ( + request.resolved_provider_model == "nvidia_nim/meta/llama-3.3-70b-instruct" + ) + + +def test_messages_request_tier_fallback_when_not_set(): + """When tier override is None, falls back to MODEL.""" + settings = Settings() + settings.model = "nvidia_nim/fallback-model" + # model_opus is None + + with patch("api.models.anthropic.get_settings", return_value=settings): + request = MessagesRequest( + model="claude-opus-4-20250514", + max_tokens=100, + messages=[Message(role="user", content="hello")], + ) + assert request.model == "fallback-model" + assert request.resolved_provider_model == "nvidia_nim/fallback-model" + + +def test_token_count_request_tier_aware(): + """TokenCountRequest also uses tier-aware resolution.""" + settings = Settings() + settings.model = "nvidia_nim/fallback-model" + settings.model_haiku = "lmstudio/qwen2.5-7b" + + with patch("api.models.anthropic.get_settings", return_value=settings): + request = TokenCountRequest( + model="claude-3-haiku-20240307", + messages=[Message(role="user", content="hello")], + ) + assert request.model == "qwen2.5-7b" diff --git a/tests/config/test_config.py b/tests/config/test_config.py index 397e996..f45df73 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -106,8 +106,6 @@ class TestSettings: # --- NimSettings Validation Tests --- - - class TestNimSettingsValidBounds: """Test that valid values within bounds are accepted.""" @@ -310,3 +308,151 @@ class TestSettingsOptionalStr: monkeypatch.setenv("WHISPER_DEVICE", device) s = Settings() assert s.whisper_device == device + + +class TestPerTierModelMapping: + """Test per-tier model fields and resolve_model().""" + + def test_tier_fields_default_none(self): + """Per-tier model fields default to None.""" + from config.settings import Settings + + s = Settings() + assert s.model_opus is None + assert s.model_sonnet is None + assert s.model_haiku is None + + def test_model_opus_from_env(self, monkeypatch): + """MODEL_OPUS env var is loaded.""" + from config.settings import Settings + + monkeypatch.setenv("MODEL_OPUS", "open_router/deepseek/deepseek-r1") + s = Settings() + assert s.model_opus == "open_router/deepseek/deepseek-r1" + + def test_model_sonnet_from_env(self, monkeypatch): + """MODEL_SONNET env var is loaded.""" + from config.settings import Settings + + monkeypatch.setenv("MODEL_SONNET", "nvidia_nim/meta/llama-3.3-70b-instruct") + s = Settings() + assert s.model_sonnet == "nvidia_nim/meta/llama-3.3-70b-instruct" + + def test_model_haiku_from_env(self, monkeypatch): + """MODEL_HAIKU env var is loaded.""" + from config.settings import Settings + + monkeypatch.setenv("MODEL_HAIKU", "lmstudio/qwen2.5-7b") + s = Settings() + assert s.model_haiku == "lmstudio/qwen2.5-7b" + + def test_model_opus_invalid_provider_raises(self, monkeypatch): + """MODEL_OPUS with invalid provider prefix raises ValidationError.""" + from config.settings import Settings + + monkeypatch.setenv("MODEL_OPUS", "bad_provider/some-model") + with pytest.raises(ValidationError, match="Invalid provider"): + Settings() + + def test_model_opus_no_slash_raises(self, monkeypatch): + """MODEL_OPUS without provider prefix raises ValidationError.""" + from config.settings import Settings + + monkeypatch.setenv("MODEL_OPUS", "noprefix") + with pytest.raises(ValidationError, match="provider type"): + Settings() + + def test_model_haiku_invalid_provider_raises(self, monkeypatch): + """MODEL_HAIKU with invalid provider prefix raises ValidationError.""" + from config.settings import Settings + + monkeypatch.setenv("MODEL_HAIKU", "invalid/model") + with pytest.raises(ValidationError, match="Invalid provider"): + Settings() + + def test_resolve_model_opus_override(self): + """resolve_model returns model_opus for opus model names.""" + from config.settings import Settings + + s = Settings() + s.model_opus = "open_router/deepseek/deepseek-r1" + assert ( + s.resolve_model("claude-opus-4-20250514") + == "open_router/deepseek/deepseek-r1" + ) + assert s.resolve_model("claude-3-opus") == "open_router/deepseek/deepseek-r1" + assert ( + s.resolve_model("claude-3-opus-20240229") + == "open_router/deepseek/deepseek-r1" + ) + + def test_resolve_model_sonnet_override(self): + """resolve_model returns model_sonnet for sonnet model names.""" + from config.settings import Settings + + s = Settings() + s.model_sonnet = "nvidia_nim/meta/llama-3.3-70b-instruct" + assert ( + s.resolve_model("claude-sonnet-4-20250514") + == "nvidia_nim/meta/llama-3.3-70b-instruct" + ) + assert ( + s.resolve_model("claude-3-5-sonnet-20241022") + == "nvidia_nim/meta/llama-3.3-70b-instruct" + ) + + def test_resolve_model_haiku_override(self): + """resolve_model returns model_haiku for haiku model names.""" + from config.settings import Settings + + s = Settings() + s.model_haiku = "lmstudio/qwen2.5-7b" + assert s.resolve_model("claude-3-haiku-20240307") == "lmstudio/qwen2.5-7b" + assert s.resolve_model("claude-3-5-haiku-20241022") == "lmstudio/qwen2.5-7b" + assert s.resolve_model("claude-haiku-4-20250514") == "lmstudio/qwen2.5-7b" + + def test_resolve_model_fallback_when_tier_not_set(self): + """resolve_model falls back to MODEL when tier override is None.""" + from config.settings import Settings + + s = Settings() + s.model = "nvidia_nim/fallback-model" + # No tier overrides set + assert s.resolve_model("claude-opus-4-20250514") == "nvidia_nim/fallback-model" + assert ( + s.resolve_model("claude-sonnet-4-20250514") == "nvidia_nim/fallback-model" + ) + assert s.resolve_model("claude-3-haiku-20240307") == "nvidia_nim/fallback-model" + + def test_resolve_model_unknown_tier_falls_back(self): + """resolve_model falls back to MODEL for unrecognized model names.""" + from config.settings import Settings + + s = Settings() + s.model = "nvidia_nim/fallback-model" + s.model_opus = "open_router/opus-model" + assert s.resolve_model("claude-2.1") == "nvidia_nim/fallback-model" + assert s.resolve_model("some-unknown-model") == "nvidia_nim/fallback-model" + + def test_resolve_model_case_insensitive(self): + """Tier classification is case-insensitive.""" + from config.settings import Settings + + s = Settings() + s.model_opus = "open_router/opus-model" + assert s.resolve_model("Claude-OPUS-4") == "open_router/opus-model" + + def test_parse_provider_type(self): + """parse_provider_type extracts provider from model string.""" + from config.settings import Settings + + assert Settings.parse_provider_type("nvidia_nim/meta/llama") == "nvidia_nim" + assert Settings.parse_provider_type("open_router/deepseek/r1") == "open_router" + assert Settings.parse_provider_type("lmstudio/qwen") == "lmstudio" + + def test_parse_model_name(self): + """parse_model_name extracts model name from model string.""" + from config.settings import Settings + + assert Settings.parse_model_name("nvidia_nim/meta/llama") == "meta/llama" + assert Settings.parse_model_name("lmstudio/qwen") == "qwen" diff --git a/uv.lock b/uv.lock index a2a15f2..b80b9b9 100644 --- a/uv.lock +++ b/uv.lock @@ -1,6 +1,6 @@ version = 1 revision = 3 -requires-python = ">=3.14.2" +requires-python = ">=3.14.3" [[package]] name = "accelerate"