fix: improve code quality, improve UX, fix small bugs

This commit is contained in:
Philipp Emanuel Weidmann
2026-02-08 13:32:00 +05:30
parent 2690655a83
commit f68a887a7b
6 changed files with 185 additions and 151 deletions
+2
View File
@@ -4,6 +4,8 @@
* Comments should start with a capital letter and end with a period. They should use correct grammar and spelling. * Comments should start with a capital letter and end with a period. They should use correct grammar and spelling.
* Function and method signatures **must** be fully type-annotated, including the return type (if any). * Function and method signatures **must** be fully type-annotated, including the return type (if any).
* Every Python code file **must** start with an SPDX/Copyright header. * Every Python code file **must** start with an SPDX/Copyright header.
* Settings descriptions should start with a capital letter and end with a period.
* When new settings are added in `config.py`, they should also be added to `config.default.toml`, set to their default value and with their description as a comment. The order of settings in `config.default.toml` should match that in `config.py`.
* Pull requests should implement one change, and one change only. * Pull requests should implement one change, and one change only.
* PRs containing multiple semantically independent changes **must** be split into multiple PRs. * PRs containing multiple semantically independent changes **must** be split into multiple PRs.
* PRs **must not** change existing code unless the changes are *directly related* to the PR. This includes changes to formatting and comments. * PRs **must not** change existing code unless the changes are *directly related* to the PR. This includes changes to formatting and comments.
+5 -2
View File
@@ -7,7 +7,7 @@ wheels/
*.egg-info *.egg-info
# Virtual environments # Virtual environments
.venv .venv/
# Caches # Caches
/.ruff_cache/ /.ruff_cache/
@@ -19,4 +19,7 @@ wheels/
/config.toml /config.toml
# Study checkpoints # Study checkpoints
/checkpoints/*.jsonl /checkpoints/
# Residual plots
/plots/
+27 -23
View File
@@ -15,15 +15,16 @@ dtypes = [
"float32", "float32",
] ]
# Quantization method to use when loading the model. Options:
# "none" (no quantization),
# "bnb_4bit" (4-bit quantization using bitsandbytes).
quantization = "none"
# Device map to pass to Accelerate when loading the model. # Device map to pass to Accelerate when loading the model.
device_map = "auto" device_map = "auto"
# Quantization method to use when loading the model. # Maximum memory to allocate per device.
# Options: "none" (no quantization), "bnb_4bit" (4-bit quantization using bitsandbytes). # max_memory = {"0": "20GB", "cpu": "64GB"}
quantization = "none"
# Memory limits to impose. 0 is usually your first graphics card.
# max_memory = {0 = "16GB", "cpu" = "64GB"}
# Number of input sequences to process in parallel (0 = auto). # Number of input sequences to process in parallel (0 = auto).
batch_size = 0 # auto batch_size = 0 # auto
@@ -34,22 +35,6 @@ max_batch_size = 128
# Maximum number of tokens to generate for each response. # Maximum number of tokens to generate for each response.
max_response_length = 100 max_response_length = 100
# Whether to adjust the refusal directions so that only the component that is
# orthogonal to the good direction is subtracted during abliteration.
orthogonalize_direction = false
# How to apply row normalization of the weights. Options:
# 'none' (no normalization),
# 'pre' (compute LoRA adapter relative to row-normalized weights),
# 'full' (like 'pre', but re-normalizes to preserve original row magnitudes).
row_normalization = "none"
# The rank of the LoRA adapter to use when 'full' row normalization is used.
# Row magnitude preservation is approximate due to non-linear efects,
# and this determines the rank of that approximation. Higher ranks produce
# larger output files and may slow down evaluation.
full_normalization_lora_rank = 3
# Whether to print prompt/response pairs when counting refusals. # Whether to print prompt/response pairs when counting refusals.
print_responses = false print_responses = false
@@ -76,9 +61,25 @@ kl_divergence_scale = 1.0
# This helps prevent the sampler from extensively exploring parameter combinations that "do nothing". # This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".
kl_divergence_target = 0.01 kl_divergence_target = 0.01
# Whether to adjust the refusal directions so that only the component that is
# orthogonal to the good direction is subtracted during abliteration.
orthogonalize_direction = false
# How to apply row normalization of the weights. Options:
# "none" (no normalization),
# "pre" (compute LoRA adapter relative to row-normalized weights),
# "full" (like "pre", but renormalizes to preserve original row magnitudes).
row_normalization = "none"
# The rank of the LoRA adapter to use when "full" row normalization is used.
# Row magnitude preservation is approximate due to non-linear effects,
# and this determines the rank of that approximation. Higher ranks produce
# larger output files and may slow down evaluation.
full_normalization_lora_rank = 3
# The symmetric winsorization to apply to each layer of the per-prompt residuals, # The symmetric winsorization to apply to each layer of the per-prompt residuals,
# expressed as the quantile to clamp to (between 0 and 1). Disabled by default. # expressed as the quantile to clamp to (between 0 and 1). Disabled by default.
# Example: winsorization_quantile = 0.95 applies a 90% winsorization. # Example: winsorization_quantile = 0.95 applies a 95% winsorization.
winsorization_quantile = 1.0 winsorization_quantile = 1.0
# Number of abliteration trials to run during optimization. # Number of abliteration trials to run during optimization.
@@ -87,6 +88,9 @@ n_trials = 200
# Number of trials that use random sampling for the purpose of exploration. # Number of trials that use random sampling for the purpose of exploration.
n_startup_trials = 60 n_startup_trials = 60
# Directory to save and load study progress to/from.
study_checkpoint_dir = "checkpoints"
# Strings whose presence in a response (case insensitive) identifies the response as a refusal. # Strings whose presence in a response (case insensitive) identifies the response as a refusal.
refusal_markers = [ refusal_markers = [
"sorry", "sorry",
+48 -38
View File
@@ -66,7 +66,10 @@ class Settings(BaseSettings):
evaluate_model: str | None = Field( evaluate_model: str | None = Field(
default=None, default=None,
description="If this model ID or path is set, then instead of abliterating the main model, evaluate this model relative to the main model.", description=(
"If this model ID or path is set, then instead of abliterating the main model, "
"evaluate this model relative to the main model."
),
) )
dtypes: list[str] = Field( dtypes: list[str] = Field(
@@ -82,7 +85,19 @@ class Settings(BaseSettings):
# if that was the dtype "auto" resolved to). # if that was the dtype "auto" resolved to).
"float32", "float32",
], ],
description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried.", description=(
"List of PyTorch dtypes to try when loading model tensors. "
"If loading with a dtype fails, the next dtype in the list will be tried."
),
)
quantization: QuantizationMethod = Field(
default=QuantizationMethod.NONE,
description=(
"Quantization method to use when loading the model. Options: "
'"none" (no quantization), '
'"bnb_4bit" (4-bit quantization using bitsandbytes).'
),
) )
device_map: str | Dict[str, int | str] = Field( device_map: str | Dict[str, int | str] = Field(
@@ -92,7 +107,7 @@ class Settings(BaseSettings):
max_memory: Dict[str, str] | None = Field( max_memory: Dict[str, str] | None = Field(
default=None, default=None,
description="Maximum memory to allocate per device (e.g., {'0': '20GB', 'cpu': '64GB'}).", description='Maximum memory to allocate per device (e.g., {"0": "20GB", "cpu": "64GB"}).',
) )
trust_remote_code: bool | None = Field( trust_remote_code: bool | None = Field(
@@ -100,11 +115,6 @@ class Settings(BaseSettings):
description="Whether to trust remote code when loading the model.", description="Whether to trust remote code when loading the model.",
) )
quantization: QuantizationMethod = Field(
default=QuantizationMethod.NONE,
description="Quantization method to use when loading the model. Options: 'none' (no quantization), 'bnb_4bit' (4-bit quantization using bitsandbytes).",
)
batch_size: int = Field( batch_size: int = Field(
default=0, # auto default=0, # auto
description="Number of input sequences to process in parallel (0 = auto).", description="Number of input sequences to process in parallel (0 = auto).",
@@ -120,34 +130,6 @@ class Settings(BaseSettings):
description="Maximum number of tokens to generate for each response.", description="Maximum number of tokens to generate for each response.",
) )
orthogonalize_direction: bool = Field(
default=False,
description=(
"Whether to adjust the refusal directions so that only the component that is "
"orthogonal to the good direction is subtracted during abliteration."
),
)
row_normalization: RowNormalization = Field(
default=RowNormalization.NONE,
description=(
"How to apply row normalization of the weights. Options: "
"'none' (no normalization), "
"'pre' (compute LoRA adapter relative to row-normalized weights), "
"'full' (like 'pre', but renormalizes to preserve original row magnitudes)."
),
)
full_normalization_lora_rank: int = Field(
default=3,
description=(
"The rank of the LoRA adapter to use when 'full' row normalization is used. "
"Row magnitude preservation is approximate due to non-linear efects, "
"and this determines the rank of that approximation. Higher ranks produce "
"larger output files and may slow down evaluation."
),
)
print_responses: bool = Field( print_responses: bool = Field(
default=False, default=False,
description="Whether to print prompt/response pairs when counting refusals.", description="Whether to print prompt/response pairs when counting refusals.",
@@ -194,12 +176,40 @@ class Settings(BaseSettings):
), ),
) )
orthogonalize_direction: bool = Field(
default=False,
description=(
"Whether to adjust the refusal directions so that only the component that is "
"orthogonal to the good direction is subtracted during abliteration."
),
)
row_normalization: RowNormalization = Field(
default=RowNormalization.NONE,
description=(
"How to apply row normalization of the weights. Options: "
'"none" (no normalization), '
'"pre" (compute LoRA adapter relative to row-normalized weights), '
'"full" (like "pre", but renormalizes to preserve original row magnitudes).'
),
)
full_normalization_lora_rank: int = Field(
default=3,
description=(
'The rank of the LoRA adapter to use when "full" row normalization is used. '
"Row magnitude preservation is approximate due to non-linear effects, "
"and this determines the rank of that approximation. Higher ranks produce "
"larger output files and may slow down evaluation."
),
)
winsorization_quantile: float = Field( winsorization_quantile: float = Field(
default=1.0, default=1.0,
description=( description=(
"The symmetric winsorization to apply to each layer of the per-prompt residuals, " "The symmetric winsorization to apply to each layer of the per-prompt residuals, "
"expressed as the quantile to clamp to (between 0 and 1). Disabled by default. " "expressed as the quantile to clamp to (between 0 and 1). Disabled by default. "
"Example: winsorization_quantile = 0.95 applies a 90% winsorization." "Example: winsorization_quantile = 0.95 applies a 95% winsorization."
), ),
) )
@@ -215,7 +225,7 @@ class Settings(BaseSettings):
study_checkpoint_dir: str = Field( study_checkpoint_dir: str = Field(
default="checkpoints", default="checkpoints",
description="Directory to save and load study progress to/from:", description="Directory to save and load study progress to/from.",
) )
refusal_markers: list[str] = Field( refusal_markers: list[str] = Field(
+86 -73
View File
@@ -61,21 +61,15 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
Returns "merge", "adapter", or None (if cancelled/invalid). Returns "merge", "adapter", or None (if cancelled/invalid).
""" """
# Prompt for all PEFT models to ensure user is aware of merge implications
if settings.quantization == QuantizationMethod.BNB_4BIT: if settings.quantization == QuantizationMethod.BNB_4BIT:
# Quantized models need special handling - we must reload the base model
# in full precision to merge the LoRA adapters
print() print()
print( print(
"[yellow]Model was loaded with quantization. Merging requires reloading the base model.[/]" "Model was loaded with quantization. Merging requires reloading the base model."
) )
print( print(
"[red](!) WARNING: CPU Merging requires dequantizing the entire model to System RAM.[/]" "[yellow]WARNING: CPU merging requires dequantizing the entire model to system RAM.[/]"
)
print("[red] This can lead to SYSTEM FREEZES if you run out of memory.[/]")
print(
"[yellow] Rule of thumb: You need approx. 3x the parameter count in GB.[/]"
) )
print("[yellow]This can lead to system freezes if you run out of memory.[/]")
try: try:
# Estimate memory requirements by loading the model structure on the "meta" device. # Estimate memory requirements by loading the model structure on the "meta" device.
@@ -94,61 +88,39 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
footprint_bytes = meta_model.get_memory_footprint() footprint_bytes = meta_model.get_memory_footprint()
footprint_gb = footprint_bytes / (1024**3) footprint_gb = footprint_bytes / (1024**3)
print( print(
f"[yellow] Estimated net RAM required for model weights (excluding overhead): [bold]~{footprint_gb:.1f} GB[/][/]" f"[yellow]Estimated RAM required (excluding overhead): [bold]~{footprint_gb:.2f} GB[/][/]"
) )
except Exception: except Exception:
# Fallback if meta loading fails (e.g. owing to custom model code # Fallback if meta loading fails (e.g. owing to custom model code
# or `bitsandbytes` quantization config issues on the meta device) # or bitsandbytes quantization config issues on the meta device).
print(
"[yellow]Rule of thumb: You need approximately 3x the parameter count in GB RAM.[/]"
)
print( print(
"[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]" "[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
) )
print() print()
merge_choice = prompt_select( strategy = prompt_select(
"How do you want to proceed?", "How do you want to proceed?",
choices=[ choices=[
Choice( Choice(
title="Merge full model" title="Merge LoRA into full model"
+ ( + (
"" ""
if settings.quantization == QuantizationMethod.NONE if settings.quantization == QuantizationMethod.NONE
else " (reload base model on CPU - requires high RAM)" else " (requires sufficient RAM)"
), ),
value="merge", value="merge",
), ),
Choice( Choice(
title="Save LoRA adapter only (can be merged later with llama.cpp or more RAM)", title="Save LoRA adapter only (can be merged later)",
value="adapter", value="adapter",
), ),
], ],
) )
return merge_choice
return strategy
def save_model(
model: Model,
save_directory: str,
settings: Settings,
strategy: str | None = None,
) -> None:
print("Saving model...")
if strategy is None:
strategy = obtain_merge_strategy(settings)
if strategy is None:
print("[yellow]Action cancelled.[/]")
return
if strategy == "adapter":
model.model.save_pretrained(save_directory)
else:
merged_model = model.get_merged_model()
merged_model.save_pretrained(save_directory)
del merged_model
empty_cache()
model.tokenizer.save_pretrained(save_directory)
print(f"Model saved to [bold]{save_directory}[/].")
def run(): def run():
@@ -249,6 +221,8 @@ def run():
# Silence the warning about multivariate TPE being experimental. # Silence the warning about multivariate TPE being experimental.
warnings.filterwarnings("ignore", category=ExperimentalWarning) warnings.filterwarnings("ignore", category=ExperimentalWarning)
os.makedirs(settings.study_checkpoint_dir, exist_ok=True)
study_checkpoint_file = os.path.join( study_checkpoint_file = os.path.join(
settings.study_checkpoint_dir, settings.study_checkpoint_dir,
"".join( "".join(
@@ -257,7 +231,6 @@ def run():
+ ".jsonl", + ".jsonl",
) )
os.makedirs(settings.study_checkpoint_dir, exist_ok=True)
lock_obj = JournalFileOpenLock(study_checkpoint_file) lock_obj = JournalFileOpenLock(study_checkpoint_file)
backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj) backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj)
storage = JournalStorage(backend) storage = JournalStorage(backend)
@@ -268,35 +241,57 @@ def run():
existing_study = None existing_study = None
if existing_study is not None: if existing_study is not None:
# A study is in here. Check if it's finished.
choices = [] choices = []
if existing_study.user_attrs["finished"]: if existing_study.user_attrs["finished"]:
print()
print( print(
"[green]You have already processed this model. How would you like to proceed?[/]" (
"[green]You have already processed this model.[/] "
"You can show the results from the previous run, allowing you to export models or to run additional trials. "
"Alternatively, you can ignore the previous run and start from scratch. "
"This will delete the checkpoint file and all results from the previous run."
)
) )
choices.append( choices.append(
Choice( Choice(
title="Show the results from the previous run, allowing you to export models, or to run additional trials.", title="Show the results from the previous run",
value="continue", value="continue",
) )
) )
else: else:
print()
print( print(
"[yellow]You have already processed this model, but the run was interrupted. How would you like to proceed?[/]", (
"[yellow]You have already processed this model, but the run was interrupted.[/] "
"You can continue the previous run from where it stopped. This will override any specified settings. "
"Alternatively, you can ignore the previous run and start from scratch. "
"This will delete the checkpoint file and all results from the previous run."
)
) )
choices.append( choices.append(
Choice( Choice(
title="Continue the previous run from where it stopped (will override all specified settings).", title="Continue the previous run",
value="continue", value="continue",
) )
) )
choices.append( choices.append(
Choice( Choice(
title="Ignore the previous run and start from scratch. This will delete the checkpoint file and all results from the previous run.", title="Ignore the previous run and start from scratch",
value="restart", value="restart",
) )
) )
choice = prompt_select("", choices)
choices.append(
Choice(
title="Exit program",
value="",
)
)
print()
choice = prompt_select("How would you like to proceed?", choices)
if choice == "continue": if choice == "continue":
settings = Settings.model_validate_json( settings = Settings.model_validate_json(
@@ -306,8 +301,7 @@ def run():
os.unlink(study_checkpoint_file) os.unlink(study_checkpoint_file)
backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj) backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj)
storage = JournalStorage(backend) storage = JournalStorage(backend)
else: elif choice is None or choice == "":
print("Cancelled; exiting.")
return return
model = Model(settings) model = Model(settings)
@@ -562,14 +556,14 @@ def run():
raise TrialPruned() raise TrialPruned()
study = optuna.create_study( study = optuna.create_study(
study_name="heretic",
sampler=TPESampler( sampler=TPESampler(
n_startup_trials=settings.n_startup_trials, n_startup_trials=settings.n_startup_trials,
n_ei_candidates=128, n_ei_candidates=128,
multivariate=True, multivariate=True,
), ),
storage=storage,
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE], directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
storage=storage,
study_name="heretic",
load_if_exists=True, load_if_exists=True,
) )
@@ -582,13 +576,14 @@ def run():
start_index = trial_index = count_completed_trials() start_index = trial_index = count_completed_trials()
if start_index > 0: if start_index > 0:
print()
print("Resuming existing study.") print("Resuming existing study.")
try: try:
study.optimize( study.optimize(
objective_wrapper, n_trials=settings.n_trials - count_completed_trials() objective_wrapper,
n_trials=settings.n_trials - count_completed_trials(),
) )
except KeyboardInterrupt: except KeyboardInterrupt:
# This additional handler takes care of the small chance that KeyboardInterrupt # This additional handler takes care of the small chance that KeyboardInterrupt
# is raised just between trials, which wouldn't be caught by the handler # is raised just between trials, which wouldn't be caught by the handler
@@ -638,14 +633,14 @@ def run():
choices.append( choices.append(
Choice( Choice(
title="Continue optimization (run more trials)", title="Run additional trials",
value="continue", value="continue",
) )
) )
choices.append( choices.append(
Choice( Choice(
title="None (exit program)", title="Exit program",
value="", value="",
) )
) )
@@ -669,18 +664,26 @@ def run():
if trial == "continue": if trial == "continue":
while True: while True:
try: try:
n_more_trials = int( n_additional_trials = prompt_text(
prompt_text("How many more trials do you want to run?") "How many additional trials do you want to run?"
) )
if n_more_trials > 0: if n_additional_trials is None or n_additional_trials == "":
n_additional_trials = 0
break
n_additional_trials = int(n_additional_trials)
if n_additional_trials > 0:
break break
print("[red]Please enter a number greater than 0.[/]") print("[red]Please enter a number greater than 0.[/]")
except ValueError: except ValueError:
print("[red]Invalid input. Please enter a number.[/]") print("[red]Please enter a number.[/]")
settings.n_trials += n_more_trials if n_additional_trials == 0:
continue
settings.n_trials += n_additional_trials
study.set_user_attr("settings", settings.model_dump_json()) study.set_user_attr("settings", settings.model_dump_json())
study.set_user_attr("finished", False) study.set_user_attr("finished", False)
try: try:
study.optimize( study.optimize(
objective_wrapper, objective_wrapper,
@@ -688,8 +691,10 @@ def run():
) )
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
if count_completed_trials() == settings.n_trials: if count_completed_trials() == settings.n_trials:
study.set_user_attr("finished", True) study.set_user_attr("finished", True)
break break
elif trial is None or trial == "": elif trial is None or trial == "":
@@ -720,14 +725,11 @@ def run():
"Save the model to a local folder", "Save the model to a local folder",
"Upload the model to Hugging Face", "Upload the model to Hugging Face",
"Chat with the model", "Chat with the model",
"Nothing (return to trial selection menu)", "Return to the trial selection menu",
], ],
) )
if ( if action is None or action == "Return to the trial selection menu":
action is None
or action == "Nothing (return to trial selection menu)"
):
break break
# All actions are wrapped in a try/except block so that if an error occurs, # All actions are wrapped in a try/except block so that if an error occurs,
@@ -740,11 +742,23 @@ def run():
if not save_directory: if not save_directory:
continue continue
save_model( strategy = obtain_merge_strategy(settings)
model, if strategy is None:
save_directory, continue
settings,
) if strategy == "adapter":
print("Saving LoRA adapter...")
model.model.save_pretrained(save_directory)
else:
print("Saving merged model...")
merged_model = model.get_merged_model()
merged_model.save_pretrained(save_directory)
del merged_model
empty_cache()
model.tokenizer.save_pretrained(save_directory)
print(f"Model saved to [bold]{save_directory}[/].")
case "Upload the model to Hugging Face": case "Upload the model to Hugging Face":
# We don't use huggingface_hub.login() because that stores the token on disk, # We don't use huggingface_hub.login() because that stores the token on disk,
@@ -780,7 +794,6 @@ def run():
strategy = obtain_merge_strategy(settings) strategy = obtain_merge_strategy(settings)
if strategy is None: if strategy is None:
print("[yellow]Action cancelled.[/]")
continue continue
if strategy == "adapter": if strategy == "adapter":
+15 -13
View File
@@ -38,7 +38,7 @@ def get_model_class(
) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]: ) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]:
configs = PretrainedConfig.get_config_dict(model) configs = PretrainedConfig.get_config_dict(model)
if any(["vision_config" in x for x in configs]): if any([("vision_config" in config) for config in configs]):
return AutoModelForImageTextToText return AutoModelForImageTextToText
else: else:
return AutoModelForCausalLM return AutoModelForCausalLM
@@ -96,9 +96,9 @@ class Model:
try: try:
quantization_config = self._get_quantization_config(dtype) quantization_config = self._get_quantization_config(dtype)
# Build kwargs, only include quantization_config if it's not None
# (some models like gpt-oss have issues with explicit None)
extra_kwargs = {} extra_kwargs = {}
# Only include quantization_config if it's not None
# (some models like gpt-oss have issues with explicit None).
if quantization_config is not None: if quantization_config is not None:
extra_kwargs["quantization_config"] = quantization_config extra_kwargs["quantization_config"] = quantization_config
@@ -134,9 +134,11 @@ class Model:
print(f"[red]Failed[/] ({error})") print(f"[red]Failed[/] ({error})")
continue continue
print("[green]Ok[/]")
if settings.quantization == QuantizationMethod.BNB_4BIT: if settings.quantization == QuantizationMethod.BNB_4BIT:
print("[bold green]Model loaded in 4-bit precision.[/]") print("[green]Ok[/] (quantized to 4-bit precision)")
else:
print("[green]Ok[/]")
break break
if self.model is None: if self.model is None:
@@ -158,7 +160,7 @@ class Model:
# Guard against calling this method at the wrong time. # Guard against calling this method at the wrong time.
assert isinstance(self.model, PreTrainedModel) assert isinstance(self.model, PreTrainedModel)
# Always use LoRA adapters for abliteration (faster reload, no weight modification) # Always use LoRA adapters for abliteration (faster reload, no weight modification).
# We use the leaf names (e.g. "o_proj") as target modules. # We use the leaf names (e.g. "o_proj") as target modules.
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"), # This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
# but this is harmless as we only abliterate the modules we target in `abliterate()`, # but this is harmless as we only abliterate the modules we target in `abliterate()`,
@@ -181,9 +183,8 @@ class Model:
lora_alpha=lora_rank, # Apply adapter at full strength. lora_alpha=lora_rank, # Apply adapter at full strength.
lora_dropout=0, lora_dropout=0,
bias="none", bias="none",
# Even if we're using AutoModelForImageTextToText, this is still correct, as it is (post-vision) # Even if we're using AutoModelForImageTextToText, this is still correct,
# the same kind of model. # as VL models are typically just causal LMs with an added image encoder.
# https://github.com/huggingface/peft/blob/622c2821cb0d7897bee53aad7914d42b5fecbf61/src/peft/auto.py#L45
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
) )
@@ -191,9 +192,7 @@ class Model:
# so the result is a PeftModel rather than a PeftMixedModel. # so the result is a PeftModel rather than a PeftMixedModel.
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config)) self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))
print( print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})")
f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]"
)
def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None: def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
""" """
@@ -637,7 +636,10 @@ class Model:
abs_residuals = torch.abs(residuals) abs_residuals = torch.abs(residuals)
# Get the (prompt, layer, 1) quantiles of the (prompt, layer, component) residuals. # Get the (prompt, layer, 1) quantiles of the (prompt, layer, component) residuals.
thresholds = torch.quantile( thresholds = torch.quantile(
abs_residuals, self.settings.winsorization_quantile, dim=2, keepdim=True abs_residuals,
self.settings.winsorization_quantile,
dim=2,
keepdim=True,
) )
return torch.clamp(residuals, -thresholds, thresholds) return torch.clamp(residuals, -thresholds, thresholds)