mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-06-02 06:13:46 +02:00
feat(providers): native Anthropic Messages for Kimi, Fireworks, Z.ai
Route these providers through POST /messages with vendor headers and bases (including Kimi model list on OpenAI /v1/models). Remove Z.ai from OpenAI-chat server-tool rejection; extend tests and README.
This commit is contained in:
+3
-3
@@ -14,7 +14,7 @@ MISTRAL_API_KEY=""
|
||||
DEEPSEEK_API_KEY=""
|
||||
|
||||
|
||||
# Kimi Config (Moonshot OpenAI-compatible API)
|
||||
# Kimi Config (Anthropic-compatible Messages at api.moonshot.ai/anthropic/v1)
|
||||
KIMI_API_KEY=""
|
||||
|
||||
|
||||
@@ -26,11 +26,11 @@ WAFER_API_KEY=""
|
||||
OPENCODE_API_KEY=""
|
||||
|
||||
|
||||
# Z.ai Config (Anthropic-compatible Messages at api.z.ai/api/anthropic)
|
||||
# Z.ai Config (Anthropic-compatible Messages at api.z.ai/api/anthropic/v1)
|
||||
ZAI_API_KEY=""
|
||||
|
||||
|
||||
# Fireworks AI Config (OpenAI-compatible Chat Completions at api.fireworks.ai/inference/v1)
|
||||
# Fireworks AI Config (Anthropic-compatible Messages at api.fireworks.ai/inference/v1)
|
||||
FIREWORKS_API_KEY=""
|
||||
|
||||
|
||||
|
||||
@@ -158,6 +158,8 @@ Get a key at [platform.moonshot.ai/console/api-keys](https://platform.moonshot.a
|
||||
|
||||
In the Admin UI, paste it into `KIMI_API_KEY`, then set `MODEL` to a Kimi slug such as `kimi/kimi-k2.5`.
|
||||
|
||||
This provider calls Kimi's **Anthropic-compatible** Messages API (`https://api.moonshot.ai/anthropic/v1/messages`; model discovery uses OpenAI-compat `GET https://api.moonshot.ai/v1/models`). It is **not** the OpenAI Chat Completions path.
|
||||
|
||||
Browse models at [platform.moonshot.ai](https://platform.moonshot.ai).
|
||||
|
||||
### 6. [Wafer](https://wafer.ai/)
|
||||
@@ -239,7 +241,7 @@ Get an API key at [Z.ai/manage-apikey/apikey-list](https://z.ai/manage-apikey/ap
|
||||
|
||||
In the Admin UI, paste it into `ZAI_API_KEY`, then set `MODEL` to a Z.ai model slug such as `zai/glm-5.1`.
|
||||
|
||||
Z.ai provides GLM models through the OpenAI-compatible Coding Plan endpoint at `https://api.z.ai/api/coding/paas/v4`.
|
||||
This provider calls Z.ai's **Anthropic-compatible** Messages API (`https://api.z.ai/api/anthropic/v1/messages`). The former OpenAI Coding Plan base (`https://api.z.ai/api/coding/paas/v4`) is **not** used by this gateway.
|
||||
|
||||
Popular examples:
|
||||
|
||||
@@ -254,7 +256,7 @@ Get an API key at [fireworks.ai/account/api-keys](https://fireworks.ai/account/a
|
||||
|
||||
In the Admin UI, paste it into `FIREWORKS_API_KEY`, then set `MODEL` to a Fireworks model slug such as `fireworks/accounts/fireworks/models/llama-v3p3-70b-instruct`.
|
||||
|
||||
Fireworks exposes an OpenAI-compatible Chat Completions API at `https://api.fireworks.ai/inference/v1`.
|
||||
Fireworks exposes an **Anthropic-compatible** Messages API at `https://api.fireworks.ai/inference/v1/messages` (same inference host as before; Chat Completions is not used here). Vendor-specific JSON keys can still be merged from request `extra_body` when allowed.
|
||||
|
||||
Browse models at [fireworks.ai/models](https://fireworks.ai/models).
|
||||
|
||||
@@ -443,8 +445,8 @@ Important pieces:
|
||||
|
||||
- FastAPI exposes Anthropic-compatible routes such as `/v1/messages`, `/v1/messages/count_tokens`, and `/v1/models`.
|
||||
- Model routing resolves the Claude model name to `MODEL_OPUS`, `MODEL_SONNET`, `MODEL_HAIKU`, or `MODEL`.
|
||||
- NIM, OpenCode Zen, OpenCode Go, Z.ai use OpenAI chat streaming translated into Anthropic SSE.
|
||||
- Wafer, OpenRouter, DeepSeek, LM Studio, llama.cpp, and Ollama use Anthropic Messages style transports.
|
||||
- NIM, OpenCode Zen, and OpenCode Go use OpenAI chat streaming translated into Anthropic SSE.
|
||||
- Wafer, OpenRouter, DeepSeek, Kimi, Fireworks AI, Z.ai, LM Studio, llama.cpp, and Ollama use Anthropic Messages style transports where applicable (with provider-specific quirks and model-list URLs).
|
||||
- The proxy normalizes thinking blocks, tool calls, token usage metadata, and provider errors into the shape Claude Code expects.
|
||||
- Request optimizations answer trivial Claude Code probes locally to save latency and quota.
|
||||
|
||||
|
||||
+1
-1
@@ -34,7 +34,7 @@ TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], i
|
||||
ProviderGetter = Callable[[str], BaseProvider]
|
||||
|
||||
# Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages).
|
||||
_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "opencode", "opencode_go", "zai"})
|
||||
_OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "opencode", "opencode_go"})
|
||||
|
||||
|
||||
def anthropic_sse_streaming_response(
|
||||
|
||||
@@ -79,7 +79,7 @@ def openai_chat_upstream_server_tool_error(
|
||||
)
|
||||
if not forced and has_listed_anthropic_server_tools(request):
|
||||
return (
|
||||
"OpenAI Chat upstreams (NVIDIA NIM) cannot use listed Anthropic server tools "
|
||||
"OpenAI Chat upstreams cannot use listed Anthropic server tools "
|
||||
"(web_search / web_fetch) without the local web server tool handler. Use a native "
|
||||
"Anthropic transport, set ENABLE_WEB_SERVER_TOOLS=true and force the tool with "
|
||||
"tool_choice, or remove these tools from the request."
|
||||
|
||||
@@ -13,7 +13,8 @@ TransportType = Literal["openai_chat", "anthropic_messages"]
|
||||
|
||||
# Default upstream base URLs (also re-exported via :mod:`providers.defaults`)
|
||||
NVIDIA_NIM_DEFAULT_BASE = "https://integrate.api.nvidia.com/v1"
|
||||
KIMI_DEFAULT_BASE = "https://api.moonshot.ai/v1"
|
||||
# Moonshot Kimi Anthropic-compatible Messages API (POST …/messages).
|
||||
KIMI_DEFAULT_BASE = "https://api.moonshot.ai/anthropic/v1"
|
||||
WAFER_DEFAULT_BASE = "https://pass.wafer.ai/v1"
|
||||
# DeepSeek Anthropic-compatible Messages API (not OpenAI ``/v1`` chat completions).
|
||||
DEEPSEEK_ANTHROPIC_DEFAULT_BASE = "https://api.deepseek.com/anthropic"
|
||||
@@ -27,7 +28,8 @@ LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1"
|
||||
OLLAMA_DEFAULT_BASE = "http://localhost:11434"
|
||||
OPENCODE_DEFAULT_BASE = "https://opencode.ai/zen/v1"
|
||||
OPENCODE_GO_DEFAULT_BASE = "https://opencode.ai/zen/go/v1"
|
||||
ZAI_DEFAULT_BASE = "https://api.z.ai/api/coding/paas/v4"
|
||||
# Z.ai Anthropic-compatible Messages API (not OpenAI Coding Plan chat completions).
|
||||
ZAI_DEFAULT_BASE = "https://api.z.ai/api/anthropic/v1"
|
||||
# Google AI Studio Gemini API OpenAI-compat layer (not Vertex AI).
|
||||
GEMINI_DEFAULT_BASE = "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||
GROQ_DEFAULT_BASE = "https://api.groq.com/openai/v1"
|
||||
@@ -92,13 +94,19 @@ PROVIDER_CATALOG: dict[str, ProviderDescriptor] = {
|
||||
),
|
||||
"kimi": ProviderDescriptor(
|
||||
provider_id="kimi",
|
||||
transport_type="openai_chat",
|
||||
transport_type="anthropic_messages",
|
||||
credential_env="KIMI_API_KEY",
|
||||
credential_url="https://platform.moonshot.cn/console/api-keys",
|
||||
credential_attr="kimi_api_key",
|
||||
default_base_url=KIMI_DEFAULT_BASE,
|
||||
proxy_attr="kimi_proxy",
|
||||
capabilities=("chat", "streaming", "tools"),
|
||||
capabilities=(
|
||||
"chat",
|
||||
"streaming",
|
||||
"tools",
|
||||
"thinking",
|
||||
"native_anthropic",
|
||||
),
|
||||
),
|
||||
"wafer": ProviderDescriptor(
|
||||
provider_id="wafer",
|
||||
@@ -165,22 +173,36 @@ PROVIDER_CATALOG: dict[str, ProviderDescriptor] = {
|
||||
),
|
||||
"zai": ProviderDescriptor(
|
||||
provider_id="zai",
|
||||
transport_type="openai_chat",
|
||||
transport_type="anthropic_messages",
|
||||
credential_env="ZAI_API_KEY",
|
||||
credential_attr="zai_api_key",
|
||||
default_base_url=ZAI_DEFAULT_BASE,
|
||||
proxy_attr="zai_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
|
||||
capabilities=(
|
||||
"chat",
|
||||
"streaming",
|
||||
"tools",
|
||||
"thinking",
|
||||
"native_anthropic",
|
||||
"rate_limit",
|
||||
),
|
||||
),
|
||||
"fireworks": ProviderDescriptor(
|
||||
provider_id="fireworks",
|
||||
transport_type="openai_chat",
|
||||
transport_type="anthropic_messages",
|
||||
credential_env="FIREWORKS_API_KEY",
|
||||
credential_url="https://fireworks.ai/account/api-keys",
|
||||
credential_attr="fireworks_api_key",
|
||||
default_base_url=FIREWORKS_DEFAULT_BASE,
|
||||
proxy_attr="fireworks_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
|
||||
capabilities=(
|
||||
"chat",
|
||||
"streaming",
|
||||
"tools",
|
||||
"thinking",
|
||||
"native_anthropic",
|
||||
"rate_limit",
|
||||
),
|
||||
),
|
||||
"gemini": ProviderDescriptor(
|
||||
provider_id="gemini",
|
||||
|
||||
@@ -1,33 +1,45 @@
|
||||
"""Fireworks AI provider implementation."""
|
||||
"""Fireworks AI provider using native Anthropic-compatible Messages."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from providers.anthropic_messages import AnthropicMessagesTransport
|
||||
from providers.base import ProviderConfig
|
||||
from providers.openai_compat import OpenAIChatTransport
|
||||
|
||||
from .request import build_request_body
|
||||
|
||||
FIREWORKS_BASE_URL = "https://api.fireworks.ai/inference/v1"
|
||||
_ANTHROPIC_VERSION = "2023-06-01"
|
||||
|
||||
|
||||
class FireworksProvider(OpenAIChatTransport):
|
||||
"""Fireworks AI provider using OpenAI-compatible chat completions."""
|
||||
class FireworksProvider(AnthropicMessagesTransport):
|
||||
"""Fireworks AI using Anthropic-compatible Messages."""
|
||||
|
||||
def __init__(self, config: ProviderConfig):
|
||||
super().__init__(
|
||||
config,
|
||||
provider_name="FIREWORKS",
|
||||
base_url=config.base_url or FIREWORKS_BASE_URL,
|
||||
api_key=config.api_key,
|
||||
default_base_url=FIREWORKS_BASE_URL,
|
||||
)
|
||||
|
||||
def _build_request_body(
|
||||
self, request: Any, thinking_enabled: bool | None = None
|
||||
) -> dict:
|
||||
"""Build request body for Fireworks AI."""
|
||||
if thinking_enabled is None:
|
||||
thinking_enabled = self._is_thinking_enabled(request)
|
||||
return build_request_body(
|
||||
request,
|
||||
thinking_enabled=thinking_enabled,
|
||||
)
|
||||
|
||||
def _request_headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"Accept": "text/event-stream",
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"anthropic-version": _ANTHROPIC_VERSION,
|
||||
}
|
||||
|
||||
def _model_list_headers(self) -> dict[str, str]:
|
||||
return {"Authorization": f"Bearer {self._api_key}"}
|
||||
|
||||
@@ -1,39 +1,46 @@
|
||||
"""Request builder for Fireworks AI provider."""
|
||||
"""Native Anthropic Messages request builder for Fireworks AI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from core.anthropic import ReasoningReplayMode, build_base_request_body
|
||||
from core.anthropic.conversion import OpenAIConversionError
|
||||
from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
|
||||
from core.anthropic.native_messages_request import (
|
||||
OpenRouterExtraBodyError,
|
||||
build_base_native_anthropic_request_body,
|
||||
validate_openrouter_extra_body,
|
||||
)
|
||||
from providers.exceptions import InvalidRequestError
|
||||
|
||||
|
||||
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
|
||||
"""Build OpenAI-format request body from Anthropic request for Fireworks AI."""
|
||||
"""Build JSON for Fireworks Anthropic-compat ``POST …/messages``."""
|
||||
logger.debug(
|
||||
"FIREWORKS_REQUEST: conversion start model={} msgs={}",
|
||||
"FIREWORKS_REQUEST: native build model={} msgs={}",
|
||||
getattr(request_data, "model", "?"),
|
||||
len(getattr(request_data, "messages", [])),
|
||||
)
|
||||
try:
|
||||
body = build_base_request_body(
|
||||
request_data,
|
||||
reasoning_replay=ReasoningReplayMode.REASONING_CONTENT,
|
||||
)
|
||||
except OpenAIConversionError as exc:
|
||||
raise InvalidRequestError(str(exc)) from exc
|
||||
|
||||
extra_body: dict[str, Any] = {}
|
||||
request_extra = getattr(request_data, "extra_body", None)
|
||||
if request_extra:
|
||||
extra_body.update(request_extra)
|
||||
body = build_base_native_anthropic_request_body(
|
||||
request_data,
|
||||
default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
|
||||
thinking_enabled=thinking_enabled,
|
||||
)
|
||||
|
||||
if extra_body:
|
||||
body["extra_body"] = extra_body
|
||||
extra = getattr(request_data, "extra_body", None)
|
||||
if isinstance(extra, dict) and extra:
|
||||
try:
|
||||
validate_openrouter_extra_body(extra)
|
||||
except OpenRouterExtraBodyError as exc:
|
||||
raise InvalidRequestError(str(exc)) from exc
|
||||
body.update(extra)
|
||||
|
||||
body["stream"] = True
|
||||
|
||||
logger.debug(
|
||||
"FIREWORKS_REQUEST: conversion done model={} msgs={} tools={}",
|
||||
"FIREWORKS_REQUEST: build done model={} msgs={} tools={}",
|
||||
body.get("model"),
|
||||
len(body.get("messages", [])),
|
||||
len(body.get("tools", [])),
|
||||
|
||||
@@ -1,25 +1,29 @@
|
||||
"""Kimi (Moonshot) provider implementation."""
|
||||
"""Kimi (Moonshot) provider using native Anthropic-compatible Messages."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from providers.anthropic_messages import AnthropicMessagesTransport
|
||||
from providers.base import ProviderConfig
|
||||
from providers.defaults import KIMI_DEFAULT_BASE
|
||||
from providers.openai_compat import OpenAIChatTransport
|
||||
|
||||
from .request import build_request_body
|
||||
|
||||
_MOONSHOT_OPENAI_MODELS_URL = "https://api.moonshot.ai/v1/models"
|
||||
_ANTHROPIC_VERSION = "2023-06-01"
|
||||
|
||||
class KimiProvider(OpenAIChatTransport):
|
||||
"""Kimi provider using the OpenAI-compatible chat completions API."""
|
||||
|
||||
class KimiProvider(AnthropicMessagesTransport):
|
||||
"""Kimi provider using Anthropic-compatible Messages at api.moonshot.ai/anthropic/v1."""
|
||||
|
||||
def __init__(self, config: ProviderConfig):
|
||||
super().__init__(
|
||||
config,
|
||||
provider_name="KIMI",
|
||||
base_url=config.base_url or KIMI_DEFAULT_BASE,
|
||||
api_key=config.api_key,
|
||||
default_base_url=KIMI_DEFAULT_BASE,
|
||||
)
|
||||
|
||||
def _build_request_body(
|
||||
@@ -29,3 +33,21 @@ class KimiProvider(OpenAIChatTransport):
|
||||
request,
|
||||
thinking_enabled=self._is_thinking_enabled(request, thinking_enabled),
|
||||
)
|
||||
|
||||
def _request_headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"Accept": "text/event-stream",
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"anthropic-version": _ANTHROPIC_VERSION,
|
||||
}
|
||||
|
||||
async def _send_model_list_request(self) -> httpx.Response:
|
||||
"""Models are listed from the OpenAI-compat root, not ``/anthropic/v1``."""
|
||||
return await self._client.get(
|
||||
_MOONSHOT_OPENAI_MODELS_URL,
|
||||
headers=self._model_list_headers(),
|
||||
)
|
||||
|
||||
def _model_list_headers(self) -> dict[str, str]:
|
||||
return {"Authorization": f"Bearer {self._api_key}"}
|
||||
|
||||
+21
-12
@@ -1,31 +1,40 @@
|
||||
"""Request builder for Kimi (Moonshot) provider."""
|
||||
"""Native Anthropic Messages request builder for Kimi (Moonshot)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from core.anthropic import ReasoningReplayMode, build_base_request_body
|
||||
from core.anthropic.conversion import OpenAIConversionError
|
||||
from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
|
||||
from core.anthropic.native_messages_request import (
|
||||
build_base_native_anthropic_request_body,
|
||||
)
|
||||
from providers.exceptions import InvalidRequestError
|
||||
|
||||
|
||||
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
|
||||
"""Build OpenAI-format request body from Anthropic request."""
|
||||
"""Build JSON for Kimi Anthropic-compat ``POST …/messages``."""
|
||||
logger.debug(
|
||||
"KIMI_REQUEST: conversion start model={} msgs={}",
|
||||
"KIMI_REQUEST: native build model={} msgs={}",
|
||||
getattr(request_data, "model", "?"),
|
||||
len(getattr(request_data, "messages", [])),
|
||||
)
|
||||
try:
|
||||
body = build_base_request_body(
|
||||
request_data,
|
||||
reasoning_replay=ReasoningReplayMode.REASONING_CONTENT,
|
||||
|
||||
body = build_base_native_anthropic_request_body(
|
||||
request_data,
|
||||
default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
|
||||
thinking_enabled=thinking_enabled,
|
||||
)
|
||||
extra = getattr(request_data, "extra_body", None)
|
||||
if extra:
|
||||
raise InvalidRequestError(
|
||||
"Kimi native Messages API does not support extra_body on requests."
|
||||
)
|
||||
except OpenAIConversionError as exc:
|
||||
raise InvalidRequestError(str(exc)) from exc
|
||||
body["stream"] = True
|
||||
|
||||
logger.debug(
|
||||
"KIMI_REQUEST: conversion done model={} msgs={} tools={}",
|
||||
"KIMI_REQUEST: build done model={} msgs={} tools={}",
|
||||
body.get("model"),
|
||||
len(body.get("messages", [])),
|
||||
len(body.get("tools", [])),
|
||||
|
||||
+21
-6
@@ -1,25 +1,26 @@
|
||||
"""Z.ai provider implementation (OpenAI-compatible Coding Plan API)."""
|
||||
"""Z.ai provider implementation (Anthropic-compatible Messages API)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from providers.anthropic_messages import AnthropicMessagesTransport
|
||||
from providers.base import ProviderConfig
|
||||
from providers.defaults import ZAI_DEFAULT_BASE
|
||||
from providers.openai_compat import OpenAIChatTransport
|
||||
|
||||
from .request import build_request_body
|
||||
|
||||
_ANTHROPIC_VERSION = "2023-06-01"
|
||||
|
||||
class ZaiProvider(OpenAIChatTransport):
|
||||
"""Z.ai using OpenAI-compatible Coding Plan API."""
|
||||
|
||||
class ZaiProvider(AnthropicMessagesTransport):
|
||||
"""Z.ai using Anthropic-compatible Messages at api.z.ai/api/anthropic/v1."""
|
||||
|
||||
def __init__(self, config: ProviderConfig):
|
||||
super().__init__(
|
||||
config,
|
||||
provider_name="ZAI",
|
||||
base_url=config.base_url or ZAI_DEFAULT_BASE,
|
||||
api_key=config.api_key,
|
||||
default_base_url=ZAI_DEFAULT_BASE,
|
||||
)
|
||||
|
||||
def _build_request_body(
|
||||
@@ -29,3 +30,17 @@ class ZaiProvider(OpenAIChatTransport):
|
||||
request,
|
||||
thinking_enabled=self._is_thinking_enabled(request, thinking_enabled),
|
||||
)
|
||||
|
||||
def _request_headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"Accept": "text/event-stream",
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": self._api_key,
|
||||
"anthropic-version": _ANTHROPIC_VERSION,
|
||||
}
|
||||
|
||||
def _model_list_headers(self) -> dict[str, str]:
|
||||
return {
|
||||
"x-api-key": self._api_key,
|
||||
"anthropic-version": _ANTHROPIC_VERSION,
|
||||
}
|
||||
|
||||
+21
-12
@@ -1,31 +1,40 @@
|
||||
"""Request builder for Z.ai OpenAI-compatible Coding Plan API."""
|
||||
"""Native Anthropic Messages request builder for Z.ai."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from core.anthropic import ReasoningReplayMode, build_base_request_body
|
||||
from core.anthropic.conversion import OpenAIConversionError
|
||||
from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
|
||||
from core.anthropic.native_messages_request import (
|
||||
build_base_native_anthropic_request_body,
|
||||
)
|
||||
from providers.exceptions import InvalidRequestError
|
||||
|
||||
|
||||
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
|
||||
"""Build OpenAI-format request body from Anthropic request."""
|
||||
"""Build JSON for Z.ai Anthropic-compat ``POST …/messages``."""
|
||||
logger.debug(
|
||||
"ZAI_REQUEST: conversion start model={} msgs={}",
|
||||
"ZAI_REQUEST: native build model={} msgs={}",
|
||||
getattr(request_data, "model", "?"),
|
||||
len(getattr(request_data, "messages", [])),
|
||||
)
|
||||
try:
|
||||
body = build_base_request_body(
|
||||
request_data,
|
||||
reasoning_replay=ReasoningReplayMode.REASONING_CONTENT,
|
||||
|
||||
body = build_base_native_anthropic_request_body(
|
||||
request_data,
|
||||
default_max_tokens=ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS,
|
||||
thinking_enabled=thinking_enabled,
|
||||
)
|
||||
extra = getattr(request_data, "extra_body", None)
|
||||
if extra:
|
||||
raise InvalidRequestError(
|
||||
"Z.ai native Messages API does not support extra_body on requests."
|
||||
)
|
||||
except OpenAIConversionError as exc:
|
||||
raise InvalidRequestError(str(exc)) from exc
|
||||
body["stream"] = True
|
||||
|
||||
logger.debug(
|
||||
"ZAI_REQUEST: conversion done model={} msgs={} tools={}",
|
||||
"ZAI_REQUEST: build done model={} msgs={} tools={}",
|
||||
body.get("model"),
|
||||
len(body.get("messages", [])),
|
||||
len(body.get("tools", [])),
|
||||
|
||||
@@ -619,3 +619,28 @@ def test_listed_server_tools_routed_on_open_router() -> None:
|
||||
)
|
||||
service.create_message(request)
|
||||
mock_provider.preflight_stream.assert_called()
|
||||
|
||||
|
||||
def test_listed_server_tools_routed_on_zai() -> None:
|
||||
"""Z.ai uses native Anthropic Messages; listed server tools are not OpenAI-chat blocked."""
|
||||
settings = Settings()
|
||||
|
||||
async def fake_stream(*_a, **_k):
|
||||
yield 'event: message_start\ndata: {"type":"message_start"}\n\n'
|
||||
yield 'event: message_stop\ndata: {"type":"message_stop"}\n\n'
|
||||
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.stream_response = fake_stream
|
||||
service = ClaudeProxyService(
|
||||
settings,
|
||||
provider_getter=lambda _: mock_provider,
|
||||
model_router=FixedProviderModelRouter(settings, "zai"),
|
||||
)
|
||||
request = MessagesRequest(
|
||||
model="m",
|
||||
max_tokens=20,
|
||||
messages=[Message(role="user", content="q")],
|
||||
tools=[Tool(name="web_search", type="web_search_20250305")],
|
||||
)
|
||||
service.create_message(request)
|
||||
mock_provider.preflight_stream.assert_called()
|
||||
|
||||
@@ -7,8 +7,10 @@ from messaging.platforms.factory import create_messaging_platform
|
||||
from providers.base import BaseProvider
|
||||
from providers.cerebras import CerebrasProvider
|
||||
from providers.deepseek import DeepSeekProvider
|
||||
from providers.fireworks import FireworksProvider
|
||||
from providers.gemini import GeminiProvider
|
||||
from providers.groq import GroqProvider
|
||||
from providers.kimi import KimiProvider
|
||||
from providers.llamacpp import LlamaCppProvider
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
from providers.mistral import MistralProvider
|
||||
@@ -78,6 +80,8 @@ def test_provider_and_platform_registries_include_advertised_builtins() -> None:
|
||||
"open_router": OpenRouterProvider,
|
||||
"mistral": MistralProvider,
|
||||
"deepseek": DeepSeekProvider,
|
||||
"kimi": KimiProvider,
|
||||
"fireworks": FireworksProvider,
|
||||
"lmstudio": LMStudioProvider,
|
||||
"llamacpp": LlamaCppProvider,
|
||||
"ollama": OllamaProvider,
|
||||
|
||||
+102
-128
@@ -1,43 +1,17 @@
|
||||
"""Tests for Fireworks AI provider."""
|
||||
"""Tests for Fireworks AI native Anthropic Messages provider."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.models.anthropic import Message, MessagesRequest
|
||||
from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
|
||||
from providers.base import ProviderConfig
|
||||
from providers.exceptions import InvalidRequestError
|
||||
from providers.fireworks import FIREWORKS_BASE_URL, FireworksProvider
|
||||
|
||||
|
||||
class MockMessage:
|
||||
def __init__(self, role, content):
|
||||
self.role = role
|
||||
self.content = content
|
||||
|
||||
|
||||
class MockBlock:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class MockRequest:
|
||||
def __init__(self, **kwargs):
|
||||
self.model = "accounts/fireworks/models/glm-5p1"
|
||||
self.messages = [MockMessage("user", "Hello")]
|
||||
self.max_tokens = 100
|
||||
self.temperature = 0.5
|
||||
self.top_p = 0.9
|
||||
self.system = "System prompt"
|
||||
self.stop_sequences = None
|
||||
self.tools = []
|
||||
self.extra_body = {}
|
||||
self.thinking = MagicMock()
|
||||
self.thinking.enabled = True
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fireworks_config():
|
||||
return ProviderConfig(
|
||||
@@ -51,13 +25,11 @@ def fireworks_config():
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_rate_limiter():
|
||||
"""Mock the global rate limiter to prevent waiting."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _slot():
|
||||
yield
|
||||
|
||||
with patch("providers.openai_compat.GlobalRateLimiter") as mock:
|
||||
with patch("providers.anthropic_messages.GlobalRateLimiter") as mock:
|
||||
instance = mock.get_scoped_instance.return_value
|
||||
|
||||
async def _passthrough(fn, *args, **kwargs):
|
||||
@@ -74,136 +46,138 @@ def fireworks_provider(fireworks_config):
|
||||
|
||||
|
||||
def test_init(fireworks_config):
|
||||
"""Test provider initialization."""
|
||||
with patch("providers.openai_compat.AsyncOpenAI") as mock_openai:
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
provider = FireworksProvider(fireworks_config)
|
||||
assert provider._api_key == "test_fireworks_key"
|
||||
assert provider._base_url == FIREWORKS_BASE_URL
|
||||
mock_openai.assert_called_once()
|
||||
assert provider._api_key == "test_fireworks_key"
|
||||
assert provider._base_url == FIREWORKS_BASE_URL
|
||||
assert mock_client.called
|
||||
|
||||
|
||||
def test_base_url_constant():
|
||||
"""FIREWORKS_BASE_URL points to the Fireworks AI inference endpoint."""
|
||||
assert FIREWORKS_BASE_URL == "https://api.fireworks.ai/inference/v1"
|
||||
|
||||
|
||||
def test_build_request_body_basic(fireworks_provider):
|
||||
"""Basic request body conversion works for Fireworks AI."""
|
||||
req = MockRequest()
|
||||
body = fireworks_provider._build_request_body(req)
|
||||
def test_request_headers(fireworks_provider):
|
||||
h = fireworks_provider._request_headers()
|
||||
assert h["Authorization"] == "Bearer test_fireworks_key"
|
||||
assert h["anthropic-version"] == "2023-06-01"
|
||||
assert h["Accept"] == "text/event-stream"
|
||||
|
||||
|
||||
def test_build_request_body_native_shape(fireworks_provider):
|
||||
request = MessagesRequest(
|
||||
model="accounts/fireworks/models/glm-5p1",
|
||||
max_tokens=100,
|
||||
messages=[Message(role="user", content="Hello")],
|
||||
system="System prompt",
|
||||
)
|
||||
body = fireworks_provider._build_request_body(request)
|
||||
assert body["model"] == "accounts/fireworks/models/glm-5p1"
|
||||
assert body["messages"][0]["role"] == "system"
|
||||
assert body["stream"] is True
|
||||
assert body["max_tokens"] == 100
|
||||
assert body["system"] == "System prompt"
|
||||
assert body["messages"][0]["role"] == "user"
|
||||
|
||||
|
||||
def test_build_request_body_default_max_tokens(fireworks_provider):
|
||||
request = MessagesRequest(
|
||||
model="m",
|
||||
messages=[Message(role="user", content="x")],
|
||||
)
|
||||
body = fireworks_provider._build_request_body(request)
|
||||
assert body["max_tokens"] == ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
|
||||
|
||||
|
||||
def test_build_request_body_global_disable_blocks_thinking():
|
||||
"""Global disable suppresses provider-side thinking."""
|
||||
provider = FireworksProvider(
|
||||
ProviderConfig(
|
||||
api_key="test_fireworks_key",
|
||||
api_key="k",
|
||||
base_url=FIREWORKS_BASE_URL,
|
||||
rate_limit=10,
|
||||
rate_window=60,
|
||||
rate_limit=1,
|
||||
rate_window=1,
|
||||
enable_thinking=False,
|
||||
)
|
||||
)
|
||||
req = MockRequest()
|
||||
body = provider._build_request_body(req)
|
||||
|
||||
# When thinking is disabled, no thinking-related fields should appear
|
||||
assert "extra_body" not in body or "thinking" not in body.get("extra_body", {})
|
||||
request = MessagesRequest.model_validate(
|
||||
{
|
||||
"model": "m",
|
||||
"messages": [{"role": "user", "content": "x"}],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 1},
|
||||
}
|
||||
)
|
||||
body = provider._build_request_body(request)
|
||||
assert "thinking" not in body
|
||||
|
||||
|
||||
def test_build_request_body_request_disable_blocks_thinking(fireworks_provider):
|
||||
"""Request-level disable suppresses provider-side thinking when global is enabled."""
|
||||
req = MockRequest()
|
||||
req.thinking.enabled = False
|
||||
body = fireworks_provider._build_request_body(req)
|
||||
|
||||
assert "extra_body" not in body or "thinking" not in body.get("extra_body", {})
|
||||
|
||||
|
||||
def test_build_request_body_preserves_caller_extra_body(fireworks_provider):
|
||||
"""Caller-provided extra_body should be preserved."""
|
||||
req = MockRequest(
|
||||
extra_body={"custom_param": "value"},
|
||||
request = MessagesRequest.model_validate(
|
||||
{
|
||||
"model": "m",
|
||||
"messages": [{"role": "user", "content": "x"}],
|
||||
"thinking": {"enabled": False},
|
||||
}
|
||||
)
|
||||
body = fireworks_provider._build_request_body(req)
|
||||
body = fireworks_provider._build_request_body(request)
|
||||
assert "thinking" not in body
|
||||
|
||||
assert body["extra_body"]["custom_param"] == "value"
|
||||
|
||||
def test_build_request_body_merges_safe_extra_body(fireworks_provider):
|
||||
request = MessagesRequest.model_validate(
|
||||
{
|
||||
"model": "m",
|
||||
"messages": [{"role": "user", "content": "x"}],
|
||||
"extra_body": {"custom_param": "value"},
|
||||
}
|
||||
)
|
||||
body = fireworks_provider._build_request_body(request)
|
||||
assert body["custom_param"] == "value"
|
||||
|
||||
|
||||
def test_build_request_body_rejects_reserved_extra_body_keys(fireworks_provider):
|
||||
request = MessagesRequest.model_validate(
|
||||
{
|
||||
"model": "m",
|
||||
"messages": [{"role": "user", "content": "x"}],
|
||||
"extra_body": {"temperature": 0.1},
|
||||
}
|
||||
)
|
||||
with pytest.raises(InvalidRequestError, match="extra_body must not override"):
|
||||
fireworks_provider._build_request_body(request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_text(fireworks_provider):
|
||||
"""Text content deltas are emitted as text blocks."""
|
||||
req = MockRequest()
|
||||
async def test_stream_uses_post_messages_path(fireworks_provider):
|
||||
request = MessagesRequest(
|
||||
model="m",
|
||||
messages=[Message(role="user", content="hi")],
|
||||
)
|
||||
called: dict[str, str] = {}
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.choices = [
|
||||
MagicMock(
|
||||
delta=MagicMock(
|
||||
content="Hello back!",
|
||||
reasoning_content=None,
|
||||
tool_calls=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
]
|
||||
mock_chunk.usage = MagicMock(completion_tokens=5, prompt_tokens=10)
|
||||
async def fake_send(request, *args, **kwargs):
|
||||
called["path"] = request.url.path
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.is_closed = False
|
||||
mock_resp.raise_for_status = lambda: None
|
||||
|
||||
async def mock_stream():
|
||||
yield mock_chunk
|
||||
async def aiter():
|
||||
if False: # pragma: no cover
|
||||
yield ""
|
||||
|
||||
with patch.object(
|
||||
fireworks_provider._client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_create:
|
||||
mock_create.return_value = mock_stream()
|
||||
mock_resp.aiter_lines = aiter
|
||||
mock_resp.aclose = AsyncMock()
|
||||
return mock_resp
|
||||
|
||||
events = [event async for event in fireworks_provider.stream_response(req)]
|
||||
fireworks_provider._client.send = fake_send
|
||||
_ = [x async for x in fireworks_provider.stream_response(request, request_id="r1")]
|
||||
|
||||
assert any(
|
||||
'"text_delta"' in event and "Hello back!" in event for event in events
|
||||
)
|
||||
assert called["path"].endswith("/messages")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_reasoning_content(fireworks_provider):
|
||||
"""reasoning_content deltas are emitted as thinking blocks."""
|
||||
req = MockRequest()
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.choices = [
|
||||
MagicMock(
|
||||
delta=MagicMock(
|
||||
content=None,
|
||||
reasoning_content="Thinking...",
|
||||
tool_calls=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
]
|
||||
mock_chunk.usage = MagicMock(completion_tokens=2, prompt_tokens=10)
|
||||
|
||||
async def mock_stream():
|
||||
yield mock_chunk
|
||||
|
||||
with patch.object(
|
||||
fireworks_provider._client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_create:
|
||||
mock_create.return_value = mock_stream()
|
||||
|
||||
events = [event async for event in fireworks_provider.stream_response(req)]
|
||||
|
||||
assert any(
|
||||
'"thinking_delta"' in event and "Thinking..." in event for event in events
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup(fireworks_provider):
|
||||
"""cleanup closes the OpenAI client."""
|
||||
async def test_cleanup_aclose(fireworks_provider):
|
||||
fireworks_provider._client = AsyncMock()
|
||||
|
||||
await fireworks_provider.cleanup()
|
||||
|
||||
fireworks_provider._client.close.assert_called_once()
|
||||
fireworks_provider._client.aclose.assert_awaited_once()
|
||||
|
||||
@@ -0,0 +1,121 @@
|
||||
"""Tests for Kimi (Moonshot) native Anthropic Messages provider."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.models.anthropic import Message, MessagesRequest
|
||||
from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
|
||||
from providers.base import ProviderConfig
|
||||
from providers.defaults import KIMI_DEFAULT_BASE
|
||||
from providers.exceptions import InvalidRequestError
|
||||
from providers.kimi import KimiProvider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kimi_config():
|
||||
return ProviderConfig(
|
||||
api_key="test_kimi_key",
|
||||
base_url=KIMI_DEFAULT_BASE,
|
||||
rate_limit=10,
|
||||
rate_window=60,
|
||||
enable_thinking=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_rate_limiter():
|
||||
@asynccontextmanager
|
||||
async def _slot():
|
||||
yield
|
||||
|
||||
with patch("providers.anthropic_messages.GlobalRateLimiter") as mock:
|
||||
instance = mock.get_scoped_instance.return_value
|
||||
|
||||
async def _passthrough(fn, *args, **kwargs):
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
instance.execute_with_retry = AsyncMock(side_effect=_passthrough)
|
||||
instance.concurrency_slot.side_effect = _slot
|
||||
yield instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def kimi_provider(kimi_config):
|
||||
return KimiProvider(kimi_config)
|
||||
|
||||
|
||||
def test_init(kimi_config):
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
provider = KimiProvider(kimi_config)
|
||||
assert provider._api_key == "test_kimi_key"
|
||||
assert provider._base_url == KIMI_DEFAULT_BASE
|
||||
assert mock_client.called
|
||||
|
||||
|
||||
def test_request_headers(kimi_provider):
|
||||
h = kimi_provider._request_headers()
|
||||
assert h["Authorization"] == "Bearer test_kimi_key"
|
||||
assert h["anthropic-version"] == "2023-06-01"
|
||||
|
||||
|
||||
def test_build_request_body_native(kimi_provider):
|
||||
request = MessagesRequest(
|
||||
model="kimi-k2.5",
|
||||
max_tokens=50,
|
||||
messages=[Message(role="user", content="hi")],
|
||||
)
|
||||
body = kimi_provider._build_request_body(request)
|
||||
assert body["model"] == "kimi-k2.5"
|
||||
assert body["stream"] is True
|
||||
assert body["messages"][0]["role"] == "user"
|
||||
|
||||
|
||||
def test_build_request_body_default_max_tokens(kimi_provider):
|
||||
request = MessagesRequest(
|
||||
model="m",
|
||||
messages=[Message(role="user", content="x")],
|
||||
)
|
||||
body = kimi_provider._build_request_body(request)
|
||||
assert body["max_tokens"] == ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
|
||||
|
||||
|
||||
def test_build_request_body_rejects_extra_body(kimi_provider):
|
||||
request = MessagesRequest.model_validate(
|
||||
{
|
||||
"model": "m",
|
||||
"messages": [{"role": "user", "content": "x"}],
|
||||
"extra_body": {"x": 1},
|
||||
}
|
||||
)
|
||||
with pytest.raises(InvalidRequestError, match="does not support extra_body"):
|
||||
kimi_provider._build_request_body(request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_list_uses_moonshot_openai_url(kimi_provider):
|
||||
called: dict[str, str] = {}
|
||||
|
||||
async def fake_get(url: str, **_k):
|
||||
called["url"] = url
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.raise_for_status = lambda: None
|
||||
mock_resp.json = lambda: {"data": [{"id": "kimi-k2.5"}]}
|
||||
mock_resp.aclose = AsyncMock()
|
||||
return mock_resp
|
||||
|
||||
kimi_provider._client.get = fake_get
|
||||
|
||||
await kimi_provider.list_model_infos()
|
||||
|
||||
assert called["url"] == "https://api.moonshot.ai/v1/models"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_aclose(kimi_provider):
|
||||
kimi_provider._client = AsyncMock()
|
||||
|
||||
await kimi_provider.cleanup()
|
||||
|
||||
kimi_provider._client.aclose.assert_awaited_once()
|
||||
@@ -10,8 +10,10 @@ from config.provider_ids import SUPPORTED_PROVIDER_IDS
|
||||
from providers.cerebras import CerebrasProvider
|
||||
from providers.deepseek import DeepSeekProvider
|
||||
from providers.exceptions import UnknownProviderTypeError
|
||||
from providers.fireworks import FireworksProvider
|
||||
from providers.gemini import GeminiProvider
|
||||
from providers.groq import GroqProvider
|
||||
from providers.kimi import KimiProvider
|
||||
from providers.llamacpp import LlamaCppProvider
|
||||
from providers.lmstudio import LMStudioProvider
|
||||
from providers.mistral import MistralProvider
|
||||
@@ -49,12 +51,13 @@ def _make_settings(**overrides):
|
||||
mock.llamacpp_proxy = ""
|
||||
mock.mistral_proxy = ""
|
||||
mock.kimi_proxy = ""
|
||||
mock.kimi_api_key = "test_kimi_key"
|
||||
mock.wafer_proxy = ""
|
||||
mock.opencode_proxy = ""
|
||||
mock.opencode_go_proxy = ""
|
||||
mock.zai_proxy = ""
|
||||
mock.fireworks_proxy = ""
|
||||
mock.fireworks_api_key = ""
|
||||
mock.fireworks_api_key = "test_fireworks_key"
|
||||
mock.gemini_api_key = ""
|
||||
mock.gemini_proxy = ""
|
||||
mock.groq_api_key = ""
|
||||
@@ -162,11 +165,15 @@ def test_create_provider_instantiates_each_builtin():
|
||||
gemini_api_key="test_gemini_key",
|
||||
groq_api_key="test_groq_key",
|
||||
cerebras_api_key="test_cerebras_key",
|
||||
fireworks_api_key="test_fireworks_key",
|
||||
kimi_api_key="test_kimi_key",
|
||||
)
|
||||
cases = {
|
||||
"nvidia_nim": NvidiaNimProvider,
|
||||
"mistral": MistralProvider,
|
||||
"deepseek": DeepSeekProvider,
|
||||
"kimi": KimiProvider,
|
||||
"fireworks": FireworksProvider,
|
||||
"lmstudio": LMStudioProvider,
|
||||
"llamacpp": LlamaCppProvider,
|
||||
"ollama": OllamaProvider,
|
||||
|
||||
@@ -0,0 +1,107 @@
|
||||
"""Tests for Z.ai native Anthropic Messages provider."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.models.anthropic import Message, MessagesRequest
|
||||
from config.constants import ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
|
||||
from providers.base import ProviderConfig
|
||||
from providers.defaults import ZAI_DEFAULT_BASE
|
||||
from providers.exceptions import InvalidRequestError
|
||||
from providers.zai import ZaiProvider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zai_config():
|
||||
return ProviderConfig(
|
||||
api_key="test_zai_key",
|
||||
base_url=ZAI_DEFAULT_BASE,
|
||||
rate_limit=10,
|
||||
rate_window=60,
|
||||
enable_thinking=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_rate_limiter():
|
||||
@asynccontextmanager
|
||||
async def _slot():
|
||||
yield
|
||||
|
||||
with patch("providers.anthropic_messages.GlobalRateLimiter") as mock:
|
||||
instance = mock.get_scoped_instance.return_value
|
||||
|
||||
async def _passthrough(fn, *args, **kwargs):
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
instance.execute_with_retry = AsyncMock(side_effect=_passthrough)
|
||||
instance.concurrency_slot.side_effect = _slot
|
||||
yield instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def zai_provider(zai_config):
|
||||
return ZaiProvider(zai_config)
|
||||
|
||||
|
||||
def test_init(zai_config):
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
provider = ZaiProvider(zai_config)
|
||||
assert provider._api_key == "test_zai_key"
|
||||
assert provider._base_url == ZAI_DEFAULT_BASE
|
||||
assert mock_client.called
|
||||
|
||||
|
||||
def test_request_headers(zai_provider):
|
||||
h = zai_provider._request_headers()
|
||||
assert h["x-api-key"] == "test_zai_key"
|
||||
assert h["anthropic-version"] == "2023-06-01"
|
||||
|
||||
|
||||
def test_model_list_headers(zai_provider):
|
||||
h = zai_provider._model_list_headers()
|
||||
assert h["x-api-key"] == "test_zai_key"
|
||||
|
||||
|
||||
def test_build_request_body_native(zai_provider):
|
||||
request = MessagesRequest(
|
||||
model="glm-5.1",
|
||||
max_tokens=100,
|
||||
messages=[Message(role="user", content="Hello")],
|
||||
)
|
||||
body = zai_provider._build_request_body(request)
|
||||
assert body["model"] == "glm-5.1"
|
||||
assert body["stream"] is True
|
||||
assert body["max_tokens"] == 100
|
||||
|
||||
|
||||
def test_build_request_body_default_max_tokens(zai_provider):
|
||||
request = MessagesRequest(
|
||||
model="m",
|
||||
messages=[Message(role="user", content="x")],
|
||||
)
|
||||
body = zai_provider._build_request_body(request)
|
||||
assert body["max_tokens"] == ANTHROPIC_DEFAULT_MAX_OUTPUT_TOKENS
|
||||
|
||||
|
||||
def test_build_request_body_rejects_extra_body(zai_provider):
|
||||
request = MessagesRequest.model_validate(
|
||||
{
|
||||
"model": "m",
|
||||
"messages": [{"role": "user", "content": "x"}],
|
||||
"extra_body": {"x": 1},
|
||||
}
|
||||
)
|
||||
with pytest.raises(InvalidRequestError, match="does not support extra_body"):
|
||||
zai_provider._build_request_body(request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_aclose(zai_provider):
|
||||
zai_provider._client = AsyncMock()
|
||||
|
||||
await zai_provider.cleanup()
|
||||
|
||||
zai_provider._client.aclose.assert_awaited_once()
|
||||
Reference in New Issue
Block a user