mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-06-01 22:09:04 +02:00
Add fireworks AI support (#476)
This commit is contained in:
@@ -14,3 +14,5 @@ debug-*.log
|
||||
.coverage
|
||||
llama_cache
|
||||
.smoke-results
|
||||
.vscode
|
||||
server.*
|
||||
@@ -19,6 +19,7 @@ WAFER_DEFAULT_BASE = "https://pass.wafer.ai/v1"
|
||||
DEEPSEEK_ANTHROPIC_DEFAULT_BASE = "https://api.deepseek.com/anthropic"
|
||||
# Historical export name: DeepSeek upstream is the native Anthropic path above.
|
||||
DEEPSEEK_DEFAULT_BASE = DEEPSEEK_ANTHROPIC_DEFAULT_BASE
|
||||
FIREWORKS_DEFAULT_BASE = "https://api.fireworks.ai/inference/v1"
|
||||
OPENROUTER_DEFAULT_BASE = "https://openrouter.ai/api/v1"
|
||||
LMSTUDIO_DEFAULT_BASE = "http://localhost:1234/v1"
|
||||
LLAMACPP_DEFAULT_BASE = "http://localhost:8080/v1"
|
||||
@@ -145,6 +146,16 @@ PROVIDER_CATALOG: dict[str, ProviderDescriptor] = {
|
||||
proxy_attr="zai_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
|
||||
),
|
||||
"fireworks": ProviderDescriptor(
|
||||
provider_id="fireworks",
|
||||
transport_type="openai_chat",
|
||||
credential_env="FIREWORKS_API_KEY",
|
||||
credential_url="https://fireworks.ai/account/api-keys",
|
||||
credential_attr="fireworks_api_key",
|
||||
default_base_url=FIREWORKS_DEFAULT_BASE,
|
||||
proxy_attr="fireworks_proxy",
|
||||
capabilities=("chat", "streaming", "tools", "thinking", "rate_limit"),
|
||||
),
|
||||
}
|
||||
|
||||
# Order matches docs / historical error text; must match PROVIDER_CATALOG keys.
|
||||
|
||||
@@ -125,6 +125,9 @@ class Settings(BaseSettings):
|
||||
# ==================== Z.ai Config ====================
|
||||
zai_api_key: str = Field(default="", validation_alias="ZAI_API_KEY")
|
||||
|
||||
# ==================== Fireworks AI Config ====================
|
||||
fireworks_api_key: str = Field(default="", validation_alias="FIREWORKS_API_KEY")
|
||||
|
||||
# ==================== Messaging Platform Selection ====================
|
||||
# Valid: "telegram" | "discord" | "none"
|
||||
messaging_platform: str = Field(
|
||||
@@ -178,6 +181,7 @@ class Settings(BaseSettings):
|
||||
wafer_proxy: str = Field(default="", validation_alias="WAFER_PROXY")
|
||||
opencode_proxy: str = Field(default="", validation_alias="OPENCODE_PROXY")
|
||||
zai_proxy: str = Field(default="", validation_alias="ZAI_PROXY")
|
||||
fireworks_proxy: str = Field(default="", validation_alias="FIREWORKS_PROXY")
|
||||
|
||||
# ==================== Provider Rate Limiting ====================
|
||||
provider_rate_limit: int = Field(default=40, validation_alias="PROVIDER_RATE_LIMIT")
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Fireworks AI provider exports."""
|
||||
|
||||
from .client import FIREWORKS_BASE_URL, FireworksProvider
|
||||
|
||||
__all__ = ["FIREWORKS_BASE_URL", "FireworksProvider"]
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Fireworks AI provider implementation."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from providers.base import ProviderConfig
|
||||
from providers.openai_compat import OpenAIChatTransport
|
||||
|
||||
from .request import build_request_body
|
||||
|
||||
FIREWORKS_BASE_URL = "https://api.fireworks.ai/inference/v1"
|
||||
|
||||
|
||||
class FireworksProvider(OpenAIChatTransport):
|
||||
"""Fireworks AI provider using OpenAI-compatible chat completions."""
|
||||
|
||||
def __init__(self, config: ProviderConfig):
|
||||
super().__init__(
|
||||
config,
|
||||
provider_name="FIREWORKS",
|
||||
base_url=config.base_url or FIREWORKS_BASE_URL,
|
||||
api_key=config.api_key,
|
||||
)
|
||||
|
||||
def _build_request_body(
|
||||
self, request: Any, thinking_enabled: bool | None = None
|
||||
) -> dict:
|
||||
"""Build request body for Fireworks AI."""
|
||||
if thinking_enabled is None:
|
||||
thinking_enabled = self._is_thinking_enabled(request)
|
||||
return build_request_body(
|
||||
request,
|
||||
thinking_enabled=thinking_enabled,
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Request builder for Fireworks AI provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from core.anthropic import ReasoningReplayMode, build_base_request_body
|
||||
from core.anthropic.conversion import OpenAIConversionError
|
||||
from providers.exceptions import InvalidRequestError
|
||||
|
||||
|
||||
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
|
||||
"""Build OpenAI-format request body from Anthropic request for Fireworks AI."""
|
||||
logger.debug(
|
||||
"FIREWORKS_REQUEST: conversion start model={} msgs={}",
|
||||
getattr(request_data, "model", "?"),
|
||||
len(getattr(request_data, "messages", [])),
|
||||
)
|
||||
try:
|
||||
body = build_base_request_body(
|
||||
request_data,
|
||||
reasoning_replay=ReasoningReplayMode.REASONING_CONTENT,
|
||||
)
|
||||
except OpenAIConversionError as exc:
|
||||
raise InvalidRequestError(str(exc)) from exc
|
||||
|
||||
extra_body: dict[str, Any] = {}
|
||||
request_extra = getattr(request_data, "extra_body", None)
|
||||
if request_extra:
|
||||
extra_body.update(request_extra)
|
||||
|
||||
if extra_body:
|
||||
body["extra_body"] = extra_body
|
||||
|
||||
logger.debug(
|
||||
"FIREWORKS_REQUEST: conversion done model={} msgs={} tools={}",
|
||||
body.get("model"),
|
||||
len(body.get("messages", [])),
|
||||
len(body.get("tools", [])),
|
||||
)
|
||||
return body
|
||||
@@ -92,6 +92,12 @@ def _create_zai(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
||||
return ZaiProvider(config)
|
||||
|
||||
|
||||
def _create_fireworks(config: ProviderConfig, _settings: Settings) -> BaseProvider:
|
||||
from providers.fireworks import FireworksProvider
|
||||
|
||||
return FireworksProvider(config)
|
||||
|
||||
|
||||
PROVIDER_FACTORIES: dict[str, ProviderFactory] = {
|
||||
"nvidia_nim": _create_nvidia_nim,
|
||||
"open_router": _create_open_router,
|
||||
@@ -103,6 +109,7 @@ PROVIDER_FACTORIES: dict[str, ProviderFactory] = {
|
||||
"wafer": _create_wafer,
|
||||
"opencode": _create_opencode,
|
||||
"zai": _create_zai,
|
||||
"fireworks": _create_fireworks,
|
||||
}
|
||||
|
||||
if set(PROVIDER_DESCRIPTORS) != set(SUPPORTED_PROVIDER_IDS) or set(
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
"""Tests for Fireworks AI provider."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from providers.base import ProviderConfig
|
||||
from providers.fireworks import FIREWORKS_BASE_URL, FireworksProvider
|
||||
|
||||
|
||||
class MockMessage:
|
||||
def __init__(self, role, content):
|
||||
self.role = role
|
||||
self.content = content
|
||||
|
||||
|
||||
class MockBlock:
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class MockRequest:
|
||||
def __init__(self, **kwargs):
|
||||
self.model = "accounts/fireworks/models/glm-5p1"
|
||||
self.messages = [MockMessage("user", "Hello")]
|
||||
self.max_tokens = 100
|
||||
self.temperature = 0.5
|
||||
self.top_p = 0.9
|
||||
self.system = "System prompt"
|
||||
self.stop_sequences = None
|
||||
self.tools = []
|
||||
self.extra_body = {}
|
||||
self.thinking = MagicMock()
|
||||
self.thinking.enabled = True
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fireworks_config():
|
||||
return ProviderConfig(
|
||||
api_key="test_fireworks_key",
|
||||
base_url=FIREWORKS_BASE_URL,
|
||||
rate_limit=10,
|
||||
rate_window=60,
|
||||
enable_thinking=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_rate_limiter():
|
||||
"""Mock the global rate limiter to prevent waiting."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def _slot():
|
||||
yield
|
||||
|
||||
with patch("providers.openai_compat.GlobalRateLimiter") as mock:
|
||||
instance = mock.get_scoped_instance.return_value
|
||||
|
||||
async def _passthrough(fn, *args, **kwargs):
|
||||
return await fn(*args, **kwargs)
|
||||
|
||||
instance.execute_with_retry = AsyncMock(side_effect=_passthrough)
|
||||
instance.concurrency_slot.side_effect = _slot
|
||||
yield instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fireworks_provider(fireworks_config):
|
||||
return FireworksProvider(fireworks_config)
|
||||
|
||||
|
||||
def test_init(fireworks_config):
|
||||
"""Test provider initialization."""
|
||||
with patch("providers.openai_compat.AsyncOpenAI") as mock_openai:
|
||||
provider = FireworksProvider(fireworks_config)
|
||||
assert provider._api_key == "test_fireworks_key"
|
||||
assert provider._base_url == FIREWORKS_BASE_URL
|
||||
mock_openai.assert_called_once()
|
||||
|
||||
|
||||
def test_base_url_constant():
|
||||
"""FIREWORKS_BASE_URL points to the Fireworks AI inference endpoint."""
|
||||
assert FIREWORKS_BASE_URL == "https://api.fireworks.ai/inference/v1"
|
||||
|
||||
|
||||
def test_build_request_body_basic(fireworks_provider):
|
||||
"""Basic request body conversion works for Fireworks AI."""
|
||||
req = MockRequest()
|
||||
body = fireworks_provider._build_request_body(req)
|
||||
|
||||
assert body["model"] == "accounts/fireworks/models/glm-5p1"
|
||||
assert body["messages"][0]["role"] == "system"
|
||||
|
||||
|
||||
def test_build_request_body_global_disable_blocks_thinking():
|
||||
"""Global disable suppresses provider-side thinking."""
|
||||
provider = FireworksProvider(
|
||||
ProviderConfig(
|
||||
api_key="test_fireworks_key",
|
||||
base_url=FIREWORKS_BASE_URL,
|
||||
rate_limit=10,
|
||||
rate_window=60,
|
||||
enable_thinking=False,
|
||||
)
|
||||
)
|
||||
req = MockRequest()
|
||||
body = provider._build_request_body(req)
|
||||
|
||||
# When thinking is disabled, no thinking-related fields should appear
|
||||
assert "extra_body" not in body or "thinking" not in body.get("extra_body", {})
|
||||
|
||||
|
||||
def test_build_request_body_request_disable_blocks_thinking(fireworks_provider):
|
||||
"""Request-level disable suppresses provider-side thinking when global is enabled."""
|
||||
req = MockRequest()
|
||||
req.thinking.enabled = False
|
||||
body = fireworks_provider._build_request_body(req)
|
||||
|
||||
assert "extra_body" not in body or "thinking" not in body.get("extra_body", {})
|
||||
|
||||
|
||||
def test_build_request_body_preserves_caller_extra_body(fireworks_provider):
|
||||
"""Caller-provided extra_body should be preserved."""
|
||||
req = MockRequest(
|
||||
extra_body={"custom_param": "value"},
|
||||
)
|
||||
body = fireworks_provider._build_request_body(req)
|
||||
|
||||
assert body["extra_body"]["custom_param"] == "value"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_text(fireworks_provider):
|
||||
"""Text content deltas are emitted as text blocks."""
|
||||
req = MockRequest()
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.choices = [
|
||||
MagicMock(
|
||||
delta=MagicMock(
|
||||
content="Hello back!",
|
||||
reasoning_content=None,
|
||||
tool_calls=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
]
|
||||
mock_chunk.usage = MagicMock(completion_tokens=5, prompt_tokens=10)
|
||||
|
||||
async def mock_stream():
|
||||
yield mock_chunk
|
||||
|
||||
with patch.object(
|
||||
fireworks_provider._client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_create:
|
||||
mock_create.return_value = mock_stream()
|
||||
|
||||
events = [event async for event in fireworks_provider.stream_response(req)]
|
||||
|
||||
assert any(
|
||||
'"text_delta"' in event and "Hello back!" in event for event in events
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_response_reasoning_content(fireworks_provider):
|
||||
"""reasoning_content deltas are emitted as thinking blocks."""
|
||||
req = MockRequest()
|
||||
|
||||
mock_chunk = MagicMock()
|
||||
mock_chunk.choices = [
|
||||
MagicMock(
|
||||
delta=MagicMock(
|
||||
content=None,
|
||||
reasoning_content="Thinking...",
|
||||
tool_calls=None,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
]
|
||||
mock_chunk.usage = MagicMock(completion_tokens=2, prompt_tokens=10)
|
||||
|
||||
async def mock_stream():
|
||||
yield mock_chunk
|
||||
|
||||
with patch.object(
|
||||
fireworks_provider._client.chat.completions, "create", new_callable=AsyncMock
|
||||
) as mock_create:
|
||||
mock_create.return_value = mock_stream()
|
||||
|
||||
events = [event async for event in fireworks_provider.stream_response(req)]
|
||||
|
||||
assert any(
|
||||
'"thinking_delta"' in event and "Thinking..." in event for event in events
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup(fireworks_provider):
|
||||
"""cleanup closes the OpenAI client."""
|
||||
fireworks_provider._client = AsyncMock()
|
||||
|
||||
await fireworks_provider.cleanup()
|
||||
|
||||
fireworks_provider._client.close.assert_called_once()
|
||||
Reference in New Issue
Block a user