mirror of
https://github.com/Alishahryar1/free-claude-code.git
synced 2026-06-02 06:13:46 +02:00
added more tests
This commit is contained in:
@@ -15,7 +15,7 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional, Union, Literal
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
from providers.nvidia_nim import NvidiaNimProvider, ProviderConfig
|
||||
from providers.exceptions import ProviderError
|
||||
import uvicorn
|
||||
@@ -124,30 +124,26 @@ class MessagesRequest(BaseModel):
|
||||
extra_body: Optional[Dict[str, Any]] = None
|
||||
original_model: Optional[str] = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model_field(cls, v, info):
|
||||
original_model = v
|
||||
clean_v = v
|
||||
@model_validator(mode="after")
|
||||
def map_model(self) -> "MessagesRequest":
|
||||
if self.original_model is None:
|
||||
self.original_model = self.model
|
||||
|
||||
clean_v = self.model
|
||||
for prefix in ["anthropic/", "openai/", "gemini/"]:
|
||||
if clean_v.startswith(prefix):
|
||||
clean_v = clean_v[len(prefix) :]
|
||||
break
|
||||
|
||||
if "haiku" in clean_v.lower():
|
||||
new_model = SMALL_MODEL
|
||||
self.model = SMALL_MODEL
|
||||
elif "sonnet" in clean_v.lower() or "opus" in clean_v.lower():
|
||||
new_model = BIG_MODEL
|
||||
else:
|
||||
new_model = v
|
||||
self.model = BIG_MODEL
|
||||
|
||||
if new_model != original_model:
|
||||
logger.debug(f"MODEL MAPPING: '{original_model}' -> '{new_model}'")
|
||||
if self.model != self.original_model:
|
||||
logger.debug(f"MODEL MAPPING: '{self.original_model}' -> '{self.model}'")
|
||||
|
||||
if isinstance(info.data, dict):
|
||||
info.data["original_model"] = original_model
|
||||
|
||||
return new_model
|
||||
return self
|
||||
|
||||
|
||||
class TokenCountRequest(BaseModel):
|
||||
@@ -464,6 +460,9 @@ async def create_message(
|
||||
response_json = await provider.complete(request_data)
|
||||
return provider.convert_response(response_json, request_data)
|
||||
|
||||
except ProviderError:
|
||||
# Re-raise ProviderError to be handled by the specialized exception handler
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
|
||||
@@ -53,3 +53,51 @@ def test_create_message_non_stream():
|
||||
assert response.status_code == 200
|
||||
assert response.json()["content"][0]["text"] == "Hello"
|
||||
mock_provider.complete.assert_called_once()
|
||||
|
||||
|
||||
def test_model_mapping():
|
||||
# Test Haiku mapping
|
||||
payload_haiku = {
|
||||
"model": "claude-3-haiku-20240307",
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"max_tokens": 100,
|
||||
}
|
||||
client.post("/v1/messages", json=payload_haiku)
|
||||
# The actual call to provider should use the mapped model
|
||||
args, _ = mock_provider.complete.call_args
|
||||
# It should not be the original model
|
||||
assert args[0].model != "claude-3-haiku-20240307"
|
||||
# It should have original_model set
|
||||
assert args[0].original_model == "claude-3-haiku-20240307"
|
||||
|
||||
|
||||
def test_error_fallbacks():
|
||||
from providers.exceptions import (
|
||||
AuthenticationError,
|
||||
RateLimitError,
|
||||
OverloadedError,
|
||||
)
|
||||
|
||||
# 1. Authentication Error (401)
|
||||
mock_provider.complete.side_effect = AuthenticationError("Invalid Key")
|
||||
response = client.post(
|
||||
"/v1/messages", json={"model": "test", "messages": [], "max_tokens": 10}
|
||||
)
|
||||
assert response.status_code == 401
|
||||
assert response.json()["error"]["type"] == "authentication_error"
|
||||
|
||||
# 2. Rate Limit (429)
|
||||
mock_provider.complete.side_effect = RateLimitError("Too Many Requests")
|
||||
response = client.post(
|
||||
"/v1/messages", json={"model": "test", "messages": [], "max_tokens": 10}
|
||||
)
|
||||
assert response.status_code == 429
|
||||
assert response.json()["error"]["type"] == "rate_limit_error"
|
||||
|
||||
# 3. Overloaded (529)
|
||||
mock_provider.complete.side_effect = OverloadedError("Server Overloaded")
|
||||
response = client.post(
|
||||
"/v1/messages", json={"model": "test", "messages": [], "max_tokens": 10}
|
||||
)
|
||||
assert response.status_code == 529
|
||||
assert response.json()["error"]["type"] == "overloaded_error"
|
||||
|
||||
@@ -71,3 +71,57 @@ def test_heuristic_tool_parser_flush():
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["name"] == "Bash"
|
||||
assert tools[0]["input"] == {"command": "ls -la"}
|
||||
|
||||
|
||||
def test_interleaved_thinking_and_tools():
|
||||
parser_think = ThinkTagParser()
|
||||
parser_tool = HeuristicToolParser()
|
||||
|
||||
text = "<think>I need to search for a file.</think> ● <function=Grep><parameter=pattern>test</parameter>"
|
||||
|
||||
# 1. Parse thinking
|
||||
chunks = list(parser_think.feed(text))
|
||||
thinking = [c for c in chunks if c.type == ContentType.THINKING]
|
||||
text_remaining = "".join([c.content for c in chunks if c.type == ContentType.TEXT])
|
||||
|
||||
assert len(thinking) == 1
|
||||
assert thinking[0].content == "I need to search for a file."
|
||||
|
||||
# 2. Parse tool from remaining text
|
||||
filtered, tools = parser_tool.feed(text_remaining)
|
||||
tools += parser_tool.flush()
|
||||
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["name"] == "Grep"
|
||||
assert tools[0]["input"] == {"pattern": "test"}
|
||||
|
||||
|
||||
def test_partial_interleaved_streaming():
|
||||
parser_think = ThinkTagParser()
|
||||
parser_tool = HeuristicToolParser()
|
||||
|
||||
# Chunk 1: Partial thinking (it emits since it's definitely not the start of <think>)
|
||||
chunks1 = list(parser_think.feed("<think>Part 1"))
|
||||
assert len(chunks1) == 1
|
||||
assert chunks1[0].type == ContentType.THINKING
|
||||
assert chunks1[0].content == "Part 1"
|
||||
|
||||
# Chunk 2: Thinking ends, tool starts
|
||||
chunks2 = list(parser_think.feed(" ends</think> ● <func"))
|
||||
assert len(chunks2) == 2
|
||||
assert chunks2[0].type == ContentType.THINKING
|
||||
assert chunks2[0].content == " ends"
|
||||
|
||||
text_rem = chunks2[1].content
|
||||
filtered, tools = parser_tool.feed(text_rem)
|
||||
assert tools == []
|
||||
|
||||
# Chunk 3: Tool ends
|
||||
chunks3 = list(parser_think.feed("tion=Read><parameter=path>test.py</parameter>"))
|
||||
text_rem3 = "".join([c.content for c in chunks3])
|
||||
filtered3, tools3 = parser_tool.feed(text_rem3)
|
||||
tools3 += parser_tool.flush()
|
||||
|
||||
assert len(tools3) == 1
|
||||
assert tools3[0]["name"] == "Read"
|
||||
assert tools3[0]["input"] == {"path": "test.py"}
|
||||
|
||||
Reference in New Issue
Block a user