mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-06-02 06:13:46 +02:00
Validate configured models at startup
This commit is contained in:
+12
-3
@@ -84,9 +84,18 @@ class AppRuntime:
|
||||
logger.info("Starting Claude Code Proxy...")
|
||||
self._provider_registry = ProviderRegistry()
|
||||
self.app.state.provider_registry = self._provider_registry
|
||||
warn_if_process_auth_token(self.settings)
|
||||
await self._start_messaging_if_configured()
|
||||
self._publish_state()
|
||||
try:
|
||||
warn_if_process_auth_token(self.settings)
|
||||
await self._provider_registry.validate_configured_models(self.settings)
|
||||
await self._start_messaging_if_configured()
|
||||
self._publish_state()
|
||||
except Exception:
|
||||
await best_effort(
|
||||
"provider_registry.cleanup",
|
||||
self._provider_registry.cleanup(),
|
||||
log_verbose_errors=self.settings.log_api_error_tracebacks,
|
||||
)
|
||||
raise
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
verbose = self.settings.log_api_error_tracebacks
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -15,6 +16,16 @@ from .nim import NimSettings
|
||||
from .provider_ids import SUPPORTED_PROVIDER_IDS
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ConfiguredChatModelRef:
|
||||
"""A unique configured chat model reference and the env keys that set it."""
|
||||
|
||||
model_ref: str
|
||||
provider_id: str
|
||||
model_id: str
|
||||
sources: tuple[str, ...]
|
||||
|
||||
|
||||
def _env_files() -> tuple[Path, ...]:
|
||||
"""Return env file paths in priority order (later overrides earlier)."""
|
||||
files: list[Path] = [
|
||||
@@ -441,6 +452,30 @@ class Settings(BaseSettings):
|
||||
return self.model_sonnet
|
||||
return self.model
|
||||
|
||||
def configured_chat_model_refs(self) -> tuple[ConfiguredChatModelRef, ...]:
|
||||
"""Return unique configured chat provider/model refs with source env keys."""
|
||||
candidates = (
|
||||
("MODEL", self.model),
|
||||
("MODEL_OPUS", self.model_opus),
|
||||
("MODEL_SONNET", self.model_sonnet),
|
||||
("MODEL_HAIKU", self.model_haiku),
|
||||
)
|
||||
sources_by_ref: dict[str, list[str]] = {}
|
||||
for source, model_ref in candidates:
|
||||
if model_ref is None:
|
||||
continue
|
||||
sources_by_ref.setdefault(model_ref, []).append(source)
|
||||
|
||||
return tuple(
|
||||
ConfiguredChatModelRef(
|
||||
model_ref=model_ref,
|
||||
provider_id=Settings.parse_provider_type(model_ref),
|
||||
model_id=Settings.parse_model_name(model_ref),
|
||||
sources=tuple(sources),
|
||||
)
|
||||
for model_ref, sources in sources_by_ref.items()
|
||||
)
|
||||
|
||||
def resolve_thinking(self, claude_model_name: str) -> bool:
|
||||
"""Resolve whether thinking is enabled for an incoming Claude model name."""
|
||||
name_lower = claude_model_name.lower()
|
||||
|
||||
@@ -10,6 +10,7 @@ from .exceptions import (
|
||||
APIError,
|
||||
AuthenticationError,
|
||||
InvalidRequestError,
|
||||
ModelListResponseError,
|
||||
OverloadedError,
|
||||
ProviderError,
|
||||
RateLimitError,
|
||||
@@ -21,6 +22,7 @@ __all__ = [
|
||||
"AuthenticationError",
|
||||
"BaseProvider",
|
||||
"InvalidRequestError",
|
||||
"ModelListResponseError",
|
||||
"OverloadedError",
|
||||
"ProviderConfig",
|
||||
"ProviderError",
|
||||
|
||||
@@ -26,11 +26,23 @@ from providers.error_mapping import (
|
||||
map_error,
|
||||
user_visible_message_for_mapped_provider_error,
|
||||
)
|
||||
from providers.exceptions import ModelListResponseError
|
||||
from providers.model_listing import extract_openai_model_ids
|
||||
from providers.rate_limit import GlobalRateLimiter
|
||||
|
||||
StreamChunkMode = Literal["line", "event"]
|
||||
|
||||
|
||||
def _model_list_json(response: httpx.Response, *, provider_name: str) -> Any:
|
||||
response.raise_for_status()
|
||||
try:
|
||||
return response.json()
|
||||
except ValueError as exc:
|
||||
raise ModelListResponseError(
|
||||
f"{provider_name} model-list response is malformed: invalid JSON"
|
||||
) from exc
|
||||
|
||||
|
||||
class AnthropicMessagesTransport(BaseProvider):
|
||||
"""Base class for providers that stream from an Anthropic-compatible endpoint."""
|
||||
|
||||
@@ -68,6 +80,32 @@ class AnthropicMessagesTransport(BaseProvider):
|
||||
"""Release HTTP client resources."""
|
||||
await self._client.aclose()
|
||||
|
||||
async def list_model_ids(self) -> frozenset[str]:
|
||||
"""Return model ids from an OpenAI-compatible ``/models`` endpoint."""
|
||||
response = await self._send_model_list_request()
|
||||
try:
|
||||
payload = _model_list_json(response, provider_name=self._provider_name)
|
||||
return self._extract_model_ids_from_model_list_payload(payload)
|
||||
finally:
|
||||
await response.aclose()
|
||||
|
||||
async def _send_model_list_request(self) -> httpx.Response:
|
||||
"""Query the provider endpoint that advertises available model ids."""
|
||||
return await self._client.get(
|
||||
"/models",
|
||||
headers=self._model_list_headers(),
|
||||
)
|
||||
|
||||
def _model_list_headers(self) -> dict[str, str]:
|
||||
"""Return headers for model-list requests."""
|
||||
return {}
|
||||
|
||||
def _extract_model_ids_from_model_list_payload(
|
||||
self, payload: Any
|
||||
) -> frozenset[str]:
|
||||
"""Parse the provider model-list response body."""
|
||||
return extract_openai_model_ids(payload, provider_name=self._provider_name)
|
||||
|
||||
def _request_headers(self) -> dict[str, str]:
|
||||
"""Return headers for the native messages request."""
|
||||
return {"Content-Type": "application/json"}
|
||||
|
||||
@@ -105,6 +105,10 @@ class BaseProvider(ABC):
|
||||
async def cleanup(self) -> None:
|
||||
"""Release any resources held by this provider."""
|
||||
|
||||
@abstractmethod
|
||||
async def list_model_ids(self) -> frozenset[str]:
|
||||
"""Return the model ids currently advertised by this provider."""
|
||||
|
||||
@abstractmethod
|
||||
async def stream_response(
|
||||
self,
|
||||
|
||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from providers.anthropic_messages import AnthropicMessagesTransport
|
||||
from providers.base import ProviderConfig
|
||||
from providers.defaults import DEEPSEEK_ANTHROPIC_DEFAULT_BASE
|
||||
@@ -35,3 +37,15 @@ class DeepSeekProvider(AnthropicMessagesTransport):
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": self._api_key,
|
||||
}
|
||||
|
||||
async def _send_model_list_request(self) -> httpx.Response:
|
||||
"""DeepSeek lists models from the OpenAI-format root, not /anthropic."""
|
||||
url = str(
|
||||
httpx.URL(self._base_url).copy_with(
|
||||
path="/models", query=None, fragment=None
|
||||
)
|
||||
)
|
||||
return await self._client.get(url, headers=self._model_list_headers())
|
||||
|
||||
def _model_list_headers(self) -> dict[str, str]:
|
||||
return {"Authorization": f"Bearer {self._api_key}"}
|
||||
|
||||
@@ -107,3 +107,7 @@ class ServiceUnavailableError(ProviderError):
|
||||
error_type="api_error",
|
||||
raw_error=raw_error,
|
||||
)
|
||||
|
||||
|
||||
class ModelListResponseError(ServiceUnavailableError):
|
||||
"""Raised when a provider model-list response cannot be parsed safely."""
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Provider model-list response parsing helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from providers.exceptions import ModelListResponseError
|
||||
|
||||
|
||||
def extract_openai_model_ids(payload: Any, *, provider_name: str) -> frozenset[str]:
|
||||
"""Extract model ids from an OpenAI-compatible ``/models`` response."""
|
||||
data = _field(payload, "data")
|
||||
if not _is_sequence(data):
|
||||
raise _malformed(provider_name, "expected top-level data array")
|
||||
|
||||
model_ids: set[str] = set()
|
||||
for item in data:
|
||||
model_id = _field(item, "id")
|
||||
if not isinstance(model_id, str) or not model_id.strip():
|
||||
raise _malformed(provider_name, "expected every data item to include id")
|
||||
model_ids.add(model_id)
|
||||
|
||||
if not model_ids:
|
||||
raise _malformed(provider_name, "response did not include any model ids")
|
||||
return frozenset(model_ids)
|
||||
|
||||
|
||||
def extract_ollama_model_ids(payload: Any, *, provider_name: str) -> frozenset[str]:
|
||||
"""Extract model ids from Ollama's native ``/api/tags`` response."""
|
||||
models = _field(payload, "models")
|
||||
if not _is_sequence(models):
|
||||
raise _malformed(provider_name, "expected top-level models array")
|
||||
|
||||
model_ids: set[str] = set()
|
||||
for item in models:
|
||||
item_ids: list[str] = []
|
||||
for key in ("model", "name"):
|
||||
value = _field(item, key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
item_ids.append(value)
|
||||
if not item_ids:
|
||||
raise _malformed(
|
||||
provider_name,
|
||||
"expected every models item to include model or name",
|
||||
)
|
||||
model_ids.update(item_ids)
|
||||
|
||||
if not model_ids:
|
||||
raise _malformed(provider_name, "response did not include any model ids")
|
||||
return frozenset(model_ids)
|
||||
|
||||
|
||||
def _field(item: Any, name: str) -> Any:
|
||||
if isinstance(item, Mapping):
|
||||
return item.get(name)
|
||||
return getattr(item, name, None)
|
||||
|
||||
|
||||
def _is_sequence(value: Any) -> bool:
|
||||
return isinstance(value, Sequence) and not isinstance(
|
||||
value, str | bytes | bytearray
|
||||
)
|
||||
|
||||
|
||||
def _malformed(provider_name: str, reason: str) -> ModelListResponseError:
|
||||
return ModelListResponseError(
|
||||
f"{provider_name} model-list response is malformed: {reason}"
|
||||
)
|
||||
@@ -5,6 +5,7 @@ import httpx
|
||||
from providers.anthropic_messages import AnthropicMessagesTransport
|
||||
from providers.base import ProviderConfig
|
||||
from providers.defaults import OLLAMA_DEFAULT_BASE
|
||||
from providers.model_listing import extract_ollama_model_ids
|
||||
|
||||
|
||||
class OllamaProvider(AnthropicMessagesTransport):
|
||||
@@ -27,3 +28,12 @@ class OllamaProvider(AnthropicMessagesTransport):
|
||||
headers=self._request_headers(),
|
||||
)
|
||||
return await self._client.send(request, stream=True)
|
||||
|
||||
async def _send_model_list_request(self) -> httpx.Response:
|
||||
"""Query Ollama's native local model-list endpoint."""
|
||||
return await self._client.get(f"{self._base_url}/api/tags")
|
||||
|
||||
def _extract_model_ids_from_model_list_payload(
|
||||
self, payload: object
|
||||
) -> frozenset[str]:
|
||||
return extract_ollama_model_ids(payload, provider_name=self._provider_name)
|
||||
|
||||
@@ -51,6 +51,10 @@ class OpenRouterProvider(AnthropicMessagesTransport):
|
||||
"anthropic-version": _ANTHROPIC_VERSION,
|
||||
}
|
||||
|
||||
def _model_list_headers(self) -> dict[str, str]:
|
||||
"""Return OpenRouter's OpenAI-compatible model-list headers."""
|
||||
return {"Authorization": f"Bearer {self._api_key}"}
|
||||
|
||||
def _new_stream_state(self, request: Any, *, thinking_enabled: bool) -> Any:
|
||||
"""Create per-stream state for thinking block filtering."""
|
||||
return NativeSseBlockPolicyState()
|
||||
|
||||
@@ -28,6 +28,7 @@ from providers.error_mapping import (
|
||||
map_error,
|
||||
user_visible_message_for_mapped_provider_error,
|
||||
)
|
||||
from providers.model_listing import extract_openai_model_ids
|
||||
from providers.rate_limit import GlobalRateLimiter
|
||||
|
||||
|
||||
@@ -106,6 +107,11 @@ class OpenAIChatTransport(BaseProvider):
|
||||
if client is not None:
|
||||
await client.aclose()
|
||||
|
||||
async def list_model_ids(self) -> frozenset[str]:
|
||||
"""Return model ids from the provider's OpenAI-compatible models endpoint."""
|
||||
payload = await self._client.models.list()
|
||||
return extract_openai_model_ids(payload, provider_name=self._provider_name)
|
||||
|
||||
@abstractmethod
|
||||
def _build_request_body(
|
||||
self, request: Any, thinking_enabled: bool | None = None
|
||||
|
||||
+98
-2
@@ -2,16 +2,27 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, MutableMapping
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from config.provider_catalog import (
|
||||
PROVIDER_CATALOG,
|
||||
SUPPORTED_PROVIDER_IDS,
|
||||
ProviderDescriptor,
|
||||
)
|
||||
from config.settings import Settings
|
||||
from config.settings import ConfiguredChatModelRef, Settings
|
||||
from providers.base import BaseProvider, ProviderConfig
|
||||
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
|
||||
from providers.exceptions import (
|
||||
AuthenticationError,
|
||||
ModelListResponseError,
|
||||
ProviderError,
|
||||
ServiceUnavailableError,
|
||||
UnknownProviderTypeError,
|
||||
)
|
||||
|
||||
ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
|
||||
|
||||
@@ -140,6 +151,41 @@ def create_provider(provider_id: str, settings: Settings) -> BaseProvider:
|
||||
return factory(config, settings)
|
||||
|
||||
|
||||
def _format_provider_query_failures(
|
||||
refs: list[ConfiguredChatModelRef],
|
||||
exc: BaseException,
|
||||
settings: Settings,
|
||||
) -> list[str]:
|
||||
reason = _provider_query_failure_reason(exc, settings)
|
||||
return [_format_model_validation_failure(ref, reason) for ref in refs]
|
||||
|
||||
|
||||
def _format_missing_model_failure(ref: ConfiguredChatModelRef) -> str:
|
||||
return _format_model_validation_failure(ref, "missing model")
|
||||
|
||||
|
||||
def _format_model_validation_failure(ref: ConfiguredChatModelRef, problem: str) -> str:
|
||||
return (
|
||||
f"sources={','.join(ref.sources)} provider={ref.provider_id} "
|
||||
f"model={ref.model_id} problem={problem}"
|
||||
)
|
||||
|
||||
|
||||
def _provider_query_failure_reason(
|
||||
exc: BaseException,
|
||||
settings: Settings,
|
||||
) -> str:
|
||||
if isinstance(exc, ModelListResponseError):
|
||||
return f"malformed model-list response: {exc.message}"
|
||||
if isinstance(exc, httpx.HTTPStatusError):
|
||||
return f"query failure: HTTP {exc.response.status_code}"
|
||||
if isinstance(exc, AuthenticationError):
|
||||
return f"query failure: {exc.message}"
|
||||
if isinstance(exc, ProviderError) and settings.log_api_error_tracebacks:
|
||||
return f"query failure: {exc.message}"
|
||||
return f"query failure: {type(exc).__name__}"
|
||||
|
||||
|
||||
class ProviderRegistry:
|
||||
"""Cache and clean up provider instances by provider id."""
|
||||
|
||||
@@ -155,6 +201,56 @@ class ProviderRegistry:
|
||||
self._providers[provider_id] = create_provider(provider_id, settings)
|
||||
return self._providers[provider_id]
|
||||
|
||||
async def validate_configured_models(self, settings: Settings) -> None:
|
||||
"""Fail fast unless every configured chat model exists upstream."""
|
||||
refs = settings.configured_chat_model_refs()
|
||||
refs_by_provider: dict[str, list[ConfiguredChatModelRef]] = defaultdict(list)
|
||||
for ref in refs:
|
||||
refs_by_provider[ref.provider_id].append(ref)
|
||||
|
||||
failures: list[str] = []
|
||||
tasks: dict[str, asyncio.Task[frozenset[str]]] = {}
|
||||
for provider_id, provider_refs in refs_by_provider.items():
|
||||
try:
|
||||
provider = self.get(provider_id, settings)
|
||||
except Exception as exc:
|
||||
failures.extend(
|
||||
_format_provider_query_failures(provider_refs, exc, settings)
|
||||
)
|
||||
continue
|
||||
tasks[provider_id] = asyncio.create_task(provider.list_model_ids())
|
||||
|
||||
if tasks:
|
||||
results = await asyncio.gather(*tasks.values(), return_exceptions=True)
|
||||
for (provider_id, _task), result in zip(
|
||||
tasks.items(), results, strict=True
|
||||
):
|
||||
provider_refs = refs_by_provider[provider_id]
|
||||
if isinstance(result, BaseException):
|
||||
if isinstance(result, asyncio.CancelledError):
|
||||
raise result
|
||||
failures.extend(
|
||||
_format_provider_query_failures(provider_refs, result, settings)
|
||||
)
|
||||
continue
|
||||
failures.extend(
|
||||
_format_missing_model_failure(ref)
|
||||
for ref in provider_refs
|
||||
if ref.model_id not in result
|
||||
)
|
||||
|
||||
if failures:
|
||||
message = "Configured model validation failed:\n" + "\n".join(
|
||||
f"- {failure}" for failure in failures
|
||||
)
|
||||
raise ServiceUnavailableError(message)
|
||||
|
||||
logger.info(
|
||||
"Configured provider models validated: models={} providers={}",
|
||||
len(refs),
|
||||
len(refs_by_provider),
|
||||
)
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Call ``cleanup`` on every cached provider, then clear the cache.
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
@@ -30,6 +30,10 @@ def client():
|
||||
"""HTTP client with provider resolution stubbed; patch only for this file."""
|
||||
with (
|
||||
patch("api.dependencies.resolve_provider", return_value=mock_provider),
|
||||
patch(
|
||||
"providers.registry.ProviderRegistry.validate_configured_models",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
TestClient(app) as test_client,
|
||||
):
|
||||
yield test_client
|
||||
|
||||
@@ -4,9 +4,11 @@ from typing import cast
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from config.settings import Settings
|
||||
from providers.exceptions import ServiceUnavailableError
|
||||
from providers.registry import ProviderRegistry
|
||||
|
||||
_RUNTIME_EXTRAS = {
|
||||
@@ -26,6 +28,7 @@ _RUNTIME_EXTRAS = {
|
||||
"log_raw_messaging_content": False,
|
||||
"log_raw_cli_diagnostics": False,
|
||||
"log_messaging_error_details": False,
|
||||
"configured_chat_model_refs": lambda: (),
|
||||
}
|
||||
|
||||
|
||||
@@ -348,6 +351,45 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
|
||||
registry_cleanup.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_startup_validation_blocks_messaging_and_cleans_up(tmp_path):
|
||||
import api.runtime as api_runtime_mod
|
||||
|
||||
settings = _app_settings(
|
||||
messaging_platform="telegram",
|
||||
telegram_bot_token="token",
|
||||
allowed_telegram_user_id="123",
|
||||
discord_bot_token=None,
|
||||
allowed_discord_channels=None,
|
||||
allowed_dir=str(tmp_path / "workspace"),
|
||||
claude_workspace=str(tmp_path / "data"),
|
||||
host="127.0.0.1",
|
||||
port=8082,
|
||||
log_file=str(tmp_path / "server.log"),
|
||||
)
|
||||
app = FastAPI()
|
||||
runtime = api_runtime_mod.AppRuntime(
|
||||
app=app,
|
||||
settings=cast(Settings, settings),
|
||||
)
|
||||
|
||||
validation = AsyncMock(side_effect=ServiceUnavailableError("bad model"))
|
||||
cleanup = AsyncMock()
|
||||
with (
|
||||
patch.object(ProviderRegistry, "validate_configured_models", new=validation),
|
||||
patch.object(ProviderRegistry, "cleanup", new=cleanup),
|
||||
patch(
|
||||
"messaging.platforms.factory.create_messaging_platform"
|
||||
) as create_platform,
|
||||
pytest.raises(ServiceUnavailableError, match="bad model"),
|
||||
):
|
||||
await runtime.startup()
|
||||
|
||||
validation.assert_awaited_once_with(settings)
|
||||
cleanup.assert_awaited_once()
|
||||
create_platform.assert_not_called()
|
||||
|
||||
|
||||
def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
|
||||
"""Messaging import failure logs warning and continues without crash."""
|
||||
from api.app import create_app
|
||||
|
||||
@@ -657,3 +657,30 @@ class TestPerModelMapping:
|
||||
assert Settings.parse_model_name("lmstudio/qwen") == "qwen"
|
||||
assert Settings.parse_model_name("llamacpp/model") == "model"
|
||||
assert Settings.parse_model_name("ollama/llama3.1") == "llama3.1"
|
||||
|
||||
def test_configured_chat_model_refs_collects_unique_models_with_sources(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Startup validation model collection is limited to configured chat refs."""
|
||||
from config.settings import Settings
|
||||
|
||||
monkeypatch.setenv("FCC_SMOKE_MODEL_NVIDIA_NIM", "nvidia_nim/smoke")
|
||||
monkeypatch.setenv("WHISPER_MODEL", "openai/whisper-large-v3")
|
||||
s = Settings()
|
||||
s.model = "nvidia_nim/fallback"
|
||||
s.model_opus = "open_router/anthropic/claude-opus"
|
||||
s.model_sonnet = "nvidia_nim/fallback"
|
||||
s.model_haiku = None
|
||||
|
||||
refs = s.configured_chat_model_refs()
|
||||
|
||||
assert [ref.model_ref for ref in refs] == [
|
||||
"nvidia_nim/fallback",
|
||||
"open_router/anthropic/claude-opus",
|
||||
]
|
||||
assert refs[0].provider_id == "nvidia_nim"
|
||||
assert refs[0].model_id == "fallback"
|
||||
assert refs[0].sources == ("MODEL", "MODEL_SONNET")
|
||||
assert refs[1].provider_id == "open_router"
|
||||
assert refs[1].model_id == "anthropic/claude-opus"
|
||||
assert refs[1].sources == ("MODEL_OPUS",)
|
||||
|
||||
@@ -0,0 +1,269 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from config.nim import NimSettings
|
||||
from config.settings import Settings
|
||||
from providers.base import BaseProvider, ProviderConfig
|
||||
from providers.deepseek import DeepSeekProvider
|
||||
from providers.exceptions import ModelListResponseError, ServiceUnavailableError
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
from providers.nvidia_nim import NvidiaNimProvider
|
||||
from providers.ollama import OllamaProvider
|
||||
from providers.registry import ProviderRegistry
|
||||
|
||||
|
||||
def _settings(
|
||||
*,
|
||||
model: str = "nvidia_nim/nim-model",
|
||||
model_opus: str | None = None,
|
||||
model_sonnet: str | None = None,
|
||||
model_haiku: str | None = None,
|
||||
) -> Settings:
|
||||
return Settings.model_construct(
|
||||
model=model,
|
||||
model_opus=model_opus,
|
||||
model_sonnet=model_sonnet,
|
||||
model_haiku=model_haiku,
|
||||
log_api_error_tracebacks=False,
|
||||
)
|
||||
|
||||
|
||||
def _response(status_code: int, payload: object) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
status_code,
|
||||
json=payload,
|
||||
request=httpx.Request("GET", "https://example.test/models"),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nim_lists_openai_compatible_model_ids() -> None:
|
||||
config = ProviderConfig(api_key="test-key")
|
||||
with patch("providers.openai_compat.AsyncOpenAI"):
|
||||
provider = NvidiaNimProvider(config, nim_settings=NimSettings())
|
||||
|
||||
with patch.object(
|
||||
provider._client.models,
|
||||
"list",
|
||||
new_callable=AsyncMock,
|
||||
return_value=SimpleNamespace(data=[SimpleNamespace(id="nvidia/model")]),
|
||||
):
|
||||
assert await provider.list_model_ids() == frozenset({"nvidia/model"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_native_openai_compatible_provider_lists_model_ids() -> None:
|
||||
provider = LMStudioProvider(
|
||||
ProviderConfig(api_key="lm-studio", base_url="http://localhost:1234/v1")
|
||||
)
|
||||
with patch.object(
|
||||
provider._client,
|
||||
"get",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_response(200, {"data": [{"id": "local/model"}]}),
|
||||
) as mock_get:
|
||||
assert await provider.list_model_ids() == frozenset({"local/model"})
|
||||
|
||||
mock_get.assert_awaited_once_with("/models", headers={})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deepseek_lists_models_from_root_endpoint() -> None:
|
||||
provider = DeepSeekProvider(ProviderConfig(api_key="deepseek-key"))
|
||||
with patch.object(
|
||||
provider._client,
|
||||
"get",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_response(200, {"data": [{"id": "deepseek-chat"}]}),
|
||||
) as mock_get:
|
||||
assert await provider.list_model_ids() == frozenset({"deepseek-chat"})
|
||||
|
||||
mock_get.assert_awaited_once_with(
|
||||
"https://api.deepseek.com/models",
|
||||
headers={"Authorization": "Bearer deepseek-key"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ollama_lists_native_tag_model_ids() -> None:
|
||||
provider = OllamaProvider(
|
||||
ProviderConfig(api_key="ollama", base_url="http://localhost:11434")
|
||||
)
|
||||
with patch.object(
|
||||
provider._client,
|
||||
"get",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_response(
|
||||
200,
|
||||
{
|
||||
"models": [
|
||||
{"name": "llama3.1:latest", "model": "llama3.1:latest"},
|
||||
{"name": "qwen3"},
|
||||
]
|
||||
},
|
||||
),
|
||||
) as mock_get:
|
||||
assert await provider.list_model_ids() == frozenset(
|
||||
{"llama3.1:latest", "qwen3"}
|
||||
)
|
||||
|
||||
mock_get.assert_awaited_once_with("http://localhost:11434/api/tags")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_listing_rejects_malformed_payload() -> None:
|
||||
provider = LMStudioProvider(
|
||||
ProviderConfig(api_key="lm-studio", base_url="http://localhost:1234/v1")
|
||||
)
|
||||
with (
|
||||
patch.object(
|
||||
provider._client,
|
||||
"get",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_response(200, {"data": [{}]}),
|
||||
),
|
||||
pytest.raises(ModelListResponseError, match="malformed"),
|
||||
):
|
||||
await provider.list_model_ids()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_listing_raises_http_status_errors() -> None:
|
||||
provider = LMStudioProvider(
|
||||
ProviderConfig(api_key="lm-studio", base_url="http://localhost:1234/v1")
|
||||
)
|
||||
with (
|
||||
patch.object(
|
||||
provider._client,
|
||||
"get",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_response(503, {"error": "down"}),
|
||||
),
|
||||
pytest.raises(httpx.HTTPStatusError),
|
||||
):
|
||||
await provider.list_model_ids()
|
||||
|
||||
|
||||
class FakeProvider(BaseProvider):
|
||||
def __init__(
|
||||
self,
|
||||
model_ids: frozenset[str] | None = None,
|
||||
*,
|
||||
error: BaseException | None = None,
|
||||
started: asyncio.Event | None = None,
|
||||
peer_started: asyncio.Event | None = None,
|
||||
):
|
||||
super().__init__(ProviderConfig(api_key="test"))
|
||||
self._model_ids = model_ids or frozenset()
|
||||
self._error = error
|
||||
self._started = started
|
||||
self._peer_started = peer_started
|
||||
self.cleaned = False
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
self.cleaned = True
|
||||
|
||||
async def list_model_ids(self) -> frozenset[str]:
|
||||
if self._started is not None:
|
||||
self._started.set()
|
||||
if self._peer_started is not None:
|
||||
await self._peer_started.wait()
|
||||
if self._error is not None:
|
||||
raise self._error
|
||||
return self._model_ids
|
||||
|
||||
async def stream_response(
|
||||
self,
|
||||
request: Any,
|
||||
input_tokens: int = 0,
|
||||
*,
|
||||
request_id: str | None = None,
|
||||
thinking_enabled: bool | None = None,
|
||||
) -> AsyncIterator[str]:
|
||||
if False:
|
||||
yield ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_validation_succeeds_for_all_configured_models() -> None:
|
||||
registry = ProviderRegistry(
|
||||
{
|
||||
"nvidia_nim": FakeProvider(frozenset({"nim-model"})),
|
||||
"open_router": FakeProvider(frozenset({"anthropic/claude-opus"})),
|
||||
}
|
||||
)
|
||||
settings = _settings(model_opus="open_router/anthropic/claude-opus")
|
||||
|
||||
await registry.validate_configured_models(settings)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_validation_reports_missing_model_with_sources() -> None:
|
||||
registry = ProviderRegistry(
|
||||
{"nvidia_nim": FakeProvider(frozenset({"different-model"}))}
|
||||
)
|
||||
settings = _settings(model_sonnet="nvidia_nim/nim-model")
|
||||
|
||||
with pytest.raises(ServiceUnavailableError) as exc_info:
|
||||
await registry.validate_configured_models(settings)
|
||||
|
||||
message = exc_info.value.message
|
||||
assert "sources=MODEL,MODEL_SONNET" in message
|
||||
assert "provider=nvidia_nim" in message
|
||||
assert "model=nim-model" in message
|
||||
assert "problem=missing model" in message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_validation_aggregates_multiple_failures() -> None:
|
||||
registry = ProviderRegistry(
|
||||
{
|
||||
"nvidia_nim": FakeProvider(frozenset({"different-model"})),
|
||||
"open_router": FakeProvider(
|
||||
error=ModelListResponseError("bad model-list shape")
|
||||
),
|
||||
}
|
||||
)
|
||||
settings = _settings(model_opus="open_router/anthropic/claude-opus")
|
||||
|
||||
with pytest.raises(ServiceUnavailableError) as exc_info:
|
||||
await registry.validate_configured_models(settings)
|
||||
|
||||
message = exc_info.value.message
|
||||
assert "sources=MODEL provider=nvidia_nim model=nim-model" in message
|
||||
assert "problem=missing model" in message
|
||||
assert "sources=MODEL_OPUS provider=open_router model=anthropic/claude-opus" in (
|
||||
message
|
||||
)
|
||||
assert "problem=malformed model-list response" in message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_registry_validation_queries_providers_concurrently() -> None:
|
||||
nim_started = asyncio.Event()
|
||||
router_started = asyncio.Event()
|
||||
registry = ProviderRegistry(
|
||||
{
|
||||
"nvidia_nim": FakeProvider(
|
||||
frozenset({"nim-model"}),
|
||||
started=nim_started,
|
||||
peer_started=router_started,
|
||||
),
|
||||
"open_router": FakeProvider(
|
||||
frozenset({"anthropic/claude-opus"}),
|
||||
started=router_started,
|
||||
peer_started=nim_started,
|
||||
),
|
||||
}
|
||||
)
|
||||
settings = _settings(model_opus="open_router/anthropic/claude-opus")
|
||||
|
||||
await asyncio.wait_for(registry.validate_configured_models(settings), timeout=1.0)
|
||||
Reference in New Issue
Block a user