diff --git a/src/heretic/main.py b/src/heretic/main.py index 3e6f4a6..c480888 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -377,7 +377,8 @@ def run(): print() print("Checking for common response prefix...") - responses = model.get_responses_batched(good_prompts[:100] + bad_prompts[:100]) + prefix_check_prompts = good_prompts[:100] + bad_prompts[:100] + responses = model.get_responses_batched(prefix_check_prompts) # Despite being located in os.path, commonprefix actually performs # a naive string operation without any path-specific logic, @@ -388,24 +389,39 @@ def run(): model.response_prefix = commonprefix(responses).rstrip(" ") # Suppress CoT output. - if model.response_prefix.startswith(""): - # Most thinking models. - model.response_prefix = "" - elif model.response_prefix.startswith("<|channel|>analysis<|message|>"): - # gpt-oss. - model.response_prefix = "<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>" - elif model.response_prefix.startswith(""): - # Unknown, suggested by user. - model.response_prefix = "" - elif model.response_prefix.startswith("[THINK]"): - # Unknown, suggested by user. - model.response_prefix = "[THINK][/THINK]" + recheck_prefix = False + if model.response_prefix: + # When using any of the predefined prefixes below, we need to check that + # the prefix is actually complete (e.g. not missing a trailing newline). + recheck_prefix = True + if model.response_prefix.startswith(""): + # Most thinking models. + model.response_prefix = "" + elif model.response_prefix.startswith("<|channel|>analysis<|message|>"): + # gpt-oss. + model.response_prefix = "<|channel|>analysis<|message|><|end|><|start|>assistant<|channel|>final<|message|>" + elif model.response_prefix.startswith(""): + # Unknown, suggested by user. + model.response_prefix = "" + elif model.response_prefix.startswith("[THINK]"): + # Unknown, suggested by user. + model.response_prefix = "[THINK][/THINK]" + else: + recheck_prefix = False if model.response_prefix: print(f"* Prefix found: [bold]{model.response_prefix!r}[/]") else: print("* None found") + if recheck_prefix: + print("* Rechecking with prefix...") + responses = model.get_responses_batched(prefix_check_prompts) + additional_prefix = commonprefix(responses).rstrip(" ") + if additional_prefix: + model.response_prefix += additional_prefix + print(f"* Extended prefix found: [bold]{model.response_prefix!r}[/]") + evaluator = Evaluator(settings, model) if settings.evaluate_model is not None: