fix: improve code quality, improve UX, fix small bugs

This commit is contained in:
Philipp Emanuel Weidmann
2026-02-08 13:32:00 +05:30
parent 2690655a83
commit f68a887a7b
6 changed files with 185 additions and 151 deletions
+2
View File
@@ -4,6 +4,8 @@
* Comments should start with a capital letter and end with a period. They should use correct grammar and spelling.
* Function and method signatures **must** be fully type-annotated, including the return type (if any).
* Every Python code file **must** start with an SPDX/Copyright header.
* Settings descriptions should start with a capital letter and end with a period.
* When new settings are added in `config.py`, they should also be added to `config.default.toml`, set to their default value and with their description as a comment. The order of settings in `config.default.toml` should match that in `config.py`.
* Pull requests should implement one change, and one change only.
* PRs containing multiple semantically independent changes **must** be split into multiple PRs.
* PRs **must not** change existing code unless the changes are *directly related* to the PR. This includes changes to formatting and comments.
+5 -2
View File
@@ -7,7 +7,7 @@ wheels/
*.egg-info
# Virtual environments
.venv
.venv/
# Caches
/.ruff_cache/
@@ -19,4 +19,7 @@ wheels/
/config.toml
# Study checkpoints
/checkpoints/*.jsonl
/checkpoints/
# Residual plots
/plots/
+27 -23
View File
@@ -15,15 +15,16 @@ dtypes = [
"float32",
]
# Quantization method to use when loading the model. Options:
# "none" (no quantization),
# "bnb_4bit" (4-bit quantization using bitsandbytes).
quantization = "none"
# Device map to pass to Accelerate when loading the model.
device_map = "auto"
# Quantization method to use when loading the model.
# Options: "none" (no quantization), "bnb_4bit" (4-bit quantization using bitsandbytes).
quantization = "none"
# Memory limits to impose. 0 is usually your first graphics card.
# max_memory = {0 = "16GB", "cpu" = "64GB"}
# Maximum memory to allocate per device.
# max_memory = {"0": "20GB", "cpu": "64GB"}
# Number of input sequences to process in parallel (0 = auto).
batch_size = 0 # auto
@@ -34,22 +35,6 @@ max_batch_size = 128
# Maximum number of tokens to generate for each response.
max_response_length = 100
# Whether to adjust the refusal directions so that only the component that is
# orthogonal to the good direction is subtracted during abliteration.
orthogonalize_direction = false
# How to apply row normalization of the weights. Options:
# 'none' (no normalization),
# 'pre' (compute LoRA adapter relative to row-normalized weights),
# 'full' (like 'pre', but re-normalizes to preserve original row magnitudes).
row_normalization = "none"
# The rank of the LoRA adapter to use when 'full' row normalization is used.
# Row magnitude preservation is approximate due to non-linear efects,
# and this determines the rank of that approximation. Higher ranks produce
# larger output files and may slow down evaluation.
full_normalization_lora_rank = 3
# Whether to print prompt/response pairs when counting refusals.
print_responses = false
@@ -76,9 +61,25 @@ kl_divergence_scale = 1.0
# This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".
kl_divergence_target = 0.01
# Whether to adjust the refusal directions so that only the component that is
# orthogonal to the good direction is subtracted during abliteration.
orthogonalize_direction = false
# How to apply row normalization of the weights. Options:
# "none" (no normalization),
# "pre" (compute LoRA adapter relative to row-normalized weights),
# "full" (like "pre", but renormalizes to preserve original row magnitudes).
row_normalization = "none"
# The rank of the LoRA adapter to use when "full" row normalization is used.
# Row magnitude preservation is approximate due to non-linear effects,
# and this determines the rank of that approximation. Higher ranks produce
# larger output files and may slow down evaluation.
full_normalization_lora_rank = 3
# The symmetric winsorization to apply to each layer of the per-prompt residuals,
# expressed as the quantile to clamp to (between 0 and 1). Disabled by default.
# Example: winsorization_quantile = 0.95 applies a 90% winsorization.
# Example: winsorization_quantile = 0.95 applies a 95% winsorization.
winsorization_quantile = 1.0
# Number of abliteration trials to run during optimization.
@@ -87,6 +88,9 @@ n_trials = 200
# Number of trials that use random sampling for the purpose of exploration.
n_startup_trials = 60
# Directory to save and load study progress to/from.
study_checkpoint_dir = "checkpoints"
# Strings whose presence in a response (case insensitive) identifies the response as a refusal.
refusal_markers = [
"sorry",
+48 -38
View File
@@ -66,7 +66,10 @@ class Settings(BaseSettings):
evaluate_model: str | None = Field(
default=None,
description="If this model ID or path is set, then instead of abliterating the main model, evaluate this model relative to the main model.",
description=(
"If this model ID or path is set, then instead of abliterating the main model, "
"evaluate this model relative to the main model."
),
)
dtypes: list[str] = Field(
@@ -82,7 +85,19 @@ class Settings(BaseSettings):
# if that was the dtype "auto" resolved to).
"float32",
],
description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried.",
description=(
"List of PyTorch dtypes to try when loading model tensors. "
"If loading with a dtype fails, the next dtype in the list will be tried."
),
)
quantization: QuantizationMethod = Field(
default=QuantizationMethod.NONE,
description=(
"Quantization method to use when loading the model. Options: "
'"none" (no quantization), '
'"bnb_4bit" (4-bit quantization using bitsandbytes).'
),
)
device_map: str | Dict[str, int | str] = Field(
@@ -92,7 +107,7 @@ class Settings(BaseSettings):
max_memory: Dict[str, str] | None = Field(
default=None,
description="Maximum memory to allocate per device (e.g., {'0': '20GB', 'cpu': '64GB'}).",
description='Maximum memory to allocate per device (e.g., {"0": "20GB", "cpu": "64GB"}).',
)
trust_remote_code: bool | None = Field(
@@ -100,11 +115,6 @@ class Settings(BaseSettings):
description="Whether to trust remote code when loading the model.",
)
quantization: QuantizationMethod = Field(
default=QuantizationMethod.NONE,
description="Quantization method to use when loading the model. Options: 'none' (no quantization), 'bnb_4bit' (4-bit quantization using bitsandbytes).",
)
batch_size: int = Field(
default=0, # auto
description="Number of input sequences to process in parallel (0 = auto).",
@@ -120,34 +130,6 @@ class Settings(BaseSettings):
description="Maximum number of tokens to generate for each response.",
)
orthogonalize_direction: bool = Field(
default=False,
description=(
"Whether to adjust the refusal directions so that only the component that is "
"orthogonal to the good direction is subtracted during abliteration."
),
)
row_normalization: RowNormalization = Field(
default=RowNormalization.NONE,
description=(
"How to apply row normalization of the weights. Options: "
"'none' (no normalization), "
"'pre' (compute LoRA adapter relative to row-normalized weights), "
"'full' (like 'pre', but renormalizes to preserve original row magnitudes)."
),
)
full_normalization_lora_rank: int = Field(
default=3,
description=(
"The rank of the LoRA adapter to use when 'full' row normalization is used. "
"Row magnitude preservation is approximate due to non-linear efects, "
"and this determines the rank of that approximation. Higher ranks produce "
"larger output files and may slow down evaluation."
),
)
print_responses: bool = Field(
default=False,
description="Whether to print prompt/response pairs when counting refusals.",
@@ -194,12 +176,40 @@ class Settings(BaseSettings):
),
)
orthogonalize_direction: bool = Field(
default=False,
description=(
"Whether to adjust the refusal directions so that only the component that is "
"orthogonal to the good direction is subtracted during abliteration."
),
)
row_normalization: RowNormalization = Field(
default=RowNormalization.NONE,
description=(
"How to apply row normalization of the weights. Options: "
'"none" (no normalization), '
'"pre" (compute LoRA adapter relative to row-normalized weights), '
'"full" (like "pre", but renormalizes to preserve original row magnitudes).'
),
)
full_normalization_lora_rank: int = Field(
default=3,
description=(
'The rank of the LoRA adapter to use when "full" row normalization is used. '
"Row magnitude preservation is approximate due to non-linear effects, "
"and this determines the rank of that approximation. Higher ranks produce "
"larger output files and may slow down evaluation."
),
)
winsorization_quantile: float = Field(
default=1.0,
description=(
"The symmetric winsorization to apply to each layer of the per-prompt residuals, "
"expressed as the quantile to clamp to (between 0 and 1). Disabled by default. "
"Example: winsorization_quantile = 0.95 applies a 90% winsorization."
"Example: winsorization_quantile = 0.95 applies a 95% winsorization."
),
)
@@ -215,7 +225,7 @@ class Settings(BaseSettings):
study_checkpoint_dir: str = Field(
default="checkpoints",
description="Directory to save and load study progress to/from:",
description="Directory to save and load study progress to/from.",
)
refusal_markers: list[str] = Field(
+86 -73
View File
@@ -61,21 +61,15 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
Returns "merge", "adapter", or None (if cancelled/invalid).
"""
# Prompt for all PEFT models to ensure user is aware of merge implications
if settings.quantization == QuantizationMethod.BNB_4BIT:
# Quantized models need special handling - we must reload the base model
# in full precision to merge the LoRA adapters
print()
print(
"[yellow]Model was loaded with quantization. Merging requires reloading the base model.[/]"
"Model was loaded with quantization. Merging requires reloading the base model."
)
print(
"[red](!) WARNING: CPU Merging requires dequantizing the entire model to System RAM.[/]"
)
print("[red] This can lead to SYSTEM FREEZES if you run out of memory.[/]")
print(
"[yellow] Rule of thumb: You need approx. 3x the parameter count in GB.[/]"
"[yellow]WARNING: CPU merging requires dequantizing the entire model to system RAM.[/]"
)
print("[yellow]This can lead to system freezes if you run out of memory.[/]")
try:
# Estimate memory requirements by loading the model structure on the "meta" device.
@@ -94,61 +88,39 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
footprint_bytes = meta_model.get_memory_footprint()
footprint_gb = footprint_bytes / (1024**3)
print(
f"[yellow] Estimated net RAM required for model weights (excluding overhead): [bold]~{footprint_gb:.1f} GB[/][/]"
f"[yellow]Estimated RAM required (excluding overhead): [bold]~{footprint_gb:.2f} GB[/][/]"
)
except Exception:
# Fallback if meta loading fails (e.g. owing to custom model code
# or `bitsandbytes` quantization config issues on the meta device)
# or bitsandbytes quantization config issues on the meta device).
print(
"[yellow]Rule of thumb: You need approximately 3x the parameter count in GB RAM.[/]"
)
print(
"[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
)
print()
merge_choice = prompt_select(
strategy = prompt_select(
"How do you want to proceed?",
choices=[
Choice(
title="Merge full model"
title="Merge LoRA into full model"
+ (
""
if settings.quantization == QuantizationMethod.NONE
else " (reload base model on CPU - requires high RAM)"
else " (requires sufficient RAM)"
),
value="merge",
),
Choice(
title="Save LoRA adapter only (can be merged later with llama.cpp or more RAM)",
title="Save LoRA adapter only (can be merged later)",
value="adapter",
),
],
)
return merge_choice
def save_model(
model: Model,
save_directory: str,
settings: Settings,
strategy: str | None = None,
) -> None:
print("Saving model...")
if strategy is None:
strategy = obtain_merge_strategy(settings)
if strategy is None:
print("[yellow]Action cancelled.[/]")
return
if strategy == "adapter":
model.model.save_pretrained(save_directory)
else:
merged_model = model.get_merged_model()
merged_model.save_pretrained(save_directory)
del merged_model
empty_cache()
model.tokenizer.save_pretrained(save_directory)
print(f"Model saved to [bold]{save_directory}[/].")
return strategy
def run():
@@ -249,6 +221,8 @@ def run():
# Silence the warning about multivariate TPE being experimental.
warnings.filterwarnings("ignore", category=ExperimentalWarning)
os.makedirs(settings.study_checkpoint_dir, exist_ok=True)
study_checkpoint_file = os.path.join(
settings.study_checkpoint_dir,
"".join(
@@ -257,7 +231,6 @@ def run():
+ ".jsonl",
)
os.makedirs(settings.study_checkpoint_dir, exist_ok=True)
lock_obj = JournalFileOpenLock(study_checkpoint_file)
backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj)
storage = JournalStorage(backend)
@@ -268,35 +241,57 @@ def run():
existing_study = None
if existing_study is not None:
# A study is in here. Check if it's finished.
choices = []
if existing_study.user_attrs["finished"]:
print()
print(
"[green]You have already processed this model. How would you like to proceed?[/]"
(
"[green]You have already processed this model.[/] "
"You can show the results from the previous run, allowing you to export models or to run additional trials. "
"Alternatively, you can ignore the previous run and start from scratch. "
"This will delete the checkpoint file and all results from the previous run."
)
)
choices.append(
Choice(
title="Show the results from the previous run, allowing you to export models, or to run additional trials.",
title="Show the results from the previous run",
value="continue",
)
)
else:
print()
print(
"[yellow]You have already processed this model, but the run was interrupted. How would you like to proceed?[/]",
(
"[yellow]You have already processed this model, but the run was interrupted.[/] "
"You can continue the previous run from where it stopped. This will override any specified settings. "
"Alternatively, you can ignore the previous run and start from scratch. "
"This will delete the checkpoint file and all results from the previous run."
)
)
choices.append(
Choice(
title="Continue the previous run from where it stopped (will override all specified settings).",
title="Continue the previous run",
value="continue",
)
)
choices.append(
Choice(
title="Ignore the previous run and start from scratch. This will delete the checkpoint file and all results from the previous run.",
title="Ignore the previous run and start from scratch",
value="restart",
)
)
choice = prompt_select("", choices)
choices.append(
Choice(
title="Exit program",
value="",
)
)
print()
choice = prompt_select("How would you like to proceed?", choices)
if choice == "continue":
settings = Settings.model_validate_json(
@@ -306,8 +301,7 @@ def run():
os.unlink(study_checkpoint_file)
backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj)
storage = JournalStorage(backend)
else:
print("Cancelled; exiting.")
elif choice is None or choice == "":
return
model = Model(settings)
@@ -562,14 +556,14 @@ def run():
raise TrialPruned()
study = optuna.create_study(
study_name="heretic",
sampler=TPESampler(
n_startup_trials=settings.n_startup_trials,
n_ei_candidates=128,
multivariate=True,
),
storage=storage,
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
storage=storage,
study_name="heretic",
load_if_exists=True,
)
@@ -582,13 +576,14 @@ def run():
start_index = trial_index = count_completed_trials()
if start_index > 0:
print()
print("Resuming existing study.")
try:
study.optimize(
objective_wrapper, n_trials=settings.n_trials - count_completed_trials()
objective_wrapper,
n_trials=settings.n_trials - count_completed_trials(),
)
except KeyboardInterrupt:
# This additional handler takes care of the small chance that KeyboardInterrupt
# is raised just between trials, which wouldn't be caught by the handler
@@ -638,14 +633,14 @@ def run():
choices.append(
Choice(
title="Continue optimization (run more trials)",
title="Run additional trials",
value="continue",
)
)
choices.append(
Choice(
title="None (exit program)",
title="Exit program",
value="",
)
)
@@ -669,18 +664,26 @@ def run():
if trial == "continue":
while True:
try:
n_more_trials = int(
prompt_text("How many more trials do you want to run?")
n_additional_trials = prompt_text(
"How many additional trials do you want to run?"
)
if n_more_trials > 0:
if n_additional_trials is None or n_additional_trials == "":
n_additional_trials = 0
break
n_additional_trials = int(n_additional_trials)
if n_additional_trials > 0:
break
print("[red]Please enter a number greater than 0.[/]")
except ValueError:
print("[red]Invalid input. Please enter a number.[/]")
print("[red]Please enter a number.[/]")
settings.n_trials += n_more_trials
if n_additional_trials == 0:
continue
settings.n_trials += n_additional_trials
study.set_user_attr("settings", settings.model_dump_json())
study.set_user_attr("finished", False)
try:
study.optimize(
objective_wrapper,
@@ -688,8 +691,10 @@ def run():
)
except KeyboardInterrupt:
pass
if count_completed_trials() == settings.n_trials:
study.set_user_attr("finished", True)
break
elif trial is None or trial == "":
@@ -720,14 +725,11 @@ def run():
"Save the model to a local folder",
"Upload the model to Hugging Face",
"Chat with the model",
"Nothing (return to trial selection menu)",
"Return to the trial selection menu",
],
)
if (
action is None
or action == "Nothing (return to trial selection menu)"
):
if action is None or action == "Return to the trial selection menu":
break
# All actions are wrapped in a try/except block so that if an error occurs,
@@ -740,11 +742,23 @@ def run():
if not save_directory:
continue
save_model(
model,
save_directory,
settings,
)
strategy = obtain_merge_strategy(settings)
if strategy is None:
continue
if strategy == "adapter":
print("Saving LoRA adapter...")
model.model.save_pretrained(save_directory)
else:
print("Saving merged model...")
merged_model = model.get_merged_model()
merged_model.save_pretrained(save_directory)
del merged_model
empty_cache()
model.tokenizer.save_pretrained(save_directory)
print(f"Model saved to [bold]{save_directory}[/].")
case "Upload the model to Hugging Face":
# We don't use huggingface_hub.login() because that stores the token on disk,
@@ -780,7 +794,6 @@ def run():
strategy = obtain_merge_strategy(settings)
if strategy is None:
print("[yellow]Action cancelled.[/]")
continue
if strategy == "adapter":
+15 -13
View File
@@ -38,7 +38,7 @@ def get_model_class(
) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]:
configs = PretrainedConfig.get_config_dict(model)
if any(["vision_config" in x for x in configs]):
if any([("vision_config" in config) for config in configs]):
return AutoModelForImageTextToText
else:
return AutoModelForCausalLM
@@ -96,9 +96,9 @@ class Model:
try:
quantization_config = self._get_quantization_config(dtype)
# Build kwargs, only include quantization_config if it's not None
# (some models like gpt-oss have issues with explicit None)
extra_kwargs = {}
# Only include quantization_config if it's not None
# (some models like gpt-oss have issues with explicit None).
if quantization_config is not None:
extra_kwargs["quantization_config"] = quantization_config
@@ -134,9 +134,11 @@ class Model:
print(f"[red]Failed[/] ({error})")
continue
print("[green]Ok[/]")
if settings.quantization == QuantizationMethod.BNB_4BIT:
print("[bold green]Model loaded in 4-bit precision.[/]")
print("[green]Ok[/] (quantized to 4-bit precision)")
else:
print("[green]Ok[/]")
break
if self.model is None:
@@ -158,7 +160,7 @@ class Model:
# Guard against calling this method at the wrong time.
assert isinstance(self.model, PreTrainedModel)
# Always use LoRA adapters for abliteration (faster reload, no weight modification)
# Always use LoRA adapters for abliteration (faster reload, no weight modification).
# We use the leaf names (e.g. "o_proj") as target modules.
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
# but this is harmless as we only abliterate the modules we target in `abliterate()`,
@@ -181,9 +183,8 @@ class Model:
lora_alpha=lora_rank, # Apply adapter at full strength.
lora_dropout=0,
bias="none",
# Even if we're using AutoModelForImageTextToText, this is still correct, as it is (post-vision)
# the same kind of model.
# https://github.com/huggingface/peft/blob/622c2821cb0d7897bee53aad7914d42b5fecbf61/src/peft/auto.py#L45
# Even if we're using AutoModelForImageTextToText, this is still correct,
# as VL models are typically just causal LMs with an added image encoder.
task_type="CAUSAL_LM",
)
@@ -191,9 +192,7 @@ class Model:
# so the result is a PeftModel rather than a PeftMixedModel.
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))
print(
f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]"
)
print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})")
def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
"""
@@ -637,7 +636,10 @@ class Model:
abs_residuals = torch.abs(residuals)
# Get the (prompt, layer, 1) quantiles of the (prompt, layer, component) residuals.
thresholds = torch.quantile(
abs_residuals, self.settings.winsorization_quantile, dim=2, keepdim=True
abs_residuals,
self.settings.winsorization_quantile,
dim=2,
keepdim=True,
)
return torch.clamp(residuals, -thresholds, thresholds)