From 9d1734855d48ad42ae5f4b21cdd52e95727fb035 Mon Sep 17 00:00:00 2001 From: Spiky Moth Date: Sun, 14 Dec 2025 09:56:48 +0100 Subject: [PATCH] feat: avoid excessive low divergence iteration (#73) * feat: adjust scoring to avoid useless iteration Adjusts the scoring function to avoid targeting meaninglessly low KL divergences. Below a threshold value, the KL divergence score switches to the refusal count. Adds config option kl_divergence_target (defaulting to 0.01). * fix: Clean up parameter selection in objective Create variables for num_layers and last_layer_index * Improves readability and makes choices explicit * feat: Print the parameters of the selected model --- config.default.toml | 4 ++++ src/heretic/config.py | 8 ++++++++ src/heretic/evaluator.py | 14 ++++++++++++-- src/heretic/main.py | 38 +++++++++++++++++++++++++++++--------- 4 files changed, 53 insertions(+), 11 deletions(-) diff --git a/config.default.toml b/config.default.toml index 18e0c6a..2e0b861 100644 --- a/config.default.toml +++ b/config.default.toml @@ -49,6 +49,10 @@ residual_plot_style = "dark_background" # This is used to ensure balanced co-optimization of KL divergence and refusal count. kl_divergence_scale = 1.0 +# The KL divergence to target. Below this value, an objective based on the refusal count is used. +# This helps prevent the sampler from extensively exploring parameter combinations that "do nothing". +kl_divergence_target = 0.01 + # Number of abliteration trials to run during optimization. n_trials = 200 diff --git a/src/heretic/config.py b/src/heretic/config.py index c786349..6819779 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -119,6 +119,14 @@ class Settings(BaseSettings): ), ) + kl_divergence_target: float = Field( + default=0.01, + description=( + "The KL divergence to target. Below this value, an objective based on the refusal count is used." + 'This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".' + ), + ) + n_trials: int = Field( default=200, description="Number of abliteration trials to run during optimization.", diff --git a/src/heretic/evaluator.py b/src/heretic/evaluator.py index eb91038..7306cbf 100644 --- a/src/heretic/evaluator.py +++ b/src/heretic/evaluator.py @@ -76,9 +76,19 @@ class Evaluator: refusals = self.count_refusals() print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}") + kl_divergence_scale = self.settings.kl_divergence_scale + kl_divergence_target = self.settings.kl_divergence_target + + refusals_score = refusals / self.base_refusals + + if kl_divergence >= kl_divergence_target: + kld_score = kl_divergence / kl_divergence_scale + else: + kld_score = refusals_score * kl_divergence_target / kl_divergence_scale + score = ( - (kl_divergence / self.settings.kl_divergence_scale), - (refusals / self.base_refusals), + kld_score, + refusals_score, ) return score, kl_divergence, refusals diff --git a/src/heretic/main.py b/src/heretic/main.py index 8c187eb..3628df1 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -27,6 +27,7 @@ from optuna import Trial, TrialPruned from optuna.exceptions import ExperimentalWarning from optuna.samplers import TPESampler from optuna.study import StudyDirection +from optuna.trial import TrialState from pydantic import ValidationError from questionary import Choice from rich.traceback import install @@ -264,6 +265,8 @@ def run(): ], ) + last_layer_index = len(model.get_layers()) - 1 + # Discrimination between "harmful" and "harmless" inputs is usually strongest # in layers slightly past the midpoint of the layer stack. See the original # abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis. @@ -273,8 +276,8 @@ def run(): # work with conditional or variable-range parameters. direction_index = trial.suggest_float( "direction_index", - 0.4 * (len(model.get_layers()) - 1), - 0.9 * (len(model.get_layers()) - 1), + 0.4 * last_layer_index, + 0.9 * last_layer_index, ) if direction_scope == "per layer": @@ -293,8 +296,8 @@ def run(): ) max_weight_position = trial.suggest_float( f"{component}.max_weight_position", - 0.6 * (len(model.get_layers()) - 1), - len(model.get_layers()) - 1, + 0.6 * last_layer_index, + 1.0 * last_layer_index, ) # For sampling purposes, min_weight is expressed as a fraction of max_weight, # again because multivariate TPE doesn't support variable-range parameters. @@ -307,7 +310,7 @@ def run(): min_weight_distance = trial.suggest_float( f"{component}.min_weight_distance", 1.0, - 0.6 * (len(model.get_layers()) - 1), + 0.6 * last_layer_index, ) parameters[component] = AbliterationParameters( @@ -378,13 +381,27 @@ def run(): # If no trials at all have been evaluated, the study must have been stopped # by pressing Ctrl+C while the first trial was running. In this case, we just # re-raise the interrupt to invoke the standard handler defined below. - if not study.best_trials: + completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE] + if not completed_trials: raise KeyboardInterrupt - best_trials = sorted( - study.best_trials, - key=lambda trial: trial.user_attrs["refusals"], + # Get the Pareto front of trials. We can't use study.best_trials directly + # as get_score() doesn't return the pure KL divergence and refusal count. + # Note: Unlike study.best_trials, this does not handle objective constraints. + sorted_trials = sorted( + completed_trials, + key=lambda trial: ( + trial.user_attrs["refusals"], + trial.user_attrs["kl_divergence"], + ), ) + min_divergence = math.inf + best_trials = [] + for trial in sorted_trials: + kl_divergence = trial.user_attrs["kl_divergence"] + if kl_divergence < min_divergence: + min_divergence = kl_divergence + best_trials.append(trial) choices = [ Choice( @@ -426,6 +443,9 @@ def run(): print() print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...") + print("* Parameters:") + for name, value in get_trial_parameters(trial).items(): + print(f" * {name} = [bold]{value}[/]") print("* Reloading model...") model.reload_model() print("* Abliterating...")