feat: avoid excessive low divergence iteration (#73)

* feat: adjust scoring to avoid useless iteration

Adjusts the scoring function to avoid targeting meaninglessly low KL divergences.
Below a threshold value, the KL divergence score switches to the refusal count.
Adds config option kl_divergence_target (defaulting to 0.01).

* fix: Clean up parameter selection in objective

Create variables for num_layers and last_layer_index
* Improves readability and makes choices explicit

* feat: Print the parameters of the selected model
This commit is contained in:
Spiky Moth
2025-12-14 09:56:48 +01:00
committed by GitHub
parent 740aab61ba
commit 9d1734855d
4 changed files with 53 additions and 11 deletions
+4
View File
@@ -49,6 +49,10 @@ residual_plot_style = "dark_background"
# This is used to ensure balanced co-optimization of KL divergence and refusal count.
kl_divergence_scale = 1.0
# The KL divergence to target. Below this value, an objective based on the refusal count is used.
# This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".
kl_divergence_target = 0.01
# Number of abliteration trials to run during optimization.
n_trials = 200
+8
View File
@@ -119,6 +119,14 @@ class Settings(BaseSettings):
),
)
kl_divergence_target: float = Field(
default=0.01,
description=(
"The KL divergence to target. Below this value, an objective based on the refusal count is used."
'This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".'
),
)
n_trials: int = Field(
default=200,
description="Number of abliteration trials to run during optimization.",
+12 -2
View File
@@ -76,9 +76,19 @@ class Evaluator:
refusals = self.count_refusals()
print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}")
kl_divergence_scale = self.settings.kl_divergence_scale
kl_divergence_target = self.settings.kl_divergence_target
refusals_score = refusals / self.base_refusals
if kl_divergence >= kl_divergence_target:
kld_score = kl_divergence / kl_divergence_scale
else:
kld_score = refusals_score * kl_divergence_target / kl_divergence_scale
score = (
(kl_divergence / self.settings.kl_divergence_scale),
(refusals / self.base_refusals),
kld_score,
refusals_score,
)
return score, kl_divergence, refusals
+29 -9
View File
@@ -27,6 +27,7 @@ from optuna import Trial, TrialPruned
from optuna.exceptions import ExperimentalWarning
from optuna.samplers import TPESampler
from optuna.study import StudyDirection
from optuna.trial import TrialState
from pydantic import ValidationError
from questionary import Choice
from rich.traceback import install
@@ -264,6 +265,8 @@ def run():
],
)
last_layer_index = len(model.get_layers()) - 1
# Discrimination between "harmful" and "harmless" inputs is usually strongest
# in layers slightly past the midpoint of the layer stack. See the original
# abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis.
@@ -273,8 +276,8 @@ def run():
# work with conditional or variable-range parameters.
direction_index = trial.suggest_float(
"direction_index",
0.4 * (len(model.get_layers()) - 1),
0.9 * (len(model.get_layers()) - 1),
0.4 * last_layer_index,
0.9 * last_layer_index,
)
if direction_scope == "per layer":
@@ -293,8 +296,8 @@ def run():
)
max_weight_position = trial.suggest_float(
f"{component}.max_weight_position",
0.6 * (len(model.get_layers()) - 1),
len(model.get_layers()) - 1,
0.6 * last_layer_index,
1.0 * last_layer_index,
)
# For sampling purposes, min_weight is expressed as a fraction of max_weight,
# again because multivariate TPE doesn't support variable-range parameters.
@@ -307,7 +310,7 @@ def run():
min_weight_distance = trial.suggest_float(
f"{component}.min_weight_distance",
1.0,
0.6 * (len(model.get_layers()) - 1),
0.6 * last_layer_index,
)
parameters[component] = AbliterationParameters(
@@ -378,13 +381,27 @@ def run():
# If no trials at all have been evaluated, the study must have been stopped
# by pressing Ctrl+C while the first trial was running. In this case, we just
# re-raise the interrupt to invoke the standard handler defined below.
if not study.best_trials:
completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
if not completed_trials:
raise KeyboardInterrupt
best_trials = sorted(
study.best_trials,
key=lambda trial: trial.user_attrs["refusals"],
# Get the Pareto front of trials. We can't use study.best_trials directly
# as get_score() doesn't return the pure KL divergence and refusal count.
# Note: Unlike study.best_trials, this does not handle objective constraints.
sorted_trials = sorted(
completed_trials,
key=lambda trial: (
trial.user_attrs["refusals"],
trial.user_attrs["kl_divergence"],
),
)
min_divergence = math.inf
best_trials = []
for trial in sorted_trials:
kl_divergence = trial.user_attrs["kl_divergence"]
if kl_divergence < min_divergence:
min_divergence = kl_divergence
best_trials.append(trial)
choices = [
Choice(
@@ -426,6 +443,9 @@ def run():
print()
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
print("* Parameters:")
for name, value in get_trial_parameters(trial).items():
print(f" * {name} = [bold]{value}[/]")
print("* Reloading model...")
model.reload_model()
print("* Abliterating...")