Phase 7 & 8: Routes optimization refactor and request utils split

Phase 7 - Optimization handlers:
- Create api/optimization_handlers.py with try_prefix_detection,
  try_quota_mock, try_title_skip, try_suggestion_skip, try_filepath_mock
- Add try_optimizations() that runs handlers in order
- Refactor routes.create_message to use try_optimizations()
- Update test_routes_optimizations patch targets

Phase 8 - Request utils split:
- Create api/detection.py: is_quota_check_request, is_title_generation_request,
  is_prefix_detection_request, is_suggestion_mode_request,
  is_filepath_extraction_request
- Create api/command_utils.py: extract_command_prefix, extract_filepaths_from_command
- Slim request_utils.py to get_token_count + re-exports for backward compat

Co-authored-by: Ali Khokhar <alishahryar2@gmail.com>
This commit is contained in:
Cursor Agent
2026-02-15 01:41:35 +00:00
parent 0bab393c05
commit 25c7123e33
6 changed files with 450 additions and 405 deletions
+139
View File
@@ -0,0 +1,139 @@
"""Command parsing utilities for API optimizations."""
import shlex
def extract_command_prefix(command: str) -> str:
"""Extract the command prefix for fast prefix detection.
Parses a shell command safely, handling environment variables and
command injection attempts. Returns the command prefix suitable
for quick identification.
Returns:
Command prefix (e.g., "git", "git commit", "npm install")
or "none" if no valid command found
"""
if "`" in command or "$(" in command:
return "command_injection_detected"
try:
parts = shlex.split(command, posix=False)
if not parts:
return "none"
env_prefix = []
cmd_start = 0
for i, part in enumerate(parts):
if "=" in part and not part.startswith("-"):
env_prefix.append(part)
cmd_start = i + 1
else:
break
if cmd_start >= len(parts):
return "none"
cmd_parts = parts[cmd_start:]
if not cmd_parts:
return "none"
first_word = cmd_parts[0]
two_word_commands = {
"git",
"npm",
"docker",
"kubectl",
"cargo",
"go",
"pip",
"yarn",
}
if first_word in two_word_commands and len(cmd_parts) > 1:
second_word = cmd_parts[1]
if not second_word.startswith("-"):
return f"{first_word} {second_word}"
return first_word
return first_word if not env_prefix else " ".join(env_prefix) + " " + first_word
except ValueError:
return command.split()[0] if command.split() else "none"
def extract_filepaths_from_command(command: str, output: str) -> str:
"""Extract file paths from a command locally without API call.
Determines if the command reads file contents and extracts paths accordingly.
Commands like ls/dir/find just list files, so return empty.
Commands like cat/head/tail actually read contents, so extract the file path.
Returns:
Filepath extraction result in <filepaths> format
"""
listing_commands = {
"ls",
"dir",
"find",
"tree",
"pwd",
"cd",
"mkdir",
"rmdir",
"rm",
}
reading_commands = {"cat", "head", "tail", "less", "more", "bat", "type"}
try:
parts = shlex.split(command, posix=False)
if not parts:
return "<filepaths>\n</filepaths>"
base_cmd = parts[0].split("/")[-1].split("\\")[-1].lower()
if base_cmd in listing_commands:
return "<filepaths>\n</filepaths>"
if base_cmd in reading_commands:
filepaths = []
for part in parts[1:]:
if part.startswith("-"):
continue
filepaths.append(part)
if filepaths:
paths_str = "\n".join(filepaths)
return f"<filepaths>\n{paths_str}\n</filepaths>"
return "<filepaths>\n</filepaths>"
if base_cmd == "grep":
flags_with_args = {"-e", "-f", "-m", "-A", "-B", "-C"}
pattern_provided_via_flag = False
positional: list[str] = []
skip_next = False
for part in parts[1:]:
if skip_next:
skip_next = False
continue
if part.startswith("-"):
if part in flags_with_args:
if part in {"-e", "-f"}:
pattern_provided_via_flag = True
skip_next = True
continue
positional.append(part)
filepaths = positional if pattern_provided_via_flag else positional[1:]
if filepaths:
paths_str = "\n".join(filepaths)
return f"<filepaths>\n{paths_str}\n</filepaths>"
return "<filepaths>\n</filepaths>"
return "<filepaths>\n</filepaths>"
except Exception:
return "<filepaths>\n</filepaths>"
+122
View File
@@ -0,0 +1,122 @@
"""Request detection utilities for API optimizations.
Detects quota checks, title generation, prefix detection, suggestion mode,
and filepath extraction requests to enable fast-path responses.
"""
from typing import Tuple
from utils.text import extract_text_from_content
from .models.anthropic import MessagesRequest
def is_quota_check_request(request_data: MessagesRequest) -> bool:
"""Check if this is a quota probe request.
Quota checks are typically simple requests with max_tokens=1
and a single message containing the word "quota".
"""
if (
request_data.max_tokens == 1
and len(request_data.messages) == 1
and request_data.messages[0].role == "user"
):
text = extract_text_from_content(request_data.messages[0].content)
if "quota" in text.lower():
return True
return False
def is_title_generation_request(request_data: MessagesRequest) -> bool:
"""Check if this is a conversation title generation request.
Title generation requests typically contain the phrase
"write a 5-10 word title" in the user's message.
"""
if len(request_data.messages) > 0 and request_data.messages[-1].role == "user":
text = extract_text_from_content(request_data.messages[-1].content)
if "write a 5-10 word title" in text.lower():
return True
return False
def is_prefix_detection_request(request_data: MessagesRequest) -> Tuple[bool, str]:
"""Check if this is a fast prefix detection request.
Prefix detection requests contain a policy_spec block and
a Command: section for extracting shell command prefixes.
Returns:
Tuple of (is_prefix_request, command_string)
"""
if len(request_data.messages) != 1 or request_data.messages[0].role != "user":
return False, ""
content = extract_text_from_content(request_data.messages[0].content)
if "<policy_spec>" in content and "Command:" in content:
try:
cmd_start = content.rfind("Command:") + len("Command:")
return True, content[cmd_start:].strip()
except Exception:
pass
return False, ""
def is_suggestion_mode_request(request_data: MessagesRequest) -> bool:
"""Check if this is a suggestion mode request.
Suggestion mode requests contain "[SUGGESTION MODE:" in the user's message,
used for auto-suggesting what the user might type next.
"""
for msg in request_data.messages:
if msg.role == "user":
text = extract_text_from_content(msg.content)
if "[SUGGESTION MODE:" in text:
return True
return False
def is_filepath_extraction_request(
request_data: MessagesRequest,
) -> Tuple[bool, str, str]:
"""Check if this is a filepath extraction request.
Filepath extraction requests have a single user message with
"Command:" and "Output:" sections, asking to extract file paths
from command output.
Returns:
Tuple of (is_filepath_request, command, output)
"""
if len(request_data.messages) != 1 or request_data.messages[0].role != "user":
return False, "", ""
if request_data.tools:
return False, "", ""
content = extract_text_from_content(request_data.messages[0].content)
if "Command:" not in content or "Output:" not in content:
return False, "", ""
if "filepaths" not in content.lower() and "<filepaths>" not in content.lower():
return False, "", ""
try:
cmd_start = content.find("Command:") + len("Command:")
output_marker = content.find("Output:", cmd_start)
if output_marker == -1:
return False, "", ""
command = content[cmd_start:output_marker].strip()
output = content[output_marker + len("Output:") :].strip()
for marker in ["<", "\n\n"]:
if marker in output:
output = output.split(marker)[0].strip()
return True, command, output
except Exception:
return False, "", ""
+146
View File
@@ -0,0 +1,146 @@
"""Optimization handlers for fast-path API responses.
Each handler returns a MessagesResponse if the request matches and the
optimization is enabled, otherwise None.
"""
import logging
import uuid
from typing import Optional
from .models.anthropic import MessagesRequest
from .models.responses import MessagesResponse, Usage
from .detection import (
is_quota_check_request,
is_title_generation_request,
is_prefix_detection_request,
is_suggestion_mode_request,
is_filepath_extraction_request,
)
from .command_utils import extract_command_prefix, extract_filepaths_from_command
from config.settings import Settings
logger = logging.getLogger(__name__)
def try_prefix_detection(
request_data: MessagesRequest, settings: Settings
) -> Optional[MessagesResponse]:
"""Fast prefix detection - return command prefix without API call."""
if not settings.fast_prefix_detection:
return None
is_prefix_req, command = is_prefix_detection_request(request_data)
if not is_prefix_req:
return None
return MessagesResponse(
id=f"msg_{uuid.uuid4()}",
model=request_data.model,
content=[{"type": "text", "text": extract_command_prefix(command)}],
stop_reason="end_turn",
usage=Usage(input_tokens=100, output_tokens=5),
)
def try_quota_mock(
request_data: MessagesRequest, settings: Settings
) -> Optional[MessagesResponse]:
"""Mock quota probe requests."""
if not settings.enable_network_probe_mock:
return None
if not is_quota_check_request(request_data):
return None
logger.info("Optimization: Intercepted and mocked quota probe")
return MessagesResponse(
id=f"msg_{uuid.uuid4()}",
model=request_data.model,
role="assistant",
content=[{"type": "text", "text": "Quota check passed."}],
stop_reason="end_turn",
usage=Usage(input_tokens=10, output_tokens=5),
)
def try_title_skip(
request_data: MessagesRequest, settings: Settings
) -> Optional[MessagesResponse]:
"""Skip title generation requests."""
if not settings.enable_title_generation_skip:
return None
if not is_title_generation_request(request_data):
return None
logger.info("Optimization: Skipped title generation request")
return MessagesResponse(
id=f"msg_{uuid.uuid4()}",
model=request_data.model,
role="assistant",
content=[{"type": "text", "text": "Conversation"}],
stop_reason="end_turn",
usage=Usage(input_tokens=100, output_tokens=5),
)
def try_suggestion_skip(
request_data: MessagesRequest, settings: Settings
) -> Optional[MessagesResponse]:
"""Skip suggestion mode requests."""
if not settings.enable_suggestion_mode_skip:
return None
if not is_suggestion_mode_request(request_data):
return None
logger.info("Optimization: Skipped suggestion mode request")
return MessagesResponse(
id=f"msg_{uuid.uuid4()}",
model=request_data.model,
role="assistant",
content=[{"type": "text", "text": ""}],
stop_reason="end_turn",
usage=Usage(input_tokens=100, output_tokens=1),
)
def try_filepath_mock(
request_data: MessagesRequest, settings: Settings
) -> Optional[MessagesResponse]:
"""Mock filepath extraction requests."""
if not settings.enable_filepath_extraction_mock:
return None
is_fp, cmd, output = is_filepath_extraction_request(request_data)
if not is_fp:
return None
filepaths = extract_filepaths_from_command(cmd, output)
logger.info("Optimization: Mocked filepath extraction")
return MessagesResponse(
id=f"msg_{uuid.uuid4()}",
model=request_data.model,
role="assistant",
content=[{"type": "text", "text": filepaths}],
stop_reason="end_turn",
usage=Usage(input_tokens=100, output_tokens=10),
)
OPTIMIZATION_HANDLERS = [
try_prefix_detection,
try_quota_mock,
try_title_skip,
try_suggestion_skip,
try_filepath_mock,
]
def try_optimizations(
request_data: MessagesRequest, settings: Settings
) -> Optional[MessagesResponse]:
"""Run optimization handlers in order. Returns first match or None."""
for handler in OPTIMIZATION_HANDLERS:
result = handler(request_data, settings)
if result is not None:
return result
return None
+22 -325
View File
@@ -1,328 +1,37 @@
"""Request utility functions for API route handlers.
This module contains optimization functions, quota detection, title generation detection,
prefix detection, and token counting utilities.
Contains token counting and re-exports detection/command utilities.
"""
import json
import logging
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union
import tiktoken
from .models.anthropic import MessagesRequest
from utils.text import extract_text_from_content
from .detection import (
is_quota_check_request,
is_title_generation_request,
is_prefix_detection_request,
is_suggestion_mode_request,
is_filepath_extraction_request,
)
from .command_utils import extract_command_prefix, extract_filepaths_from_command
logger = logging.getLogger(__name__)
ENCODER = tiktoken.get_encoding("cl100k_base")
def is_quota_check_request(request_data: MessagesRequest) -> bool:
"""Check if this is a quota probe request.
Quota checks are typically simple requests with max_tokens=1
and a single message containing the word "quota".
Args:
request_data: The incoming request data
Returns:
True if this is a quota probe request
"""
if (
request_data.max_tokens == 1
and len(request_data.messages) == 1
and request_data.messages[0].role == "user"
):
text = extract_text_from_content(request_data.messages[0].content)
if "quota" in text.lower():
return True
return False
def is_title_generation_request(request_data: MessagesRequest) -> bool:
"""Check if this is a conversation title generation request.
Title generation requests typically contain the phrase
"write a 5-10 word title" in the user's message.
Args:
request_data: The incoming request data
Returns:
True if this is a title generation request
"""
if len(request_data.messages) > 0 and request_data.messages[-1].role == "user":
text = extract_text_from_content(request_data.messages[-1].content)
if "write a 5-10 word title" in text.lower():
return True
return False
def extract_command_prefix(command: str) -> str:
"""Extract the command prefix for fast prefix detection.
Parses a shell command safely, handling environment variables and
command injection attempts. Returns the command prefix suitable
for quick identification.
Args:
command: The command string to analyze
Returns:
Command prefix (e.g., "git", "git commit", "npm install")
or "none" if no valid command found
"""
import shlex
# Quick check for command injection patterns
if "`" in command or "$(" in command:
return "command_injection_detected"
try:
# On Windows, shlex(posix=True) treats backslashes as escapes (e.g. \t),
# which corrupts paths like C:\tmp\a.txt. posix=False preserves them.
parts = shlex.split(command, posix=False)
if not parts:
return "none"
# Handle environment variable prefixes (e.g., KEY=value command)
env_prefix = []
cmd_start = 0
for i, part in enumerate(parts):
if "=" in part and not part.startswith("-"):
env_prefix.append(part)
cmd_start = i + 1
else:
break
if cmd_start >= len(parts):
return "none"
cmd_parts = parts[cmd_start:]
if not cmd_parts:
return "none"
first_word = cmd_parts[0]
two_word_commands = {
"git",
"npm",
"docker",
"kubectl",
"cargo",
"go",
"pip",
"yarn",
}
# For compound commands, include the subcommand (e.g., "git commit")
if first_word in two_word_commands and len(cmd_parts) > 1:
second_word = cmd_parts[1]
if not second_word.startswith("-"):
return f"{first_word} {second_word}"
return first_word
return first_word if not env_prefix else " ".join(env_prefix) + " " + first_word
except ValueError:
# Fall back to simple split if shlex fails
return command.split()[0] if command.split() else "none"
def is_prefix_detection_request(request_data: MessagesRequest) -> Tuple[bool, str]:
"""Check if this is a fast prefix detection request.
Prefix detection requests contain a policy_spec block and
a Command: section for extracting shell command prefixes.
Args:
request_data: The incoming request data
Returns:
Tuple of (is_prefix_request, command_string)
"""
if len(request_data.messages) != 1 or request_data.messages[0].role != "user":
return False, ""
content = extract_text_from_content(request_data.messages[0].content)
if "<policy_spec>" in content and "Command:" in content:
try:
cmd_start = content.rfind("Command:") + len("Command:")
return True, content[cmd_start:].strip()
except Exception:
pass
return False, ""
def is_suggestion_mode_request(request_data: MessagesRequest) -> bool:
"""Check if this is a suggestion mode request.
Suggestion mode requests contain "[SUGGESTION MODE:" in the user's message,
used for auto-suggesting what the user might type next.
Args:
request_data: The incoming request data
Returns:
True if this is a suggestion mode request
"""
for msg in request_data.messages:
if msg.role == "user":
text = extract_text_from_content(msg.content)
if "[SUGGESTION MODE:" in text:
return True
return False
def is_filepath_extraction_request(
request_data: MessagesRequest,
) -> Tuple[bool, str, str]:
"""Check if this is a filepath extraction request.
Filepath extraction requests have a single user message with
"Command:" and "Output:" sections, asking to extract file paths
from command output.
Args:
request_data: The incoming request data
Returns:
Tuple of (is_filepath_request, command, output)
"""
# Must be single message, no tools
if len(request_data.messages) != 1 or request_data.messages[0].role != "user":
return False, "", ""
if request_data.tools:
return False, "", ""
content = extract_text_from_content(request_data.messages[0].content)
# Must have Command: and Output: markers
if "Command:" not in content or "Output:" not in content:
return False, "", ""
# Must ask for filepath extraction
if "filepaths" not in content.lower() and "<filepaths>" not in content.lower():
return False, "", ""
try:
# Extract command and output
cmd_start = content.find("Command:") + len("Command:")
output_marker = content.find("Output:", cmd_start)
if output_marker == -1:
return False, "", ""
command = content[cmd_start:output_marker].strip()
output = content[output_marker + len("Output:") :].strip()
# Clean up output - stop at next section marker if present
for marker in ["<", "\n\n"]:
if marker in output:
output = output.split(marker)[0].strip()
return True, command, output
except Exception:
return False, "", ""
def extract_filepaths_from_command(command: str, output: str) -> str:
"""Extract file paths from a command locally without API call.
Determines if the command reads file contents and extracts paths accordingly.
Commands like ls/dir/find just list files, so return empty.
Commands like cat/head/tail actually read contents, so extract the file path.
Args:
command: The shell command that was executed
output: The command's output
Returns:
Filepath extraction result in <filepaths> format
"""
import shlex
# Commands that just list files (don't read contents)
listing_commands = {
"ls",
"dir",
"find",
"tree",
"pwd",
"cd",
"mkdir",
"rmdir",
"rm",
}
# Commands that read file contents
reading_commands = {"cat", "head", "tail", "less", "more", "bat", "type"}
try:
# Use Windows-style splitting to preserve backslashes in paths (e.g. C:\tmp\a.txt).
parts = shlex.split(command, posix=False)
if not parts:
return "<filepaths>\n</filepaths>"
# Get base command (handle paths like /bin/cat)
base_cmd = parts[0].split("/")[-1].split("\\")[-1].lower()
# Listing commands - return empty
if base_cmd in listing_commands:
return "<filepaths>\n</filepaths>"
# Reading commands - extract file arguments
if base_cmd in reading_commands:
filepaths = []
for part in parts[1:]:
# Skip flags
if part.startswith("-"):
continue
# This is likely a file path
filepaths.append(part)
if filepaths:
paths_str = "\n".join(filepaths)
return f"<filepaths>\n{paths_str}\n</filepaths>"
return "<filepaths>\n</filepaths>"
# grep with file argument
if base_cmd == "grep":
# Basic parsing:
# - Skip flags (and args for flags that take an argument)
# - If -e/-f is used, pattern is provided via flag so all remaining
# positional args are treated as file paths.
# - Otherwise, first positional arg is pattern, remainder are file paths.
flags_with_args = {"-e", "-f", "-m", "-A", "-B", "-C"}
pattern_provided_via_flag = False
positional: list[str] = []
skip_next = False
for part in parts[1:]:
if skip_next:
skip_next = False
continue
if part.startswith("-"):
if part in flags_with_args:
if part in {"-e", "-f"}:
pattern_provided_via_flag = True
skip_next = True
continue
positional.append(part)
filepaths = positional if pattern_provided_via_flag else positional[1:]
if filepaths:
paths_str = "\n".join(filepaths)
return f"<filepaths>\n{paths_str}\n</filepaths>"
return "<filepaths>\n</filepaths>"
# Default - return empty for unknown commands
return "<filepaths>\n</filepaths>"
except Exception:
return "<filepaths>\n</filepaths>"
__all__ = [
"is_quota_check_request",
"is_title_generation_request",
"is_prefix_detection_request",
"is_suggestion_mode_request",
"is_filepath_extraction_request",
"extract_command_prefix",
"extract_filepaths_from_command",
"get_token_count",
]
def get_token_count(
@@ -334,18 +43,9 @@ def get_token_count(
Uses tiktoken cl100k_base encoding to estimate token usage.
Includes system prompt, messages, tools, and per-message overhead.
Args:
messages: List of message objects with content
system: Optional system prompt (str or list of blocks)
tools: Optional list of tool definitions
Returns:
Estimated total token count
"""
total_tokens = 0
# Count system prompt tokens
if system:
if isinstance(system, str):
total_tokens += len(ENCODER.encode(system))
@@ -354,7 +54,6 @@ def get_token_count(
if hasattr(block, "text"):
total_tokens += len(ENCODER.encode(block.text))
# Count message tokens
for msg in messages:
if isinstance(msg.content, str):
total_tokens += len(ENCODER.encode(msg.content))
@@ -371,16 +70,15 @@ def get_token_count(
inp = getattr(block, "input", {})
total_tokens += len(ENCODER.encode(name))
total_tokens += len(ENCODER.encode(json.dumps(inp)))
total_tokens += 10 # Tool use overhead
total_tokens += 10
elif b_type == "tool_result":
content = getattr(block, "content", "")
if isinstance(content, str):
total_tokens += len(ENCODER.encode(content))
else:
total_tokens += len(ENCODER.encode(json.dumps(content)))
total_tokens += 5 # Tool result overhead
total_tokens += 5
# Count tool definition tokens
if tools:
for tool in tools:
tool_str = (
@@ -388,7 +86,6 @@ def get_token_count(
)
total_tokens += len(ENCODER.encode(tool_str))
# Add per-message overhead
total_tokens += len(messages) * 3
if tools:
total_tokens += len(tools) * 5
+6 -76
View File
@@ -7,18 +7,10 @@ from fastapi import APIRouter, Request, Depends, HTTPException
from fastapi.responses import StreamingResponse
from .models.anthropic import MessagesRequest, TokenCountRequest
from .models.responses import MessagesResponse, TokenCountResponse, Usage
from .models.responses import MessagesResponse, TokenCountResponse
from .dependencies import get_provider, get_settings
from .request_utils import (
is_quota_check_request,
is_title_generation_request,
is_prefix_detection_request,
is_suggestion_mode_request,
is_filepath_extraction_request,
extract_command_prefix,
extract_filepaths_from_command,
get_token_count,
)
from .request_utils import get_token_count
from .optimization_handlers import try_optimizations
from config.settings import Settings
from providers.base import BaseProvider
from providers.exceptions import ProviderError
@@ -44,71 +36,9 @@ async def create_message(
"""Create a message (streaming or non-streaming)."""
try:
if settings.fast_prefix_detection:
is_prefix_req, command = is_prefix_detection_request(request_data)
if is_prefix_req:
return MessagesResponse(
id=f"msg_{uuid.uuid4()}",
model=request_data.model,
content=[{"type": "text", "text": extract_command_prefix(command)}],
stop_reason="end_turn",
usage=Usage(input_tokens=100, output_tokens=5),
)
# Optimization: Mock network probe/quota requests
if settings.enable_network_probe_mock and is_quota_check_request(request_data):
logger.info("Optimization: Intercepted and mocked quota probe")
return MessagesResponse(
id=f"msg_{uuid.uuid4()}",
model=request_data.model,
role="assistant",
content=[{"type": "text", "text": "Quota check passed."}],
stop_reason="end_turn",
usage=Usage(input_tokens=10, output_tokens=5),
)
# Optimization: Skip title generation requests
if settings.enable_title_generation_skip and is_title_generation_request(
request_data
):
logger.info("Optimization: Skipped title generation request")
return MessagesResponse(
id=f"msg_{uuid.uuid4()}",
model=request_data.model,
role="assistant",
content=[{"type": "text", "text": "Conversation"}],
stop_reason="end_turn",
usage=Usage(input_tokens=100, output_tokens=5),
)
# Optimization: Skip suggestion mode requests
if settings.enable_suggestion_mode_skip and is_suggestion_mode_request(
request_data
):
logger.info("Optimization: Skipped suggestion mode request")
return MessagesResponse(
id=f"msg_{uuid.uuid4()}",
model=request_data.model,
role="assistant",
content=[{"type": "text", "text": ""}],
stop_reason="end_turn",
usage=Usage(input_tokens=100, output_tokens=1),
)
# Optimization: Mock filepath extraction requests
if settings.enable_filepath_extraction_mock:
is_fp, cmd, output = is_filepath_extraction_request(request_data)
if is_fp:
filepaths = extract_filepaths_from_command(cmd, output)
logger.info("Optimization: Mocked filepath extraction")
return MessagesResponse(
id=f"msg_{uuid.uuid4()}",
model=request_data.model,
role="assistant",
content=[{"type": "text", "text": filepaths}],
stop_reason="end_turn",
usage=Usage(input_tokens=100, output_tokens=10),
)
optimized = try_optimizations(request_data, settings)
if optimized is not None:
return optimized
request_id = f"req_{uuid.uuid4().hex[:12]}"
log_request_compact(logger, request_id, request_data)
+15 -4
View File
@@ -29,8 +29,15 @@ def test_create_message_fast_prefix_detection(client, mock_settings):
"messages": [{"role": "user", "content": "What is the prefix?"}],
}
with patch("api.routes.is_prefix_detection_request", return_value=(True, "/ask")):
response = client.post("/v1/messages", json=payload)
with patch(
"api.optimization_handlers.is_prefix_detection_request",
return_value=(True, "/ask"),
):
with patch(
"api.optimization_handlers.extract_command_prefix",
return_value="/ask",
):
response = client.post("/v1/messages", json=payload)
assert response.status_code == 200
data = response.json()
@@ -48,7 +55,9 @@ def test_create_message_quota_check_mock(client, mock_settings):
"messages": [{"role": "user", "content": "quota check"}],
}
with patch("api.routes.is_quota_check_request", return_value=True):
with patch(
"api.optimization_handlers.is_quota_check_request", return_value=True
):
response = client.post("/v1/messages", json=payload)
assert response.status_code == 200
@@ -66,7 +75,9 @@ def test_create_message_title_generation_skip(client, mock_settings):
"messages": [{"role": "user", "content": "generate title"}],
}
with patch("api.routes.is_title_generation_request", return_value=True):
with patch(
"api.optimization_handlers.is_title_generation_request", return_value=True
):
response = client.post("/v1/messages", json=payload)
assert response.status_code == 200