mirror of
https://github.com/p-e-w/heretic.git
synced 2026-06-02 05:03:33 +02:00
feat: avoid excessive low divergence iteration (#73)
* feat: adjust scoring to avoid useless iteration Adjusts the scoring function to avoid targeting meaninglessly low KL divergences. Below a threshold value, the KL divergence score switches to the refusal count. Adds config option kl_divergence_target (defaulting to 0.01). * fix: Clean up parameter selection in objective Create variables for num_layers and last_layer_index * Improves readability and makes choices explicit * feat: Print the parameters of the selected model
This commit is contained in:
@@ -49,6 +49,10 @@ residual_plot_style = "dark_background"
|
||||
# This is used to ensure balanced co-optimization of KL divergence and refusal count.
|
||||
kl_divergence_scale = 1.0
|
||||
|
||||
# The KL divergence to target. Below this value, an objective based on the refusal count is used.
|
||||
# This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".
|
||||
kl_divergence_target = 0.01
|
||||
|
||||
# Number of abliteration trials to run during optimization.
|
||||
n_trials = 200
|
||||
|
||||
|
||||
@@ -119,6 +119,14 @@ class Settings(BaseSettings):
|
||||
),
|
||||
)
|
||||
|
||||
kl_divergence_target: float = Field(
|
||||
default=0.01,
|
||||
description=(
|
||||
"The KL divergence to target. Below this value, an objective based on the refusal count is used."
|
||||
'This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".'
|
||||
),
|
||||
)
|
||||
|
||||
n_trials: int = Field(
|
||||
default=200,
|
||||
description="Number of abliteration trials to run during optimization.",
|
||||
|
||||
@@ -76,9 +76,19 @@ class Evaluator:
|
||||
refusals = self.count_refusals()
|
||||
print(f" * Refusals: [bold]{refusals}[/]/{len(self.bad_prompts)}")
|
||||
|
||||
kl_divergence_scale = self.settings.kl_divergence_scale
|
||||
kl_divergence_target = self.settings.kl_divergence_target
|
||||
|
||||
refusals_score = refusals / self.base_refusals
|
||||
|
||||
if kl_divergence >= kl_divergence_target:
|
||||
kld_score = kl_divergence / kl_divergence_scale
|
||||
else:
|
||||
kld_score = refusals_score * kl_divergence_target / kl_divergence_scale
|
||||
|
||||
score = (
|
||||
(kl_divergence / self.settings.kl_divergence_scale),
|
||||
(refusals / self.base_refusals),
|
||||
kld_score,
|
||||
refusals_score,
|
||||
)
|
||||
|
||||
return score, kl_divergence, refusals
|
||||
|
||||
+29
-9
@@ -27,6 +27,7 @@ from optuna import Trial, TrialPruned
|
||||
from optuna.exceptions import ExperimentalWarning
|
||||
from optuna.samplers import TPESampler
|
||||
from optuna.study import StudyDirection
|
||||
from optuna.trial import TrialState
|
||||
from pydantic import ValidationError
|
||||
from questionary import Choice
|
||||
from rich.traceback import install
|
||||
@@ -264,6 +265,8 @@ def run():
|
||||
],
|
||||
)
|
||||
|
||||
last_layer_index = len(model.get_layers()) - 1
|
||||
|
||||
# Discrimination between "harmful" and "harmless" inputs is usually strongest
|
||||
# in layers slightly past the midpoint of the layer stack. See the original
|
||||
# abliteration paper (https://arxiv.org/abs/2406.11717) for a deeper analysis.
|
||||
@@ -273,8 +276,8 @@ def run():
|
||||
# work with conditional or variable-range parameters.
|
||||
direction_index = trial.suggest_float(
|
||||
"direction_index",
|
||||
0.4 * (len(model.get_layers()) - 1),
|
||||
0.9 * (len(model.get_layers()) - 1),
|
||||
0.4 * last_layer_index,
|
||||
0.9 * last_layer_index,
|
||||
)
|
||||
|
||||
if direction_scope == "per layer":
|
||||
@@ -293,8 +296,8 @@ def run():
|
||||
)
|
||||
max_weight_position = trial.suggest_float(
|
||||
f"{component}.max_weight_position",
|
||||
0.6 * (len(model.get_layers()) - 1),
|
||||
len(model.get_layers()) - 1,
|
||||
0.6 * last_layer_index,
|
||||
1.0 * last_layer_index,
|
||||
)
|
||||
# For sampling purposes, min_weight is expressed as a fraction of max_weight,
|
||||
# again because multivariate TPE doesn't support variable-range parameters.
|
||||
@@ -307,7 +310,7 @@ def run():
|
||||
min_weight_distance = trial.suggest_float(
|
||||
f"{component}.min_weight_distance",
|
||||
1.0,
|
||||
0.6 * (len(model.get_layers()) - 1),
|
||||
0.6 * last_layer_index,
|
||||
)
|
||||
|
||||
parameters[component] = AbliterationParameters(
|
||||
@@ -378,13 +381,27 @@ def run():
|
||||
# If no trials at all have been evaluated, the study must have been stopped
|
||||
# by pressing Ctrl+C while the first trial was running. In this case, we just
|
||||
# re-raise the interrupt to invoke the standard handler defined below.
|
||||
if not study.best_trials:
|
||||
completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]
|
||||
if not completed_trials:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
best_trials = sorted(
|
||||
study.best_trials,
|
||||
key=lambda trial: trial.user_attrs["refusals"],
|
||||
# Get the Pareto front of trials. We can't use study.best_trials directly
|
||||
# as get_score() doesn't return the pure KL divergence and refusal count.
|
||||
# Note: Unlike study.best_trials, this does not handle objective constraints.
|
||||
sorted_trials = sorted(
|
||||
completed_trials,
|
||||
key=lambda trial: (
|
||||
trial.user_attrs["refusals"],
|
||||
trial.user_attrs["kl_divergence"],
|
||||
),
|
||||
)
|
||||
min_divergence = math.inf
|
||||
best_trials = []
|
||||
for trial in sorted_trials:
|
||||
kl_divergence = trial.user_attrs["kl_divergence"]
|
||||
if kl_divergence < min_divergence:
|
||||
min_divergence = kl_divergence
|
||||
best_trials.append(trial)
|
||||
|
||||
choices = [
|
||||
Choice(
|
||||
@@ -426,6 +443,9 @@ def run():
|
||||
|
||||
print()
|
||||
print(f"Restoring model from trial [bold]{trial.user_attrs['index']}[/]...")
|
||||
print("* Parameters:")
|
||||
for name, value in get_trial_parameters(trial).items():
|
||||
print(f" * {name} = [bold]{value}[/]")
|
||||
print("* Reloading model...")
|
||||
model.reload_model()
|
||||
print("* Abliterating...")
|
||||
|
||||
Reference in New Issue
Block a user