Validate configured models at startup

This commit is contained in:
Alishahryar1
2026-04-30 00:33:45 -07:00
parent d78869d39a
commit eb5516e53b
16 changed files with 639 additions and 6 deletions
+12 -3
View File
@@ -84,9 +84,18 @@ class AppRuntime:
logger.info("Starting Claude Code Proxy...")
self._provider_registry = ProviderRegistry()
self.app.state.provider_registry = self._provider_registry
warn_if_process_auth_token(self.settings)
await self._start_messaging_if_configured()
self._publish_state()
try:
warn_if_process_auth_token(self.settings)
await self._provider_registry.validate_configured_models(self.settings)
await self._start_messaging_if_configured()
self._publish_state()
except Exception:
await best_effort(
"provider_registry.cleanup",
self._provider_registry.cleanup(),
log_verbose_errors=self.settings.log_api_error_tracebacks,
)
raise
async def shutdown(self) -> None:
verbose = self.settings.log_api_error_tracebacks
+35
View File
@@ -2,6 +2,7 @@
import os
from collections.abc import Mapping
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import Any
@@ -15,6 +16,16 @@ from .nim import NimSettings
from .provider_ids import SUPPORTED_PROVIDER_IDS
@dataclass(frozen=True, slots=True)
class ConfiguredChatModelRef:
"""A unique configured chat model reference and the env keys that set it."""
model_ref: str
provider_id: str
model_id: str
sources: tuple[str, ...]
def _env_files() -> tuple[Path, ...]:
"""Return env file paths in priority order (later overrides earlier)."""
files: list[Path] = [
@@ -441,6 +452,30 @@ class Settings(BaseSettings):
return self.model_sonnet
return self.model
def configured_chat_model_refs(self) -> tuple[ConfiguredChatModelRef, ...]:
"""Return unique configured chat provider/model refs with source env keys."""
candidates = (
("MODEL", self.model),
("MODEL_OPUS", self.model_opus),
("MODEL_SONNET", self.model_sonnet),
("MODEL_HAIKU", self.model_haiku),
)
sources_by_ref: dict[str, list[str]] = {}
for source, model_ref in candidates:
if model_ref is None:
continue
sources_by_ref.setdefault(model_ref, []).append(source)
return tuple(
ConfiguredChatModelRef(
model_ref=model_ref,
provider_id=Settings.parse_provider_type(model_ref),
model_id=Settings.parse_model_name(model_ref),
sources=tuple(sources),
)
for model_ref, sources in sources_by_ref.items()
)
def resolve_thinking(self, claude_model_name: str) -> bool:
"""Resolve whether thinking is enabled for an incoming Claude model name."""
name_lower = claude_model_name.lower()
+2
View File
@@ -10,6 +10,7 @@ from .exceptions import (
APIError,
AuthenticationError,
InvalidRequestError,
ModelListResponseError,
OverloadedError,
ProviderError,
RateLimitError,
@@ -21,6 +22,7 @@ __all__ = [
"AuthenticationError",
"BaseProvider",
"InvalidRequestError",
"ModelListResponseError",
"OverloadedError",
"ProviderConfig",
"ProviderError",
+38
View File
@@ -26,11 +26,23 @@ from providers.error_mapping import (
map_error,
user_visible_message_for_mapped_provider_error,
)
from providers.exceptions import ModelListResponseError
from providers.model_listing import extract_openai_model_ids
from providers.rate_limit import GlobalRateLimiter
StreamChunkMode = Literal["line", "event"]
def _model_list_json(response: httpx.Response, *, provider_name: str) -> Any:
response.raise_for_status()
try:
return response.json()
except ValueError as exc:
raise ModelListResponseError(
f"{provider_name} model-list response is malformed: invalid JSON"
) from exc
class AnthropicMessagesTransport(BaseProvider):
"""Base class for providers that stream from an Anthropic-compatible endpoint."""
@@ -68,6 +80,32 @@ class AnthropicMessagesTransport(BaseProvider):
"""Release HTTP client resources."""
await self._client.aclose()
async def list_model_ids(self) -> frozenset[str]:
"""Return model ids from an OpenAI-compatible ``/models`` endpoint."""
response = await self._send_model_list_request()
try:
payload = _model_list_json(response, provider_name=self._provider_name)
return self._extract_model_ids_from_model_list_payload(payload)
finally:
await response.aclose()
async def _send_model_list_request(self) -> httpx.Response:
"""Query the provider endpoint that advertises available model ids."""
return await self._client.get(
"/models",
headers=self._model_list_headers(),
)
def _model_list_headers(self) -> dict[str, str]:
"""Return headers for model-list requests."""
return {}
def _extract_model_ids_from_model_list_payload(
self, payload: Any
) -> frozenset[str]:
"""Parse the provider model-list response body."""
return extract_openai_model_ids(payload, provider_name=self._provider_name)
def _request_headers(self) -> dict[str, str]:
"""Return headers for the native messages request."""
return {"Content-Type": "application/json"}
+4
View File
@@ -105,6 +105,10 @@ class BaseProvider(ABC):
async def cleanup(self) -> None:
"""Release any resources held by this provider."""
@abstractmethod
async def list_model_ids(self) -> frozenset[str]:
"""Return the model ids currently advertised by this provider."""
@abstractmethod
async def stream_response(
self,
+14
View File
@@ -4,6 +4,8 @@ from __future__ import annotations
from typing import Any
import httpx
from providers.anthropic_messages import AnthropicMessagesTransport
from providers.base import ProviderConfig
from providers.defaults import DEEPSEEK_ANTHROPIC_DEFAULT_BASE
@@ -35,3 +37,15 @@ class DeepSeekProvider(AnthropicMessagesTransport):
"Content-Type": "application/json",
"x-api-key": self._api_key,
}
async def _send_model_list_request(self) -> httpx.Response:
"""DeepSeek lists models from the OpenAI-format root, not /anthropic."""
url = str(
httpx.URL(self._base_url).copy_with(
path="/models", query=None, fragment=None
)
)
return await self._client.get(url, headers=self._model_list_headers())
def _model_list_headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self._api_key}"}
+4
View File
@@ -107,3 +107,7 @@ class ServiceUnavailableError(ProviderError):
error_type="api_error",
raw_error=raw_error,
)
class ModelListResponseError(ServiceUnavailableError):
"""Raised when a provider model-list response cannot be parsed safely."""
+69
View File
@@ -0,0 +1,69 @@
"""Provider model-list response parsing helpers."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any
from providers.exceptions import ModelListResponseError
def extract_openai_model_ids(payload: Any, *, provider_name: str) -> frozenset[str]:
"""Extract model ids from an OpenAI-compatible ``/models`` response."""
data = _field(payload, "data")
if not _is_sequence(data):
raise _malformed(provider_name, "expected top-level data array")
model_ids: set[str] = set()
for item in data:
model_id = _field(item, "id")
if not isinstance(model_id, str) or not model_id.strip():
raise _malformed(provider_name, "expected every data item to include id")
model_ids.add(model_id)
if not model_ids:
raise _malformed(provider_name, "response did not include any model ids")
return frozenset(model_ids)
def extract_ollama_model_ids(payload: Any, *, provider_name: str) -> frozenset[str]:
"""Extract model ids from Ollama's native ``/api/tags`` response."""
models = _field(payload, "models")
if not _is_sequence(models):
raise _malformed(provider_name, "expected top-level models array")
model_ids: set[str] = set()
for item in models:
item_ids: list[str] = []
for key in ("model", "name"):
value = _field(item, key)
if isinstance(value, str) and value.strip():
item_ids.append(value)
if not item_ids:
raise _malformed(
provider_name,
"expected every models item to include model or name",
)
model_ids.update(item_ids)
if not model_ids:
raise _malformed(provider_name, "response did not include any model ids")
return frozenset(model_ids)
def _field(item: Any, name: str) -> Any:
if isinstance(item, Mapping):
return item.get(name)
return getattr(item, name, None)
def _is_sequence(value: Any) -> bool:
return isinstance(value, Sequence) and not isinstance(
value, str | bytes | bytearray
)
def _malformed(provider_name: str, reason: str) -> ModelListResponseError:
return ModelListResponseError(
f"{provider_name} model-list response is malformed: {reason}"
)
+10
View File
@@ -5,6 +5,7 @@ import httpx
from providers.anthropic_messages import AnthropicMessagesTransport
from providers.base import ProviderConfig
from providers.defaults import OLLAMA_DEFAULT_BASE
from providers.model_listing import extract_ollama_model_ids
class OllamaProvider(AnthropicMessagesTransport):
@@ -27,3 +28,12 @@ class OllamaProvider(AnthropicMessagesTransport):
headers=self._request_headers(),
)
return await self._client.send(request, stream=True)
async def _send_model_list_request(self) -> httpx.Response:
"""Query Ollama's native local model-list endpoint."""
return await self._client.get(f"{self._base_url}/api/tags")
def _extract_model_ids_from_model_list_payload(
self, payload: object
) -> frozenset[str]:
return extract_ollama_model_ids(payload, provider_name=self._provider_name)
+4
View File
@@ -51,6 +51,10 @@ class OpenRouterProvider(AnthropicMessagesTransport):
"anthropic-version": _ANTHROPIC_VERSION,
}
def _model_list_headers(self) -> dict[str, str]:
"""Return OpenRouter's OpenAI-compatible model-list headers."""
return {"Authorization": f"Bearer {self._api_key}"}
def _new_stream_state(self, request: Any, *, thinking_enabled: bool) -> Any:
"""Create per-stream state for thinking block filtering."""
return NativeSseBlockPolicyState()
+6
View File
@@ -28,6 +28,7 @@ from providers.error_mapping import (
map_error,
user_visible_message_for_mapped_provider_error,
)
from providers.model_listing import extract_openai_model_ids
from providers.rate_limit import GlobalRateLimiter
@@ -106,6 +107,11 @@ class OpenAIChatTransport(BaseProvider):
if client is not None:
await client.aclose()
async def list_model_ids(self) -> frozenset[str]:
"""Return model ids from the provider's OpenAI-compatible models endpoint."""
payload = await self._client.models.list()
return extract_openai_model_ids(payload, provider_name=self._provider_name)
@abstractmethod
def _build_request_body(
self, request: Any, thinking_enabled: bool | None = None
+98 -2
View File
@@ -2,16 +2,27 @@
from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import Callable, MutableMapping
import httpx
from loguru import logger
from config.provider_catalog import (
PROVIDER_CATALOG,
SUPPORTED_PROVIDER_IDS,
ProviderDescriptor,
)
from config.settings import Settings
from config.settings import ConfiguredChatModelRef, Settings
from providers.base import BaseProvider, ProviderConfig
from providers.exceptions import AuthenticationError, UnknownProviderTypeError
from providers.exceptions import (
AuthenticationError,
ModelListResponseError,
ProviderError,
ServiceUnavailableError,
UnknownProviderTypeError,
)
ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
@@ -140,6 +151,41 @@ def create_provider(provider_id: str, settings: Settings) -> BaseProvider:
return factory(config, settings)
def _format_provider_query_failures(
refs: list[ConfiguredChatModelRef],
exc: BaseException,
settings: Settings,
) -> list[str]:
reason = _provider_query_failure_reason(exc, settings)
return [_format_model_validation_failure(ref, reason) for ref in refs]
def _format_missing_model_failure(ref: ConfiguredChatModelRef) -> str:
return _format_model_validation_failure(ref, "missing model")
def _format_model_validation_failure(ref: ConfiguredChatModelRef, problem: str) -> str:
return (
f"sources={','.join(ref.sources)} provider={ref.provider_id} "
f"model={ref.model_id} problem={problem}"
)
def _provider_query_failure_reason(
exc: BaseException,
settings: Settings,
) -> str:
if isinstance(exc, ModelListResponseError):
return f"malformed model-list response: {exc.message}"
if isinstance(exc, httpx.HTTPStatusError):
return f"query failure: HTTP {exc.response.status_code}"
if isinstance(exc, AuthenticationError):
return f"query failure: {exc.message}"
if isinstance(exc, ProviderError) and settings.log_api_error_tracebacks:
return f"query failure: {exc.message}"
return f"query failure: {type(exc).__name__}"
class ProviderRegistry:
"""Cache and clean up provider instances by provider id."""
@@ -155,6 +201,56 @@ class ProviderRegistry:
self._providers[provider_id] = create_provider(provider_id, settings)
return self._providers[provider_id]
async def validate_configured_models(self, settings: Settings) -> None:
"""Fail fast unless every configured chat model exists upstream."""
refs = settings.configured_chat_model_refs()
refs_by_provider: dict[str, list[ConfiguredChatModelRef]] = defaultdict(list)
for ref in refs:
refs_by_provider[ref.provider_id].append(ref)
failures: list[str] = []
tasks: dict[str, asyncio.Task[frozenset[str]]] = {}
for provider_id, provider_refs in refs_by_provider.items():
try:
provider = self.get(provider_id, settings)
except Exception as exc:
failures.extend(
_format_provider_query_failures(provider_refs, exc, settings)
)
continue
tasks[provider_id] = asyncio.create_task(provider.list_model_ids())
if tasks:
results = await asyncio.gather(*tasks.values(), return_exceptions=True)
for (provider_id, _task), result in zip(
tasks.items(), results, strict=True
):
provider_refs = refs_by_provider[provider_id]
if isinstance(result, BaseException):
if isinstance(result, asyncio.CancelledError):
raise result
failures.extend(
_format_provider_query_failures(provider_refs, result, settings)
)
continue
failures.extend(
_format_missing_model_failure(ref)
for ref in provider_refs
if ref.model_id not in result
)
if failures:
message = "Configured model validation failed:\n" + "\n".join(
f"- {failure}" for failure in failures
)
raise ServiceUnavailableError(message)
logger.info(
"Configured provider models validated: models={} providers={}",
len(refs),
len(refs_by_provider),
)
async def cleanup(self) -> None:
"""Call ``cleanup`` on every cached provider, then clear the cache.
+5 -1
View File
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
@@ -30,6 +30,10 @@ def client():
"""HTTP client with provider resolution stubbed; patch only for this file."""
with (
patch("api.dependencies.resolve_provider", return_value=mock_provider),
patch(
"providers.registry.ProviderRegistry.validate_configured_models",
new_callable=AsyncMock,
),
TestClient(app) as test_client,
):
yield test_client
+42
View File
@@ -4,9 +4,11 @@ from typing import cast
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from config.settings import Settings
from providers.exceptions import ServiceUnavailableError
from providers.registry import ProviderRegistry
_RUNTIME_EXTRAS = {
@@ -26,6 +28,7 @@ _RUNTIME_EXTRAS = {
"log_raw_messaging_content": False,
"log_raw_cli_diagnostics": False,
"log_messaging_error_details": False,
"configured_chat_model_refs": lambda: (),
}
@@ -348,6 +351,45 @@ def test_app_lifespan_cleanup_continues_if_platform_stop_raises(tmp_path):
registry_cleanup.assert_awaited_once()
@pytest.mark.asyncio
async def test_runtime_startup_validation_blocks_messaging_and_cleans_up(tmp_path):
import api.runtime as api_runtime_mod
settings = _app_settings(
messaging_platform="telegram",
telegram_bot_token="token",
allowed_telegram_user_id="123",
discord_bot_token=None,
allowed_discord_channels=None,
allowed_dir=str(tmp_path / "workspace"),
claude_workspace=str(tmp_path / "data"),
host="127.0.0.1",
port=8082,
log_file=str(tmp_path / "server.log"),
)
app = FastAPI()
runtime = api_runtime_mod.AppRuntime(
app=app,
settings=cast(Settings, settings),
)
validation = AsyncMock(side_effect=ServiceUnavailableError("bad model"))
cleanup = AsyncMock()
with (
patch.object(ProviderRegistry, "validate_configured_models", new=validation),
patch.object(ProviderRegistry, "cleanup", new=cleanup),
patch(
"messaging.platforms.factory.create_messaging_platform"
) as create_platform,
pytest.raises(ServiceUnavailableError, match="bad model"),
):
await runtime.startup()
validation.assert_awaited_once_with(settings)
cleanup.assert_awaited_once()
create_platform.assert_not_called()
def test_app_lifespan_messaging_import_error_no_crash(tmp_path, caplog):
"""Messaging import failure logs warning and continues without crash."""
from api.app import create_app
+27
View File
@@ -657,3 +657,30 @@ class TestPerModelMapping:
assert Settings.parse_model_name("lmstudio/qwen") == "qwen"
assert Settings.parse_model_name("llamacpp/model") == "model"
assert Settings.parse_model_name("ollama/llama3.1") == "llama3.1"
def test_configured_chat_model_refs_collects_unique_models_with_sources(
self, monkeypatch
):
"""Startup validation model collection is limited to configured chat refs."""
from config.settings import Settings
monkeypatch.setenv("FCC_SMOKE_MODEL_NVIDIA_NIM", "nvidia_nim/smoke")
monkeypatch.setenv("WHISPER_MODEL", "openai/whisper-large-v3")
s = Settings()
s.model = "nvidia_nim/fallback"
s.model_opus = "open_router/anthropic/claude-opus"
s.model_sonnet = "nvidia_nim/fallback"
s.model_haiku = None
refs = s.configured_chat_model_refs()
assert [ref.model_ref for ref in refs] == [
"nvidia_nim/fallback",
"open_router/anthropic/claude-opus",
]
assert refs[0].provider_id == "nvidia_nim"
assert refs[0].model_id == "fallback"
assert refs[0].sources == ("MODEL", "MODEL_SONNET")
assert refs[1].provider_id == "open_router"
assert refs[1].model_id == "anthropic/claude-opus"
assert refs[1].sources == ("MODEL_OPUS",)
+269
View File
@@ -0,0 +1,269 @@
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from types import SimpleNamespace
from typing import Any
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from config.nim import NimSettings
from config.settings import Settings
from providers.base import BaseProvider, ProviderConfig
from providers.deepseek import DeepSeekProvider
from providers.exceptions import ModelListResponseError, ServiceUnavailableError
from providers.lmstudio import LMStudioProvider
from providers.nvidia_nim import NvidiaNimProvider
from providers.ollama import OllamaProvider
from providers.registry import ProviderRegistry
def _settings(
*,
model: str = "nvidia_nim/nim-model",
model_opus: str | None = None,
model_sonnet: str | None = None,
model_haiku: str | None = None,
) -> Settings:
return Settings.model_construct(
model=model,
model_opus=model_opus,
model_sonnet=model_sonnet,
model_haiku=model_haiku,
log_api_error_tracebacks=False,
)
def _response(status_code: int, payload: object) -> httpx.Response:
return httpx.Response(
status_code,
json=payload,
request=httpx.Request("GET", "https://example.test/models"),
)
@pytest.mark.asyncio
async def test_nim_lists_openai_compatible_model_ids() -> None:
config = ProviderConfig(api_key="test-key")
with patch("providers.openai_compat.AsyncOpenAI"):
provider = NvidiaNimProvider(config, nim_settings=NimSettings())
with patch.object(
provider._client.models,
"list",
new_callable=AsyncMock,
return_value=SimpleNamespace(data=[SimpleNamespace(id="nvidia/model")]),
):
assert await provider.list_model_ids() == frozenset({"nvidia/model"})
@pytest.mark.asyncio
async def test_native_openai_compatible_provider_lists_model_ids() -> None:
provider = LMStudioProvider(
ProviderConfig(api_key="lm-studio", base_url="http://localhost:1234/v1")
)
with patch.object(
provider._client,
"get",
new_callable=AsyncMock,
return_value=_response(200, {"data": [{"id": "local/model"}]}),
) as mock_get:
assert await provider.list_model_ids() == frozenset({"local/model"})
mock_get.assert_awaited_once_with("/models", headers={})
@pytest.mark.asyncio
async def test_deepseek_lists_models_from_root_endpoint() -> None:
provider = DeepSeekProvider(ProviderConfig(api_key="deepseek-key"))
with patch.object(
provider._client,
"get",
new_callable=AsyncMock,
return_value=_response(200, {"data": [{"id": "deepseek-chat"}]}),
) as mock_get:
assert await provider.list_model_ids() == frozenset({"deepseek-chat"})
mock_get.assert_awaited_once_with(
"https://api.deepseek.com/models",
headers={"Authorization": "Bearer deepseek-key"},
)
@pytest.mark.asyncio
async def test_ollama_lists_native_tag_model_ids() -> None:
provider = OllamaProvider(
ProviderConfig(api_key="ollama", base_url="http://localhost:11434")
)
with patch.object(
provider._client,
"get",
new_callable=AsyncMock,
return_value=_response(
200,
{
"models": [
{"name": "llama3.1:latest", "model": "llama3.1:latest"},
{"name": "qwen3"},
]
},
),
) as mock_get:
assert await provider.list_model_ids() == frozenset(
{"llama3.1:latest", "qwen3"}
)
mock_get.assert_awaited_once_with("http://localhost:11434/api/tags")
@pytest.mark.asyncio
async def test_model_listing_rejects_malformed_payload() -> None:
provider = LMStudioProvider(
ProviderConfig(api_key="lm-studio", base_url="http://localhost:1234/v1")
)
with (
patch.object(
provider._client,
"get",
new_callable=AsyncMock,
return_value=_response(200, {"data": [{}]}),
),
pytest.raises(ModelListResponseError, match="malformed"),
):
await provider.list_model_ids()
@pytest.mark.asyncio
async def test_model_listing_raises_http_status_errors() -> None:
provider = LMStudioProvider(
ProviderConfig(api_key="lm-studio", base_url="http://localhost:1234/v1")
)
with (
patch.object(
provider._client,
"get",
new_callable=AsyncMock,
return_value=_response(503, {"error": "down"}),
),
pytest.raises(httpx.HTTPStatusError),
):
await provider.list_model_ids()
class FakeProvider(BaseProvider):
def __init__(
self,
model_ids: frozenset[str] | None = None,
*,
error: BaseException | None = None,
started: asyncio.Event | None = None,
peer_started: asyncio.Event | None = None,
):
super().__init__(ProviderConfig(api_key="test"))
self._model_ids = model_ids or frozenset()
self._error = error
self._started = started
self._peer_started = peer_started
self.cleaned = False
async def cleanup(self) -> None:
self.cleaned = True
async def list_model_ids(self) -> frozenset[str]:
if self._started is not None:
self._started.set()
if self._peer_started is not None:
await self._peer_started.wait()
if self._error is not None:
raise self._error
return self._model_ids
async def stream_response(
self,
request: Any,
input_tokens: int = 0,
*,
request_id: str | None = None,
thinking_enabled: bool | None = None,
) -> AsyncIterator[str]:
if False:
yield ""
@pytest.mark.asyncio
async def test_registry_validation_succeeds_for_all_configured_models() -> None:
registry = ProviderRegistry(
{
"nvidia_nim": FakeProvider(frozenset({"nim-model"})),
"open_router": FakeProvider(frozenset({"anthropic/claude-opus"})),
}
)
settings = _settings(model_opus="open_router/anthropic/claude-opus")
await registry.validate_configured_models(settings)
@pytest.mark.asyncio
async def test_registry_validation_reports_missing_model_with_sources() -> None:
registry = ProviderRegistry(
{"nvidia_nim": FakeProvider(frozenset({"different-model"}))}
)
settings = _settings(model_sonnet="nvidia_nim/nim-model")
with pytest.raises(ServiceUnavailableError) as exc_info:
await registry.validate_configured_models(settings)
message = exc_info.value.message
assert "sources=MODEL,MODEL_SONNET" in message
assert "provider=nvidia_nim" in message
assert "model=nim-model" in message
assert "problem=missing model" in message
@pytest.mark.asyncio
async def test_registry_validation_aggregates_multiple_failures() -> None:
registry = ProviderRegistry(
{
"nvidia_nim": FakeProvider(frozenset({"different-model"})),
"open_router": FakeProvider(
error=ModelListResponseError("bad model-list shape")
),
}
)
settings = _settings(model_opus="open_router/anthropic/claude-opus")
with pytest.raises(ServiceUnavailableError) as exc_info:
await registry.validate_configured_models(settings)
message = exc_info.value.message
assert "sources=MODEL provider=nvidia_nim model=nim-model" in message
assert "problem=missing model" in message
assert "sources=MODEL_OPUS provider=open_router model=anthropic/claude-opus" in (
message
)
assert "problem=malformed model-list response" in message
@pytest.mark.asyncio
async def test_registry_validation_queries_providers_concurrently() -> None:
nim_started = asyncio.Event()
router_started = asyncio.Event()
registry = ProviderRegistry(
{
"nvidia_nim": FakeProvider(
frozenset({"nim-model"}),
started=nim_started,
peer_started=router_started,
),
"open_router": FakeProvider(
frozenset({"anthropic/claude-opus"}),
started=router_started,
peer_started=nim_started,
),
}
)
settings = _settings(model_opus="open_router/anthropic/claude-opus")
await asyncio.wait_for(registry.validate_configured_models(settings), timeout=1.0)