added more tests

This commit is contained in:
Alishahryar1
2026-01-28 14:47:37 -08:00
parent b8e0360b37
commit ba2159340a
3 changed files with 117 additions and 16 deletions
+15 -16
View File
@@ -15,7 +15,7 @@ import json
import logging
import uuid
from typing import List, Dict, Any, Optional, Union, Literal
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, field_validator, model_validator
from providers.nvidia_nim import NvidiaNimProvider, ProviderConfig
from providers.exceptions import ProviderError
import uvicorn
@@ -124,30 +124,26 @@ class MessagesRequest(BaseModel):
extra_body: Optional[Dict[str, Any]] = None
original_model: Optional[str] = None
@field_validator("model")
@classmethod
def validate_model_field(cls, v, info):
original_model = v
clean_v = v
@model_validator(mode="after")
def map_model(self) -> "MessagesRequest":
if self.original_model is None:
self.original_model = self.model
clean_v = self.model
for prefix in ["anthropic/", "openai/", "gemini/"]:
if clean_v.startswith(prefix):
clean_v = clean_v[len(prefix) :]
break
if "haiku" in clean_v.lower():
new_model = SMALL_MODEL
self.model = SMALL_MODEL
elif "sonnet" in clean_v.lower() or "opus" in clean_v.lower():
new_model = BIG_MODEL
else:
new_model = v
self.model = BIG_MODEL
if new_model != original_model:
logger.debug(f"MODEL MAPPING: '{original_model}' -> '{new_model}'")
if self.model != self.original_model:
logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'")
if isinstance(info.data, dict):
info.data["original_model"] = original_model
return new_model
return self
class TokenCountRequest(BaseModel):
@@ -464,6 +460,9 @@ async def create_message(
response_json = await provider.complete(request_data)
return provider.convert_response(response_json, request_data)
except ProviderError:
# Re-raise ProviderError to be handled by the specialized exception handler
raise
except Exception as e:
import traceback
+48
View File
@@ -53,3 +53,51 @@ def test_create_message_non_stream():
assert response.status_code == 200
assert response.json()["content"][0]["text"] == "Hello"
mock_provider.complete.assert_called_once()
def test_model_mapping():
# Test Haiku mapping
payload_haiku = {
"model": "claude-3-haiku-20240307",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 100,
}
client.post("/v1/messages", json=payload_haiku)
# The actual call to provider should use the mapped model
args, _ = mock_provider.complete.call_args
# It should not be the original model
assert args[0].model != "claude-3-haiku-20240307"
# It should have original_model set
assert args[0].original_model == "claude-3-haiku-20240307"
def test_error_fallbacks():
from providers.exceptions import (
AuthenticationError,
RateLimitError,
OverloadedError,
)
# 1. Authentication Error (401)
mock_provider.complete.side_effect = AuthenticationError("Invalid Key")
response = client.post(
"/v1/messages", json={"model": "test", "messages": [], "max_tokens": 10}
)
assert response.status_code == 401
assert response.json()["error"]["type"] == "authentication_error"
# 2. Rate Limit (429)
mock_provider.complete.side_effect = RateLimitError("Too Many Requests")
response = client.post(
"/v1/messages", json={"model": "test", "messages": [], "max_tokens": 10}
)
assert response.status_code == 429
assert response.json()["error"]["type"] == "rate_limit_error"
# 3. Overloaded (529)
mock_provider.complete.side_effect = OverloadedError("Server Overloaded")
response = client.post(
"/v1/messages", json={"model": "test", "messages": [], "max_tokens": 10}
)
assert response.status_code == 529
assert response.json()["error"]["type"] == "overloaded_error"
+54
View File
@@ -71,3 +71,57 @@ def test_heuristic_tool_parser_flush():
assert len(tools) == 1
assert tools[0]["name"] == "Bash"
assert tools[0]["input"] == {"command": "ls -la"}
def test_interleaved_thinking_and_tools():
parser_think = ThinkTagParser()
parser_tool = HeuristicToolParser()
text = "<think>I need to search for a file.</think> ● <function=Grep><parameter=pattern>test</parameter>"
# 1. Parse thinking
chunks = list(parser_think.feed(text))
thinking = [c for c in chunks if c.type == ContentType.THINKING]
text_remaining = "".join([c.content for c in chunks if c.type == ContentType.TEXT])
assert len(thinking) == 1
assert thinking[0].content == "I need to search for a file."
# 2. Parse tool from remaining text
filtered, tools = parser_tool.feed(text_remaining)
tools += parser_tool.flush()
assert len(tools) == 1
assert tools[0]["name"] == "Grep"
assert tools[0]["input"] == {"pattern": "test"}
def test_partial_interleaved_streaming():
parser_think = ThinkTagParser()
parser_tool = HeuristicToolParser()
# Chunk 1: Partial thinking (it emits since it's definitely not the start of <think>)
chunks1 = list(parser_think.feed("<think>Part 1"))
assert len(chunks1) == 1
assert chunks1[0].type == ContentType.THINKING
assert chunks1[0].content == "Part 1"
# Chunk 2: Thinking ends, tool starts
chunks2 = list(parser_think.feed(" ends</think> ● <func"))
assert len(chunks2) == 2
assert chunks2[0].type == ContentType.THINKING
assert chunks2[0].content == " ends"
text_rem = chunks2[1].content
filtered, tools = parser_tool.feed(text_rem)
assert tools == []
# Chunk 3: Tool ends
chunks3 = list(parser_think.feed("tion=Read><parameter=path>test.py</parameter>"))
text_rem3 = "".join([c.content for c in chunks3])
filtered3, tools3 = parser_tool.feed(text_rem3)
tools3 += parser_tool.flush()
assert len(tools3) == 1
assert tools3[0]["name"] == "Read"
assert tools3[0]["input"] == {"path": "test.py"}