fix(gemini): nest google extra body for sdk

This commit is contained in:
Alishahryar1
2026-05-24 11:10:29 -07:00
parent 3e70a1c47e
commit 8ae7795961
2 changed files with 85 additions and 9 deletions
+23 -8
View File
@@ -2,7 +2,8 @@
from __future__ import annotations
from typing import Any
from copy import deepcopy
from typing import Any, cast
from loguru import logger
@@ -11,6 +12,24 @@ from core.anthropic.conversion import OpenAIConversionError
from providers.exceptions import InvalidRequestError
def _ensure_dict(container: dict[str, Any], key: str) -> dict[str, Any]:
value = container.get(key)
if isinstance(value, dict):
return cast(dict[str, Any], value)
nested: dict[str, Any] = {}
container[key] = nested
return nested
def _apply_thinking_config(extra_body: dict[str, Any]) -> None:
# OpenAI's SDK merges its ``extra_body`` argument into the request JSON.
# Google expects its extension fields under a literal JSON ``extra_body`` key.
literal_extra_body = _ensure_dict(extra_body, "extra_body")
google_section = _ensure_dict(literal_extra_body, "google")
thinking_cfg = _ensure_dict(google_section, "thinking_config")
thinking_cfg.setdefault("include_thoughts", True)
def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
"""Build OpenAI-format request body from an Anthropic request for Gemini."""
logger.debug(
@@ -30,16 +49,12 @@ def build_request_body(request_data: Any, *, thinking_enabled: bool) -> dict:
extra_body: dict[str, Any] = {}
request_extra = getattr(request_data, "extra_body", None)
if request_extra:
extra_body.update(request_extra)
if isinstance(request_extra, dict):
extra_body.update(deepcopy(request_extra))
if thinking_enabled:
body["reasoning_effort"] = "high"
google_section = extra_body.setdefault("google", {})
if isinstance(google_section, dict):
thinking_cfg = google_section.setdefault("thinking_config", {})
if isinstance(thinking_cfg, dict):
thinking_cfg.setdefault("include_thoughts", True)
_apply_thinking_config(extra_body)
else:
body["reasoning_effort"] = "none"
+62 -1
View File
@@ -31,6 +31,14 @@ class MockRequest:
setattr(self, key, value)
def _simulate_openai_sdk_wire_json(body: dict) -> dict:
wire = {key: value for key, value in body.items() if key != "extra_body"}
sdk_extra = body.get("extra_body")
if isinstance(sdk_extra, dict):
wire.update(sdk_extra)
return wire
@pytest.fixture
def gemini_config():
return ProviderConfig(
@@ -94,11 +102,31 @@ def test_build_request_body_basic(gemini_provider):
assert body["reasoning_effort"] == "high"
eb = body.get("extra_body")
assert isinstance(eb, dict)
gc = eb.get("google")
literal_extra_body = eb.get("extra_body")
assert isinstance(literal_extra_body, dict)
gc = literal_extra_body.get("google")
assert isinstance(gc, dict)
tc = gc.get("thinking_config")
assert isinstance(tc, dict)
assert tc.get("include_thoughts") is True
assert "google" not in eb
def test_build_request_body_sdk_wire_json_has_literal_extra_body(gemini_provider):
"""Regression for issue #542: SDK merge must not send top-level google."""
req = MockRequest()
body = gemini_provider._build_request_body(req)
wire_json = _simulate_openai_sdk_wire_json(body)
assert "google" not in wire_json
literal_extra_body = wire_json.get("extra_body")
assert isinstance(literal_extra_body, dict)
google = literal_extra_body.get("google")
assert isinstance(google, dict)
thinking_config = google.get("thinking_config")
assert isinstance(thinking_config, dict)
assert thinking_config.get("include_thoughts") is True
def test_build_request_body_global_disable_sets_reasoning_none():
@@ -128,6 +156,39 @@ def test_build_request_body_preserves_caller_extra_body(gemini_provider):
eb = body.get("extra_body")
assert isinstance(eb, dict)
assert eb.get("metadata") == {"user": "u1"}
literal_extra_body = eb.get("extra_body")
assert isinstance(literal_extra_body, dict)
google = literal_extra_body.get("google")
assert isinstance(google, dict)
def test_build_request_body_merges_caller_nested_google(gemini_provider):
req = MockRequest(
extra_body={
"metadata": {"user": "u1"},
"extra_body": {
"google": {
"thinking_config": {"budget_tokens": 128},
"cached_content": "cachedContents/example",
}
},
}
)
body = gemini_provider._build_request_body(req)
eb = body.get("extra_body")
assert isinstance(eb, dict)
assert eb.get("metadata") == {"user": "u1"}
literal_extra_body = eb.get("extra_body")
assert isinstance(literal_extra_body, dict)
google = literal_extra_body.get("google")
assert isinstance(google, dict)
assert google.get("cached_content") == "cachedContents/example"
thinking_config = google.get("thinking_config")
assert isinstance(thinking_config, dict)
assert thinking_config.get("budget_tokens") == 128
assert thinking_config.get("include_thoughts") is True
@pytest.mark.asyncio