Files
free-claude-code/providers/registry.py
T
2026-04-30 00:33:45 -07:00

275 lines
9.7 KiB
Python

"""Provider descriptors, factory, and runtime registry."""
from __future__ import annotations
import asyncio
from collections import defaultdict
from collections.abc import Callable, MutableMapping
import httpx
from loguru import logger
from config.provider_catalog import (
PROVIDER_CATALOG,
SUPPORTED_PROVIDER_IDS,
ProviderDescriptor,
)
from config.settings import ConfiguredChatModelRef, Settings
from providers.base import BaseProvider, ProviderConfig
from providers.exceptions import (
AuthenticationError,
ModelListResponseError,
ProviderError,
ServiceUnavailableError,
UnknownProviderTypeError,
)
ProviderFactory = Callable[[ProviderConfig, Settings], BaseProvider]
# Backwards-compatible name for the catalog (single source: ``config.provider_catalog``).
PROVIDER_DESCRIPTORS: dict[str, ProviderDescriptor] = PROVIDER_CATALOG
def _create_nvidia_nim(config: ProviderConfig, settings: Settings) -> BaseProvider:
from providers.nvidia_nim import NvidiaNimProvider
return NvidiaNimProvider(config, nim_settings=settings.nim)
def _create_open_router(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.open_router import OpenRouterProvider
return OpenRouterProvider(config)
def _create_deepseek(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.deepseek import DeepSeekProvider
return DeepSeekProvider(config)
def _create_lmstudio(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.lmstudio import LMStudioProvider
return LMStudioProvider(config)
def _create_llamacpp(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.llamacpp import LlamaCppProvider
return LlamaCppProvider(config)
def _create_ollama(config: ProviderConfig, _settings: Settings) -> BaseProvider:
from providers.ollama import OllamaProvider
return OllamaProvider(config)
PROVIDER_FACTORIES: dict[str, ProviderFactory] = {
"nvidia_nim": _create_nvidia_nim,
"open_router": _create_open_router,
"deepseek": _create_deepseek,
"lmstudio": _create_lmstudio,
"llamacpp": _create_llamacpp,
"ollama": _create_ollama,
}
if set(PROVIDER_DESCRIPTORS) != set(SUPPORTED_PROVIDER_IDS) or set(
PROVIDER_FACTORIES
) != set(SUPPORTED_PROVIDER_IDS):
raise AssertionError(
"PROVIDER_DESCRIPTORS, PROVIDER_FACTORIES, and SUPPORTED_PROVIDER_IDS are out of sync: "
f"descriptors={set(PROVIDER_DESCRIPTORS)!r} factories={set(PROVIDER_FACTORIES)!r} "
f"ids={set(SUPPORTED_PROVIDER_IDS)!r}"
)
def _string_attr(settings: Settings, attr_name: str | None, default: str = "") -> str:
if attr_name is None:
return default
value = getattr(settings, attr_name, default)
return value if isinstance(value, str) else default
def _credential_for(descriptor: ProviderDescriptor, settings: Settings) -> str:
if descriptor.static_credential is not None:
return descriptor.static_credential
if descriptor.credential_attr:
return _string_attr(settings, descriptor.credential_attr)
return ""
def _require_credential(descriptor: ProviderDescriptor, credential: str) -> None:
if descriptor.credential_env is None:
return
if credential and credential.strip():
return
message = f"{descriptor.credential_env} is not set. Add it to your .env file."
if descriptor.credential_url:
message = f"{message} Get a key at {descriptor.credential_url}"
raise AuthenticationError(message)
def build_provider_config(
descriptor: ProviderDescriptor, settings: Settings
) -> ProviderConfig:
credential = _credential_for(descriptor, settings)
_require_credential(descriptor, credential)
base_url = _string_attr(
settings, descriptor.base_url_attr, descriptor.default_base_url or ""
)
proxy = _string_attr(settings, descriptor.proxy_attr)
return ProviderConfig(
api_key=credential,
base_url=base_url or descriptor.default_base_url,
rate_limit=settings.provider_rate_limit,
rate_window=settings.provider_rate_window,
max_concurrency=settings.provider_max_concurrency,
http_read_timeout=settings.http_read_timeout,
http_write_timeout=settings.http_write_timeout,
http_connect_timeout=settings.http_connect_timeout,
enable_thinking=settings.enable_model_thinking,
proxy=proxy,
log_raw_sse_events=settings.log_raw_sse_events,
log_api_error_tracebacks=settings.log_api_error_tracebacks,
)
def create_provider(provider_id: str, settings: Settings) -> BaseProvider:
descriptor = PROVIDER_DESCRIPTORS.get(provider_id)
if descriptor is None:
supported = "', '".join(PROVIDER_DESCRIPTORS)
raise UnknownProviderTypeError(
f"Unknown provider_type: '{provider_id}'. Supported: '{supported}'"
)
config = build_provider_config(descriptor, settings)
factory = PROVIDER_FACTORIES.get(provider_id)
if factory is None:
raise AssertionError(f"Unhandled provider descriptor: {provider_id}")
return factory(config, settings)
def _format_provider_query_failures(
refs: list[ConfiguredChatModelRef],
exc: BaseException,
settings: Settings,
) -> list[str]:
reason = _provider_query_failure_reason(exc, settings)
return [_format_model_validation_failure(ref, reason) for ref in refs]
def _format_missing_model_failure(ref: ConfiguredChatModelRef) -> str:
return _format_model_validation_failure(ref, "missing model")
def _format_model_validation_failure(ref: ConfiguredChatModelRef, problem: str) -> str:
return (
f"sources={','.join(ref.sources)} provider={ref.provider_id} "
f"model={ref.model_id} problem={problem}"
)
def _provider_query_failure_reason(
exc: BaseException,
settings: Settings,
) -> str:
if isinstance(exc, ModelListResponseError):
return f"malformed model-list response: {exc.message}"
if isinstance(exc, httpx.HTTPStatusError):
return f"query failure: HTTP {exc.response.status_code}"
if isinstance(exc, AuthenticationError):
return f"query failure: {exc.message}"
if isinstance(exc, ProviderError) and settings.log_api_error_tracebacks:
return f"query failure: {exc.message}"
return f"query failure: {type(exc).__name__}"
class ProviderRegistry:
"""Cache and clean up provider instances by provider id."""
def __init__(self, providers: MutableMapping[str, BaseProvider] | None = None):
self._providers = providers if providers is not None else {}
def is_cached(self, provider_id: str) -> bool:
"""Return whether a provider for this id is already in the cache."""
return provider_id in self._providers
def get(self, provider_id: str, settings: Settings) -> BaseProvider:
if provider_id not in self._providers:
self._providers[provider_id] = create_provider(provider_id, settings)
return self._providers[provider_id]
async def validate_configured_models(self, settings: Settings) -> None:
"""Fail fast unless every configured chat model exists upstream."""
refs = settings.configured_chat_model_refs()
refs_by_provider: dict[str, list[ConfiguredChatModelRef]] = defaultdict(list)
for ref in refs:
refs_by_provider[ref.provider_id].append(ref)
failures: list[str] = []
tasks: dict[str, asyncio.Task[frozenset[str]]] = {}
for provider_id, provider_refs in refs_by_provider.items():
try:
provider = self.get(provider_id, settings)
except Exception as exc:
failures.extend(
_format_provider_query_failures(provider_refs, exc, settings)
)
continue
tasks[provider_id] = asyncio.create_task(provider.list_model_ids())
if tasks:
results = await asyncio.gather(*tasks.values(), return_exceptions=True)
for (provider_id, _task), result in zip(
tasks.items(), results, strict=True
):
provider_refs = refs_by_provider[provider_id]
if isinstance(result, BaseException):
if isinstance(result, asyncio.CancelledError):
raise result
failures.extend(
_format_provider_query_failures(provider_refs, result, settings)
)
continue
failures.extend(
_format_missing_model_failure(ref)
for ref in provider_refs
if ref.model_id not in result
)
if failures:
message = "Configured model validation failed:\n" + "\n".join(
f"- {failure}" for failure in failures
)
raise ServiceUnavailableError(message)
logger.info(
"Configured provider models validated: models={} providers={}",
len(refs),
len(refs_by_provider),
)
async def cleanup(self) -> None:
"""Call ``cleanup`` on every cached provider, then clear the cache.
Attempts all providers even if one fails. A single failure is re-raised
as-is; multiple failures are wrapped in :exc:`ExceptionGroup`.
"""
items = list(self._providers.items())
errors: list[Exception] = []
try:
for _pid, provider in items:
try:
await provider.cleanup()
except Exception as e:
errors.append(e)
finally:
self._providers.clear()
if len(errors) == 1:
raise errors[0]
if len(errors) > 1:
msg = "One or more provider cleanups failed"
raise ExceptionGroup(msg, errors)