feat: make response prefix logic configurable

This commit is contained in:
Philipp Emanuel Weidmann
2026-04-07 13:24:48 +05:30
parent f612a48b9f
commit b08a0925c1
3 changed files with 80 additions and 48 deletions
+39
View File
@@ -142,6 +142,45 @@ class Settings(BaseSettings):
description="Maximum number of tokens to generate for each response.",
)
response_prefix: str | None = Field(
default=None,
description=(
"Common prefix to assume for all responses, so that evaluation happens "
"at the point where responses start to differ for different prompts. "
"If not set, the prefix is determined automatically by comparing multiple responses."
),
)
chain_of_thought_skips: list[tuple[str, str]] = Field(
default=[
# Most thinking models.
(
"<think>",
"<think></think>",
),
# gpt-oss.
(
"<|channel|>analysis<|message|>",
"<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>",
),
# Unknown, suggested by user.
(
"<thought>",
"<thought></thought>",
),
# Unknown, suggested by user.
(
"[THINK]",
"[THINK][/THINK]",
),
],
description=(
"List of pairs of the form (cot_initializer, closed_cot_block) used to skip "
"the Chain-of-Thought block in responses, so that evaluation happens "
"at the start of the actual response."
),
)
print_responses: bool = Field(
default=False,
description="Whether to print prompt/response pairs when counting refusals.",
+37 -45
View File
@@ -393,52 +393,44 @@ def run():
settings.batch_size = best_batch_size
print(f"* Chosen batch size: [bold]{settings.batch_size}[/]")
print()
print("Checking for common response prefix...")
prefix_check_prompts = good_prompts[:100] + bad_prompts[:100]
responses = model.get_responses_batched(prefix_check_prompts)
# Despite being located in os.path, commonprefix actually performs
# a naive string operation without any path-specific logic,
# which is exactly what we need here. Trailing spaces are removed
# to avoid issues where multiple different tokens that all start
# with a space character lead to the common prefix ending with
# a space, which would result in an uncommon tokenization.
model.response_prefix = commonprefix(responses).rstrip(" ")
# Suppress CoT output.
recheck_prefix = False
if model.response_prefix:
# When using any of the predefined prefixes below, we need to check that
# the prefix is actually complete (e.g. not missing a trailing newline).
recheck_prefix = True
if model.response_prefix.startswith("<think>"):
# Most thinking models.
model.response_prefix = "<think></think>"
elif model.response_prefix.startswith("<|channel|>analysis<|message|>"):
# gpt-oss.
model.response_prefix = "<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>"
elif model.response_prefix.startswith("<thought>"):
# Unknown, suggested by user.
model.response_prefix = "<thought></thought>"
elif model.response_prefix.startswith("[THINK]"):
# Unknown, suggested by user.
model.response_prefix = "[THINK][/THINK]"
else:
recheck_prefix = False
if model.response_prefix:
print(f"* Prefix found: [bold]{model.response_prefix!r}[/]")
else:
print("* None found")
if recheck_prefix:
print("* Rechecking with prefix...")
if settings.response_prefix is None:
print()
print("Checking for common response prefix...")
prefix_check_prompts = good_prompts[:100] + bad_prompts[:100]
responses = model.get_responses_batched(prefix_check_prompts)
additional_prefix = commonprefix(responses).rstrip(" ")
if additional_prefix:
model.response_prefix += additional_prefix
print(f"* Extended prefix found: [bold]{model.response_prefix!r}[/]")
# Despite being located in os.path, commonprefix actually performs
# a naive string operation without any path-specific logic,
# which is exactly what we need here. Trailing spaces are removed
# to avoid issues where multiple different tokens that all start
# with a space character lead to the common prefix ending with
# a space, which would result in an uncommon tokenization.
settings.response_prefix = commonprefix(responses).rstrip(" ")
if settings.response_prefix:
print(f"* Prefix found: [bold]{settings.response_prefix!r}[/]")
for cot_initializer, closed_cot_block in settings.chain_of_thought_skips:
if settings.response_prefix.startswith(cot_initializer):
settings.response_prefix = closed_cot_block
print(
f"* Closed Chain-of-Thought block: [bold]{settings.response_prefix!r}[/]"
)
# When using a Chain-of-Thought skip, we need to check that the prefix
# is actually complete (e.g. not missing a trailing newline).
print("* Rechecking with prefix...")
responses = model.get_responses_batched(prefix_check_prompts)
additional_prefix = commonprefix(responses).rstrip(" ")
if additional_prefix:
settings.response_prefix += additional_prefix
print(
f"* Extended prefix found: [bold]{settings.response_prefix!r}[/]"
)
break
else:
print("* None found")
evaluator = Evaluator(settings, model)
+4 -3
View File
@@ -59,7 +59,6 @@ class Model:
def __init__(self, settings: Settings):
self.settings = settings
self.response_prefix = ""
self.needs_reload = False
print()
@@ -565,10 +564,12 @@ class Model:
),
)
if self.response_prefix:
if self.settings.response_prefix:
# Append the common response prefix to the prompts so that evaluation happens
# at the point where responses start to differ for different prompts.
chat_prompts = [prompt + self.response_prefix for prompt in chat_prompts]
chat_prompts = [
prompt + self.settings.response_prefix for prompt in chat_prompts
]
inputs = self.tokenizer(
chat_prompts,