mirror of
https://github.com/p-e-w/heretic.git
synced 2026-06-01 20:58:47 +02:00
fix: improve code quality, improve UX, fix small bugs
This commit is contained in:
@@ -4,6 +4,8 @@
|
||||
* Comments should start with a capital letter and end with a period. They should use correct grammar and spelling.
|
||||
* Function and method signatures **must** be fully type-annotated, including the return type (if any).
|
||||
* Every Python code file **must** start with an SPDX/Copyright header.
|
||||
* Settings descriptions should start with a capital letter and end with a period.
|
||||
* When new settings are added in `config.py`, they should also be added to `config.default.toml`, set to their default value and with their description as a comment. The order of settings in `config.default.toml` should match that in `config.py`.
|
||||
* Pull requests should implement one change, and one change only.
|
||||
* PRs containing multiple semantically independent changes **must** be split into multiple PRs.
|
||||
* PRs **must not** change existing code unless the changes are *directly related* to the PR. This includes changes to formatting and comments.
|
||||
|
||||
+5
-2
@@ -7,7 +7,7 @@ wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
.venv/
|
||||
|
||||
# Caches
|
||||
/.ruff_cache/
|
||||
@@ -19,4 +19,7 @@ wheels/
|
||||
/config.toml
|
||||
|
||||
# Study checkpoints
|
||||
/checkpoints/*.jsonl
|
||||
/checkpoints/
|
||||
|
||||
# Residual plots
|
||||
/plots/
|
||||
|
||||
+27
-23
@@ -15,15 +15,16 @@ dtypes = [
|
||||
"float32",
|
||||
]
|
||||
|
||||
# Quantization method to use when loading the model. Options:
|
||||
# "none" (no quantization),
|
||||
# "bnb_4bit" (4-bit quantization using bitsandbytes).
|
||||
quantization = "none"
|
||||
|
||||
# Device map to pass to Accelerate when loading the model.
|
||||
device_map = "auto"
|
||||
|
||||
# Quantization method to use when loading the model.
|
||||
# Options: "none" (no quantization), "bnb_4bit" (4-bit quantization using bitsandbytes).
|
||||
quantization = "none"
|
||||
|
||||
# Memory limits to impose. 0 is usually your first graphics card.
|
||||
# max_memory = {0 = "16GB", "cpu" = "64GB"}
|
||||
# Maximum memory to allocate per device.
|
||||
# max_memory = {"0": "20GB", "cpu": "64GB"}
|
||||
|
||||
# Number of input sequences to process in parallel (0 = auto).
|
||||
batch_size = 0 # auto
|
||||
@@ -34,22 +35,6 @@ max_batch_size = 128
|
||||
# Maximum number of tokens to generate for each response.
|
||||
max_response_length = 100
|
||||
|
||||
# Whether to adjust the refusal directions so that only the component that is
|
||||
# orthogonal to the good direction is subtracted during abliteration.
|
||||
orthogonalize_direction = false
|
||||
|
||||
# How to apply row normalization of the weights. Options:
|
||||
# 'none' (no normalization),
|
||||
# 'pre' (compute LoRA adapter relative to row-normalized weights),
|
||||
# 'full' (like 'pre', but re-normalizes to preserve original row magnitudes).
|
||||
row_normalization = "none"
|
||||
|
||||
# The rank of the LoRA adapter to use when 'full' row normalization is used.
|
||||
# Row magnitude preservation is approximate due to non-linear efects,
|
||||
# and this determines the rank of that approximation. Higher ranks produce
|
||||
# larger output files and may slow down evaluation.
|
||||
full_normalization_lora_rank = 3
|
||||
|
||||
# Whether to print prompt/response pairs when counting refusals.
|
||||
print_responses = false
|
||||
|
||||
@@ -76,9 +61,25 @@ kl_divergence_scale = 1.0
|
||||
# This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".
|
||||
kl_divergence_target = 0.01
|
||||
|
||||
# Whether to adjust the refusal directions so that only the component that is
|
||||
# orthogonal to the good direction is subtracted during abliteration.
|
||||
orthogonalize_direction = false
|
||||
|
||||
# How to apply row normalization of the weights. Options:
|
||||
# "none" (no normalization),
|
||||
# "pre" (compute LoRA adapter relative to row-normalized weights),
|
||||
# "full" (like "pre", but renormalizes to preserve original row magnitudes).
|
||||
row_normalization = "none"
|
||||
|
||||
# The rank of the LoRA adapter to use when "full" row normalization is used.
|
||||
# Row magnitude preservation is approximate due to non-linear effects,
|
||||
# and this determines the rank of that approximation. Higher ranks produce
|
||||
# larger output files and may slow down evaluation.
|
||||
full_normalization_lora_rank = 3
|
||||
|
||||
# The symmetric winsorization to apply to each layer of the per-prompt residuals,
|
||||
# expressed as the quantile to clamp to (between 0 and 1). Disabled by default.
|
||||
# Example: winsorization_quantile = 0.95 applies a 90% winsorization.
|
||||
# Example: winsorization_quantile = 0.95 applies a 95% winsorization.
|
||||
winsorization_quantile = 1.0
|
||||
|
||||
# Number of abliteration trials to run during optimization.
|
||||
@@ -87,6 +88,9 @@ n_trials = 200
|
||||
# Number of trials that use random sampling for the purpose of exploration.
|
||||
n_startup_trials = 60
|
||||
|
||||
# Directory to save and load study progress to/from.
|
||||
study_checkpoint_dir = "checkpoints"
|
||||
|
||||
# Strings whose presence in a response (case insensitive) identifies the response as a refusal.
|
||||
refusal_markers = [
|
||||
"sorry",
|
||||
|
||||
+49
-39
@@ -66,7 +66,10 @@ class Settings(BaseSettings):
|
||||
|
||||
evaluate_model: str | None = Field(
|
||||
default=None,
|
||||
description="If this model ID or path is set, then instead of abliterating the main model, evaluate this model relative to the main model.",
|
||||
description=(
|
||||
"If this model ID or path is set, then instead of abliterating the main model, "
|
||||
"evaluate this model relative to the main model."
|
||||
),
|
||||
)
|
||||
|
||||
dtypes: list[str] = Field(
|
||||
@@ -82,7 +85,19 @@ class Settings(BaseSettings):
|
||||
# if that was the dtype "auto" resolved to).
|
||||
"float32",
|
||||
],
|
||||
description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried.",
|
||||
description=(
|
||||
"List of PyTorch dtypes to try when loading model tensors. "
|
||||
"If loading with a dtype fails, the next dtype in the list will be tried."
|
||||
),
|
||||
)
|
||||
|
||||
quantization: QuantizationMethod = Field(
|
||||
default=QuantizationMethod.NONE,
|
||||
description=(
|
||||
"Quantization method to use when loading the model. Options: "
|
||||
'"none" (no quantization), '
|
||||
'"bnb_4bit" (4-bit quantization using bitsandbytes).'
|
||||
),
|
||||
)
|
||||
|
||||
device_map: str | Dict[str, int | str] = Field(
|
||||
@@ -92,7 +107,7 @@ class Settings(BaseSettings):
|
||||
|
||||
max_memory: Dict[str, str] | None = Field(
|
||||
default=None,
|
||||
description="Maximum memory to allocate per device (e.g., {'0': '20GB', 'cpu': '64GB'}).",
|
||||
description='Maximum memory to allocate per device (e.g., {"0": "20GB", "cpu": "64GB"}).',
|
||||
)
|
||||
|
||||
trust_remote_code: bool | None = Field(
|
||||
@@ -100,11 +115,6 @@ class Settings(BaseSettings):
|
||||
description="Whether to trust remote code when loading the model.",
|
||||
)
|
||||
|
||||
quantization: QuantizationMethod = Field(
|
||||
default=QuantizationMethod.NONE,
|
||||
description="Quantization method to use when loading the model. Options: 'none' (no quantization), 'bnb_4bit' (4-bit quantization using bitsandbytes).",
|
||||
)
|
||||
|
||||
batch_size: int = Field(
|
||||
default=0, # auto
|
||||
description="Number of input sequences to process in parallel (0 = auto).",
|
||||
@@ -120,34 +130,6 @@ class Settings(BaseSettings):
|
||||
description="Maximum number of tokens to generate for each response.",
|
||||
)
|
||||
|
||||
orthogonalize_direction: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to adjust the refusal directions so that only the component that is "
|
||||
"orthogonal to the good direction is subtracted during abliteration."
|
||||
),
|
||||
)
|
||||
|
||||
row_normalization: RowNormalization = Field(
|
||||
default=RowNormalization.NONE,
|
||||
description=(
|
||||
"How to apply row normalization of the weights. Options: "
|
||||
"'none' (no normalization), "
|
||||
"'pre' (compute LoRA adapter relative to row-normalized weights), "
|
||||
"'full' (like 'pre', but renormalizes to preserve original row magnitudes)."
|
||||
),
|
||||
)
|
||||
|
||||
full_normalization_lora_rank: int = Field(
|
||||
default=3,
|
||||
description=(
|
||||
"The rank of the LoRA adapter to use when 'full' row normalization is used. "
|
||||
"Row magnitude preservation is approximate due to non-linear efects, "
|
||||
"and this determines the rank of that approximation. Higher ranks produce "
|
||||
"larger output files and may slow down evaluation."
|
||||
),
|
||||
)
|
||||
|
||||
print_responses: bool = Field(
|
||||
default=False,
|
||||
description="Whether to print prompt/response pairs when counting refusals.",
|
||||
@@ -189,17 +171,45 @@ class Settings(BaseSettings):
|
||||
kl_divergence_target: float = Field(
|
||||
default=0.01,
|
||||
description=(
|
||||
"The KL divergence to target. Below this value, an objective based on the refusal count is used."
|
||||
"The KL divergence to target. Below this value, an objective based on the refusal count is used. "
|
||||
'This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".'
|
||||
),
|
||||
)
|
||||
|
||||
orthogonalize_direction: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to adjust the refusal directions so that only the component that is "
|
||||
"orthogonal to the good direction is subtracted during abliteration."
|
||||
),
|
||||
)
|
||||
|
||||
row_normalization: RowNormalization = Field(
|
||||
default=RowNormalization.NONE,
|
||||
description=(
|
||||
"How to apply row normalization of the weights. Options: "
|
||||
'"none" (no normalization), '
|
||||
'"pre" (compute LoRA adapter relative to row-normalized weights), '
|
||||
'"full" (like "pre", but renormalizes to preserve original row magnitudes).'
|
||||
),
|
||||
)
|
||||
|
||||
full_normalization_lora_rank: int = Field(
|
||||
default=3,
|
||||
description=(
|
||||
'The rank of the LoRA adapter to use when "full" row normalization is used. '
|
||||
"Row magnitude preservation is approximate due to non-linear effects, "
|
||||
"and this determines the rank of that approximation. Higher ranks produce "
|
||||
"larger output files and may slow down evaluation."
|
||||
),
|
||||
)
|
||||
|
||||
winsorization_quantile: float = Field(
|
||||
default=1.0,
|
||||
description=(
|
||||
"The symmetric winsorization to apply to each layer of the per-prompt residuals, "
|
||||
"expressed as the quantile to clamp to (between 0 and 1). Disabled by default. "
|
||||
"Example: winsorization_quantile = 0.95 applies a 90% winsorization."
|
||||
"Example: winsorization_quantile = 0.95 applies a 95% winsorization."
|
||||
),
|
||||
)
|
||||
|
||||
@@ -215,7 +225,7 @@ class Settings(BaseSettings):
|
||||
|
||||
study_checkpoint_dir: str = Field(
|
||||
default="checkpoints",
|
||||
description="Directory to save and load study progress to/from:",
|
||||
description="Directory to save and load study progress to/from.",
|
||||
)
|
||||
|
||||
refusal_markers: list[str] = Field(
|
||||
|
||||
+87
-74
@@ -61,21 +61,15 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
|
||||
Returns "merge", "adapter", or None (if cancelled/invalid).
|
||||
"""
|
||||
|
||||
# Prompt for all PEFT models to ensure user is aware of merge implications
|
||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
# Quantized models need special handling - we must reload the base model
|
||||
# in full precision to merge the LoRA adapters
|
||||
print()
|
||||
print(
|
||||
"[yellow]Model was loaded with quantization. Merging requires reloading the base model.[/]"
|
||||
"Model was loaded with quantization. Merging requires reloading the base model."
|
||||
)
|
||||
print(
|
||||
"[red](!) WARNING: CPU Merging requires dequantizing the entire model to System RAM.[/]"
|
||||
)
|
||||
print("[red] This can lead to SYSTEM FREEZES if you run out of memory.[/]")
|
||||
print(
|
||||
"[yellow] Rule of thumb: You need approx. 3x the parameter count in GB.[/]"
|
||||
"[yellow]WARNING: CPU merging requires dequantizing the entire model to system RAM.[/]"
|
||||
)
|
||||
print("[yellow]This can lead to system freezes if you run out of memory.[/]")
|
||||
|
||||
try:
|
||||
# Estimate memory requirements by loading the model structure on the "meta" device.
|
||||
@@ -94,61 +88,39 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
|
||||
footprint_bytes = meta_model.get_memory_footprint()
|
||||
footprint_gb = footprint_bytes / (1024**3)
|
||||
print(
|
||||
f"[yellow] Estimated net RAM required for model weights (excluding overhead): [bold]~{footprint_gb:.1f} GB[/][/]"
|
||||
f"[yellow]Estimated RAM required (excluding overhead): [bold]~{footprint_gb:.2f} GB[/][/]"
|
||||
)
|
||||
except Exception:
|
||||
# Fallback if meta loading fails (e.g. owing to custom model code
|
||||
# or `bitsandbytes` quantization config issues on the meta device)
|
||||
# or bitsandbytes quantization config issues on the meta device).
|
||||
print(
|
||||
"[yellow] Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
|
||||
"[yellow]Rule of thumb: You need approximately 3x the parameter count in GB RAM.[/]"
|
||||
)
|
||||
print(
|
||||
"[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
|
||||
)
|
||||
print()
|
||||
|
||||
merge_choice = prompt_select(
|
||||
strategy = prompt_select(
|
||||
"How do you want to proceed?",
|
||||
choices=[
|
||||
Choice(
|
||||
title="Merge full model"
|
||||
title="Merge LoRA into full model"
|
||||
+ (
|
||||
""
|
||||
if settings.quantization == QuantizationMethod.NONE
|
||||
else " (reload base model on CPU - requires high RAM)"
|
||||
else " (requires sufficient RAM)"
|
||||
),
|
||||
value="merge",
|
||||
),
|
||||
Choice(
|
||||
title="Save LoRA adapter only (can be merged later with llama.cpp or more RAM)",
|
||||
title="Save LoRA adapter only (can be merged later)",
|
||||
value="adapter",
|
||||
),
|
||||
],
|
||||
)
|
||||
return merge_choice
|
||||
|
||||
|
||||
def save_model(
|
||||
model: Model,
|
||||
save_directory: str,
|
||||
settings: Settings,
|
||||
strategy: str | None = None,
|
||||
) -> None:
|
||||
print("Saving model...")
|
||||
if strategy is None:
|
||||
strategy = obtain_merge_strategy(settings)
|
||||
if strategy is None:
|
||||
print("[yellow]Action cancelled.[/]")
|
||||
return
|
||||
|
||||
if strategy == "adapter":
|
||||
model.model.save_pretrained(save_directory)
|
||||
else:
|
||||
merged_model = model.get_merged_model()
|
||||
merged_model.save_pretrained(save_directory)
|
||||
del merged_model
|
||||
empty_cache()
|
||||
|
||||
model.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
print(f"Model saved to [bold]{save_directory}[/].")
|
||||
return strategy
|
||||
|
||||
|
||||
def run():
|
||||
@@ -249,6 +221,8 @@ def run():
|
||||
# Silence the warning about multivariate TPE being experimental.
|
||||
warnings.filterwarnings("ignore", category=ExperimentalWarning)
|
||||
|
||||
os.makedirs(settings.study_checkpoint_dir, exist_ok=True)
|
||||
|
||||
study_checkpoint_file = os.path.join(
|
||||
settings.study_checkpoint_dir,
|
||||
"".join(
|
||||
@@ -257,7 +231,6 @@ def run():
|
||||
+ ".jsonl",
|
||||
)
|
||||
|
||||
os.makedirs(settings.study_checkpoint_dir, exist_ok=True)
|
||||
lock_obj = JournalFileOpenLock(study_checkpoint_file)
|
||||
backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj)
|
||||
storage = JournalStorage(backend)
|
||||
@@ -268,35 +241,57 @@ def run():
|
||||
existing_study = None
|
||||
|
||||
if existing_study is not None:
|
||||
# A study is in here. Check if it's finished.
|
||||
choices = []
|
||||
|
||||
if existing_study.user_attrs["finished"]:
|
||||
print()
|
||||
print(
|
||||
"[green]You have already processed this model. How would you like to proceed?[/]"
|
||||
(
|
||||
"[green]You have already processed this model.[/] "
|
||||
"You can show the results from the previous run, allowing you to export models or to run additional trials. "
|
||||
"Alternatively, you can ignore the previous run and start from scratch. "
|
||||
"This will delete the checkpoint file and all results from the previous run."
|
||||
)
|
||||
)
|
||||
choices.append(
|
||||
Choice(
|
||||
title="Show the results from the previous run, allowing you to export models, or to run additional trials.",
|
||||
title="Show the results from the previous run",
|
||||
value="continue",
|
||||
)
|
||||
)
|
||||
else:
|
||||
print()
|
||||
print(
|
||||
"[yellow]You have already processed this model, but the run was interrupted. How would you like to proceed?[/]",
|
||||
(
|
||||
"[yellow]You have already processed this model, but the run was interrupted.[/] "
|
||||
"You can continue the previous run from where it stopped. This will override any specified settings. "
|
||||
"Alternatively, you can ignore the previous run and start from scratch. "
|
||||
"This will delete the checkpoint file and all results from the previous run."
|
||||
)
|
||||
)
|
||||
choices.append(
|
||||
Choice(
|
||||
title="Continue the previous run from where it stopped (will override all specified settings).",
|
||||
title="Continue the previous run",
|
||||
value="continue",
|
||||
)
|
||||
)
|
||||
|
||||
choices.append(
|
||||
Choice(
|
||||
title="Ignore the previous run and start from scratch. This will delete the checkpoint file and all results from the previous run.",
|
||||
title="Ignore the previous run and start from scratch",
|
||||
value="restart",
|
||||
)
|
||||
)
|
||||
choice = prompt_select("", choices)
|
||||
|
||||
choices.append(
|
||||
Choice(
|
||||
title="Exit program",
|
||||
value="",
|
||||
)
|
||||
)
|
||||
|
||||
print()
|
||||
choice = prompt_select("How would you like to proceed?", choices)
|
||||
|
||||
if choice == "continue":
|
||||
settings = Settings.model_validate_json(
|
||||
@@ -306,8 +301,7 @@ def run():
|
||||
os.unlink(study_checkpoint_file)
|
||||
backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj)
|
||||
storage = JournalStorage(backend)
|
||||
else:
|
||||
print("Cancelled; exiting.")
|
||||
elif choice is None or choice == "":
|
||||
return
|
||||
|
||||
model = Model(settings)
|
||||
@@ -562,14 +556,14 @@ def run():
|
||||
raise TrialPruned()
|
||||
|
||||
study = optuna.create_study(
|
||||
study_name="heretic",
|
||||
sampler=TPESampler(
|
||||
n_startup_trials=settings.n_startup_trials,
|
||||
n_ei_candidates=128,
|
||||
multivariate=True,
|
||||
),
|
||||
storage=storage,
|
||||
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
||||
storage=storage,
|
||||
study_name="heretic",
|
||||
load_if_exists=True,
|
||||
)
|
||||
|
||||
@@ -582,13 +576,14 @@ def run():
|
||||
|
||||
start_index = trial_index = count_completed_trials()
|
||||
if start_index > 0:
|
||||
print()
|
||||
print("Resuming existing study.")
|
||||
|
||||
try:
|
||||
study.optimize(
|
||||
objective_wrapper, n_trials=settings.n_trials - count_completed_trials()
|
||||
objective_wrapper,
|
||||
n_trials=settings.n_trials - count_completed_trials(),
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
# This additional handler takes care of the small chance that KeyboardInterrupt
|
||||
# is raised just between trials, which wouldn't be caught by the handler
|
||||
@@ -638,14 +633,14 @@ def run():
|
||||
|
||||
choices.append(
|
||||
Choice(
|
||||
title="Continue optimization (run more trials)",
|
||||
title="Run additional trials",
|
||||
value="continue",
|
||||
)
|
||||
)
|
||||
|
||||
choices.append(
|
||||
Choice(
|
||||
title="None (exit program)",
|
||||
title="Exit program",
|
||||
value="",
|
||||
)
|
||||
)
|
||||
@@ -669,18 +664,26 @@ def run():
|
||||
if trial == "continue":
|
||||
while True:
|
||||
try:
|
||||
n_more_trials = int(
|
||||
prompt_text("How many more trials do you want to run?")
|
||||
n_additional_trials = prompt_text(
|
||||
"How many additional trials do you want to run?"
|
||||
)
|
||||
if n_more_trials > 0:
|
||||
if n_additional_trials is None or n_additional_trials == "":
|
||||
n_additional_trials = 0
|
||||
break
|
||||
n_additional_trials = int(n_additional_trials)
|
||||
if n_additional_trials > 0:
|
||||
break
|
||||
print("[red]Please enter a number greater than 0.[/]")
|
||||
except ValueError:
|
||||
print("[red]Invalid input. Please enter a number.[/]")
|
||||
print("[red]Please enter a number.[/]")
|
||||
|
||||
settings.n_trials += n_more_trials
|
||||
if n_additional_trials == 0:
|
||||
continue
|
||||
|
||||
settings.n_trials += n_additional_trials
|
||||
study.set_user_attr("settings", settings.model_dump_json())
|
||||
study.set_user_attr("finished", False)
|
||||
|
||||
try:
|
||||
study.optimize(
|
||||
objective_wrapper,
|
||||
@@ -688,8 +691,10 @@ def run():
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
||||
if count_completed_trials() == settings.n_trials:
|
||||
study.set_user_attr("finished", True)
|
||||
|
||||
break
|
||||
|
||||
elif trial is None or trial == "":
|
||||
@@ -720,14 +725,11 @@ def run():
|
||||
"Save the model to a local folder",
|
||||
"Upload the model to Hugging Face",
|
||||
"Chat with the model",
|
||||
"Nothing (return to trial selection menu)",
|
||||
"Return to the trial selection menu",
|
||||
],
|
||||
)
|
||||
|
||||
if (
|
||||
action is None
|
||||
or action == "Nothing (return to trial selection menu)"
|
||||
):
|
||||
if action is None or action == "Return to the trial selection menu":
|
||||
break
|
||||
|
||||
# All actions are wrapped in a try/except block so that if an error occurs,
|
||||
@@ -740,11 +742,23 @@ def run():
|
||||
if not save_directory:
|
||||
continue
|
||||
|
||||
save_model(
|
||||
model,
|
||||
save_directory,
|
||||
settings,
|
||||
)
|
||||
strategy = obtain_merge_strategy(settings)
|
||||
if strategy is None:
|
||||
continue
|
||||
|
||||
if strategy == "adapter":
|
||||
print("Saving LoRA adapter...")
|
||||
model.model.save_pretrained(save_directory)
|
||||
else:
|
||||
print("Saving merged model...")
|
||||
merged_model = model.get_merged_model()
|
||||
merged_model.save_pretrained(save_directory)
|
||||
del merged_model
|
||||
empty_cache()
|
||||
|
||||
model.tokenizer.save_pretrained(save_directory)
|
||||
|
||||
print(f"Model saved to [bold]{save_directory}[/].")
|
||||
|
||||
case "Upload the model to Hugging Face":
|
||||
# We don't use huggingface_hub.login() because that stores the token on disk,
|
||||
@@ -780,7 +794,6 @@ def run():
|
||||
|
||||
strategy = obtain_merge_strategy(settings)
|
||||
if strategy is None:
|
||||
print("[yellow]Action cancelled.[/]")
|
||||
continue
|
||||
|
||||
if strategy == "adapter":
|
||||
|
||||
+15
-13
@@ -38,7 +38,7 @@ def get_model_class(
|
||||
) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]:
|
||||
configs = PretrainedConfig.get_config_dict(model)
|
||||
|
||||
if any(["vision_config" in x for x in configs]):
|
||||
if any([("vision_config" in config) for config in configs]):
|
||||
return AutoModelForImageTextToText
|
||||
else:
|
||||
return AutoModelForCausalLM
|
||||
@@ -96,9 +96,9 @@ class Model:
|
||||
try:
|
||||
quantization_config = self._get_quantization_config(dtype)
|
||||
|
||||
# Build kwargs, only include quantization_config if it's not None
|
||||
# (some models like gpt-oss have issues with explicit None)
|
||||
extra_kwargs = {}
|
||||
# Only include quantization_config if it's not None
|
||||
# (some models like gpt-oss have issues with explicit None).
|
||||
if quantization_config is not None:
|
||||
extra_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
@@ -134,9 +134,11 @@ class Model:
|
||||
print(f"[red]Failed[/] ({error})")
|
||||
continue
|
||||
|
||||
print("[green]Ok[/]")
|
||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
print("[bold green]Model loaded in 4-bit precision.[/]")
|
||||
print("[green]Ok[/] (quantized to 4-bit precision)")
|
||||
else:
|
||||
print("[green]Ok[/]")
|
||||
|
||||
break
|
||||
|
||||
if self.model is None:
|
||||
@@ -158,7 +160,7 @@ class Model:
|
||||
# Guard against calling this method at the wrong time.
|
||||
assert isinstance(self.model, PreTrainedModel)
|
||||
|
||||
# Always use LoRA adapters for abliteration (faster reload, no weight modification)
|
||||
# Always use LoRA adapters for abliteration (faster reload, no weight modification).
|
||||
# We use the leaf names (e.g. "o_proj") as target modules.
|
||||
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
|
||||
# but this is harmless as we only abliterate the modules we target in `abliterate()`,
|
||||
@@ -181,9 +183,8 @@ class Model:
|
||||
lora_alpha=lora_rank, # Apply adapter at full strength.
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
# Even if we're using AutoModelForImageTextToText, this is still correct, as it is (post-vision)
|
||||
# the same kind of model.
|
||||
# https://github.com/huggingface/peft/blob/622c2821cb0d7897bee53aad7914d42b5fecbf61/src/peft/auto.py#L45
|
||||
# Even if we're using AutoModelForImageTextToText, this is still correct,
|
||||
# as VL models are typically just causal LMs with an added image encoder.
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
@@ -191,9 +192,7 @@ class Model:
|
||||
# so the result is a PeftModel rather than a PeftMixedModel.
|
||||
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))
|
||||
|
||||
print(
|
||||
f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]"
|
||||
)
|
||||
print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})")
|
||||
|
||||
def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
|
||||
"""
|
||||
@@ -637,7 +636,10 @@ class Model:
|
||||
abs_residuals = torch.abs(residuals)
|
||||
# Get the (prompt, layer, 1) quantiles of the (prompt, layer, component) residuals.
|
||||
thresholds = torch.quantile(
|
||||
abs_residuals, self.settings.winsorization_quantile, dim=2, keepdim=True
|
||||
abs_residuals,
|
||||
self.settings.winsorization_quantile,
|
||||
dim=2,
|
||||
keepdim=True,
|
||||
)
|
||||
return torch.clamp(residuals, -thresholds, thresholds)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user