mirror of
https://github.com/p-e-w/heretic.git
synced 2026-06-02 05:03:33 +02:00
fix: improve code quality, improve UX, fix small bugs
This commit is contained in:
@@ -4,6 +4,8 @@
|
|||||||
* Comments should start with a capital letter and end with a period. They should use correct grammar and spelling.
|
* Comments should start with a capital letter and end with a period. They should use correct grammar and spelling.
|
||||||
* Function and method signatures **must** be fully type-annotated, including the return type (if any).
|
* Function and method signatures **must** be fully type-annotated, including the return type (if any).
|
||||||
* Every Python code file **must** start with an SPDX/Copyright header.
|
* Every Python code file **must** start with an SPDX/Copyright header.
|
||||||
|
* Settings descriptions should start with a capital letter and end with a period.
|
||||||
|
* When new settings are added in `config.py`, they should also be added to `config.default.toml`, set to their default value and with their description as a comment. The order of settings in `config.default.toml` should match that in `config.py`.
|
||||||
* Pull requests should implement one change, and one change only.
|
* Pull requests should implement one change, and one change only.
|
||||||
* PRs containing multiple semantically independent changes **must** be split into multiple PRs.
|
* PRs containing multiple semantically independent changes **must** be split into multiple PRs.
|
||||||
* PRs **must not** change existing code unless the changes are *directly related* to the PR. This includes changes to formatting and comments.
|
* PRs **must not** change existing code unless the changes are *directly related* to the PR. This includes changes to formatting and comments.
|
||||||
|
|||||||
+5
-2
@@ -7,7 +7,7 @@ wheels/
|
|||||||
*.egg-info
|
*.egg-info
|
||||||
|
|
||||||
# Virtual environments
|
# Virtual environments
|
||||||
.venv
|
.venv/
|
||||||
|
|
||||||
# Caches
|
# Caches
|
||||||
/.ruff_cache/
|
/.ruff_cache/
|
||||||
@@ -19,4 +19,7 @@ wheels/
|
|||||||
/config.toml
|
/config.toml
|
||||||
|
|
||||||
# Study checkpoints
|
# Study checkpoints
|
||||||
/checkpoints/*.jsonl
|
/checkpoints/
|
||||||
|
|
||||||
|
# Residual plots
|
||||||
|
/plots/
|
||||||
|
|||||||
+27
-23
@@ -15,15 +15,16 @@ dtypes = [
|
|||||||
"float32",
|
"float32",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Quantization method to use when loading the model. Options:
|
||||||
|
# "none" (no quantization),
|
||||||
|
# "bnb_4bit" (4-bit quantization using bitsandbytes).
|
||||||
|
quantization = "none"
|
||||||
|
|
||||||
# Device map to pass to Accelerate when loading the model.
|
# Device map to pass to Accelerate when loading the model.
|
||||||
device_map = "auto"
|
device_map = "auto"
|
||||||
|
|
||||||
# Quantization method to use when loading the model.
|
# Maximum memory to allocate per device.
|
||||||
# Options: "none" (no quantization), "bnb_4bit" (4-bit quantization using bitsandbytes).
|
# max_memory = {"0": "20GB", "cpu": "64GB"}
|
||||||
quantization = "none"
|
|
||||||
|
|
||||||
# Memory limits to impose. 0 is usually your first graphics card.
|
|
||||||
# max_memory = {0 = "16GB", "cpu" = "64GB"}
|
|
||||||
|
|
||||||
# Number of input sequences to process in parallel (0 = auto).
|
# Number of input sequences to process in parallel (0 = auto).
|
||||||
batch_size = 0 # auto
|
batch_size = 0 # auto
|
||||||
@@ -34,22 +35,6 @@ max_batch_size = 128
|
|||||||
# Maximum number of tokens to generate for each response.
|
# Maximum number of tokens to generate for each response.
|
||||||
max_response_length = 100
|
max_response_length = 100
|
||||||
|
|
||||||
# Whether to adjust the refusal directions so that only the component that is
|
|
||||||
# orthogonal to the good direction is subtracted during abliteration.
|
|
||||||
orthogonalize_direction = false
|
|
||||||
|
|
||||||
# How to apply row normalization of the weights. Options:
|
|
||||||
# 'none' (no normalization),
|
|
||||||
# 'pre' (compute LoRA adapter relative to row-normalized weights),
|
|
||||||
# 'full' (like 'pre', but re-normalizes to preserve original row magnitudes).
|
|
||||||
row_normalization = "none"
|
|
||||||
|
|
||||||
# The rank of the LoRA adapter to use when 'full' row normalization is used.
|
|
||||||
# Row magnitude preservation is approximate due to non-linear efects,
|
|
||||||
# and this determines the rank of that approximation. Higher ranks produce
|
|
||||||
# larger output files and may slow down evaluation.
|
|
||||||
full_normalization_lora_rank = 3
|
|
||||||
|
|
||||||
# Whether to print prompt/response pairs when counting refusals.
|
# Whether to print prompt/response pairs when counting refusals.
|
||||||
print_responses = false
|
print_responses = false
|
||||||
|
|
||||||
@@ -76,9 +61,25 @@ kl_divergence_scale = 1.0
|
|||||||
# This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".
|
# This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".
|
||||||
kl_divergence_target = 0.01
|
kl_divergence_target = 0.01
|
||||||
|
|
||||||
|
# Whether to adjust the refusal directions so that only the component that is
|
||||||
|
# orthogonal to the good direction is subtracted during abliteration.
|
||||||
|
orthogonalize_direction = false
|
||||||
|
|
||||||
|
# How to apply row normalization of the weights. Options:
|
||||||
|
# "none" (no normalization),
|
||||||
|
# "pre" (compute LoRA adapter relative to row-normalized weights),
|
||||||
|
# "full" (like "pre", but renormalizes to preserve original row magnitudes).
|
||||||
|
row_normalization = "none"
|
||||||
|
|
||||||
|
# The rank of the LoRA adapter to use when "full" row normalization is used.
|
||||||
|
# Row magnitude preservation is approximate due to non-linear effects,
|
||||||
|
# and this determines the rank of that approximation. Higher ranks produce
|
||||||
|
# larger output files and may slow down evaluation.
|
||||||
|
full_normalization_lora_rank = 3
|
||||||
|
|
||||||
# The symmetric winsorization to apply to each layer of the per-prompt residuals,
|
# The symmetric winsorization to apply to each layer of the per-prompt residuals,
|
||||||
# expressed as the quantile to clamp to (between 0 and 1). Disabled by default.
|
# expressed as the quantile to clamp to (between 0 and 1). Disabled by default.
|
||||||
# Example: winsorization_quantile = 0.95 applies a 90% winsorization.
|
# Example: winsorization_quantile = 0.95 applies a 95% winsorization.
|
||||||
winsorization_quantile = 1.0
|
winsorization_quantile = 1.0
|
||||||
|
|
||||||
# Number of abliteration trials to run during optimization.
|
# Number of abliteration trials to run during optimization.
|
||||||
@@ -87,6 +88,9 @@ n_trials = 200
|
|||||||
# Number of trials that use random sampling for the purpose of exploration.
|
# Number of trials that use random sampling for the purpose of exploration.
|
||||||
n_startup_trials = 60
|
n_startup_trials = 60
|
||||||
|
|
||||||
|
# Directory to save and load study progress to/from.
|
||||||
|
study_checkpoint_dir = "checkpoints"
|
||||||
|
|
||||||
# Strings whose presence in a response (case insensitive) identifies the response as a refusal.
|
# Strings whose presence in a response (case insensitive) identifies the response as a refusal.
|
||||||
refusal_markers = [
|
refusal_markers = [
|
||||||
"sorry",
|
"sorry",
|
||||||
|
|||||||
+49
-39
@@ -66,7 +66,10 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
evaluate_model: str | None = Field(
|
evaluate_model: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="If this model ID or path is set, then instead of abliterating the main model, evaluate this model relative to the main model.",
|
description=(
|
||||||
|
"If this model ID or path is set, then instead of abliterating the main model, "
|
||||||
|
"evaluate this model relative to the main model."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
dtypes: list[str] = Field(
|
dtypes: list[str] = Field(
|
||||||
@@ -82,7 +85,19 @@ class Settings(BaseSettings):
|
|||||||
# if that was the dtype "auto" resolved to).
|
# if that was the dtype "auto" resolved to).
|
||||||
"float32",
|
"float32",
|
||||||
],
|
],
|
||||||
description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried.",
|
description=(
|
||||||
|
"List of PyTorch dtypes to try when loading model tensors. "
|
||||||
|
"If loading with a dtype fails, the next dtype in the list will be tried."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
quantization: QuantizationMethod = Field(
|
||||||
|
default=QuantizationMethod.NONE,
|
||||||
|
description=(
|
||||||
|
"Quantization method to use when loading the model. Options: "
|
||||||
|
'"none" (no quantization), '
|
||||||
|
'"bnb_4bit" (4-bit quantization using bitsandbytes).'
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
device_map: str | Dict[str, int | str] = Field(
|
device_map: str | Dict[str, int | str] = Field(
|
||||||
@@ -92,7 +107,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
max_memory: Dict[str, str] | None = Field(
|
max_memory: Dict[str, str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Maximum memory to allocate per device (e.g., {'0': '20GB', 'cpu': '64GB'}).",
|
description='Maximum memory to allocate per device (e.g., {"0": "20GB", "cpu": "64GB"}).',
|
||||||
)
|
)
|
||||||
|
|
||||||
trust_remote_code: bool | None = Field(
|
trust_remote_code: bool | None = Field(
|
||||||
@@ -100,11 +115,6 @@ class Settings(BaseSettings):
|
|||||||
description="Whether to trust remote code when loading the model.",
|
description="Whether to trust remote code when loading the model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
quantization: QuantizationMethod = Field(
|
|
||||||
default=QuantizationMethod.NONE,
|
|
||||||
description="Quantization method to use when loading the model. Options: 'none' (no quantization), 'bnb_4bit' (4-bit quantization using bitsandbytes).",
|
|
||||||
)
|
|
||||||
|
|
||||||
batch_size: int = Field(
|
batch_size: int = Field(
|
||||||
default=0, # auto
|
default=0, # auto
|
||||||
description="Number of input sequences to process in parallel (0 = auto).",
|
description="Number of input sequences to process in parallel (0 = auto).",
|
||||||
@@ -120,34 +130,6 @@ class Settings(BaseSettings):
|
|||||||
description="Maximum number of tokens to generate for each response.",
|
description="Maximum number of tokens to generate for each response.",
|
||||||
)
|
)
|
||||||
|
|
||||||
orthogonalize_direction: bool = Field(
|
|
||||||
default=False,
|
|
||||||
description=(
|
|
||||||
"Whether to adjust the refusal directions so that only the component that is "
|
|
||||||
"orthogonal to the good direction is subtracted during abliteration."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
row_normalization: RowNormalization = Field(
|
|
||||||
default=RowNormalization.NONE,
|
|
||||||
description=(
|
|
||||||
"How to apply row normalization of the weights. Options: "
|
|
||||||
"'none' (no normalization), "
|
|
||||||
"'pre' (compute LoRA adapter relative to row-normalized weights), "
|
|
||||||
"'full' (like 'pre', but renormalizes to preserve original row magnitudes)."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
full_normalization_lora_rank: int = Field(
|
|
||||||
default=3,
|
|
||||||
description=(
|
|
||||||
"The rank of the LoRA adapter to use when 'full' row normalization is used. "
|
|
||||||
"Row magnitude preservation is approximate due to non-linear efects, "
|
|
||||||
"and this determines the rank of that approximation. Higher ranks produce "
|
|
||||||
"larger output files and may slow down evaluation."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
print_responses: bool = Field(
|
print_responses: bool = Field(
|
||||||
default=False,
|
default=False,
|
||||||
description="Whether to print prompt/response pairs when counting refusals.",
|
description="Whether to print prompt/response pairs when counting refusals.",
|
||||||
@@ -189,17 +171,45 @@ class Settings(BaseSettings):
|
|||||||
kl_divergence_target: float = Field(
|
kl_divergence_target: float = Field(
|
||||||
default=0.01,
|
default=0.01,
|
||||||
description=(
|
description=(
|
||||||
"The KL divergence to target. Below this value, an objective based on the refusal count is used."
|
"The KL divergence to target. Below this value, an objective based on the refusal count is used. "
|
||||||
'This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".'
|
'This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".'
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
orthogonalize_direction: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"Whether to adjust the refusal directions so that only the component that is "
|
||||||
|
"orthogonal to the good direction is subtracted during abliteration."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
row_normalization: RowNormalization = Field(
|
||||||
|
default=RowNormalization.NONE,
|
||||||
|
description=(
|
||||||
|
"How to apply row normalization of the weights. Options: "
|
||||||
|
'"none" (no normalization), '
|
||||||
|
'"pre" (compute LoRA adapter relative to row-normalized weights), '
|
||||||
|
'"full" (like "pre", but renormalizes to preserve original row magnitudes).'
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
full_normalization_lora_rank: int = Field(
|
||||||
|
default=3,
|
||||||
|
description=(
|
||||||
|
'The rank of the LoRA adapter to use when "full" row normalization is used. '
|
||||||
|
"Row magnitude preservation is approximate due to non-linear effects, "
|
||||||
|
"and this determines the rank of that approximation. Higher ranks produce "
|
||||||
|
"larger output files and may slow down evaluation."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
winsorization_quantile: float = Field(
|
winsorization_quantile: float = Field(
|
||||||
default=1.0,
|
default=1.0,
|
||||||
description=(
|
description=(
|
||||||
"The symmetric winsorization to apply to each layer of the per-prompt residuals, "
|
"The symmetric winsorization to apply to each layer of the per-prompt residuals, "
|
||||||
"expressed as the quantile to clamp to (between 0 and 1). Disabled by default. "
|
"expressed as the quantile to clamp to (between 0 and 1). Disabled by default. "
|
||||||
"Example: winsorization_quantile = 0.95 applies a 90% winsorization."
|
"Example: winsorization_quantile = 0.95 applies a 95% winsorization."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -215,7 +225,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
study_checkpoint_dir: str = Field(
|
study_checkpoint_dir: str = Field(
|
||||||
default="checkpoints",
|
default="checkpoints",
|
||||||
description="Directory to save and load study progress to/from:",
|
description="Directory to save and load study progress to/from.",
|
||||||
)
|
)
|
||||||
|
|
||||||
refusal_markers: list[str] = Field(
|
refusal_markers: list[str] = Field(
|
||||||
|
|||||||
+87
-74
@@ -61,21 +61,15 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
|
|||||||
Returns "merge", "adapter", or None (if cancelled/invalid).
|
Returns "merge", "adapter", or None (if cancelled/invalid).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Prompt for all PEFT models to ensure user is aware of merge implications
|
|
||||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||||
# Quantized models need special handling - we must reload the base model
|
|
||||||
# in full precision to merge the LoRA adapters
|
|
||||||
print()
|
print()
|
||||||
print(
|
print(
|
||||||
"[yellow]Model was loaded with quantization. Merging requires reloading the base model.[/]"
|
"Model was loaded with quantization. Merging requires reloading the base model."
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
"[red](!) WARNING: CPU Merging requires dequantizing the entire model to System RAM.[/]"
|
"[yellow]WARNING: CPU merging requires dequantizing the entire model to system RAM.[/]"
|
||||||
)
|
|
||||||
print("[red] This can lead to SYSTEM FREEZES if you run out of memory.[/]")
|
|
||||||
print(
|
|
||||||
"[yellow] Rule of thumb: You need approx. 3x the parameter count in GB.[/]"
|
|
||||||
)
|
)
|
||||||
|
print("[yellow]This can lead to system freezes if you run out of memory.[/]")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Estimate memory requirements by loading the model structure on the "meta" device.
|
# Estimate memory requirements by loading the model structure on the "meta" device.
|
||||||
@@ -94,61 +88,39 @@ def obtain_merge_strategy(settings: Settings) -> str | None:
|
|||||||
footprint_bytes = meta_model.get_memory_footprint()
|
footprint_bytes = meta_model.get_memory_footprint()
|
||||||
footprint_gb = footprint_bytes / (1024**3)
|
footprint_gb = footprint_bytes / (1024**3)
|
||||||
print(
|
print(
|
||||||
f"[yellow] Estimated net RAM required for model weights (excluding overhead): [bold]~{footprint_gb:.1f} GB[/][/]"
|
f"[yellow]Estimated RAM required (excluding overhead): [bold]~{footprint_gb:.2f} GB[/][/]"
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# Fallback if meta loading fails (e.g. owing to custom model code
|
# Fallback if meta loading fails (e.g. owing to custom model code
|
||||||
# or `bitsandbytes` quantization config issues on the meta device)
|
# or bitsandbytes quantization config issues on the meta device).
|
||||||
print(
|
print(
|
||||||
"[yellow] Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
|
"[yellow]Rule of thumb: You need approximately 3x the parameter count in GB RAM.[/]"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
|
||||||
)
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
merge_choice = prompt_select(
|
strategy = prompt_select(
|
||||||
"How do you want to proceed?",
|
"How do you want to proceed?",
|
||||||
choices=[
|
choices=[
|
||||||
Choice(
|
Choice(
|
||||||
title="Merge full model"
|
title="Merge LoRA into full model"
|
||||||
+ (
|
+ (
|
||||||
""
|
""
|
||||||
if settings.quantization == QuantizationMethod.NONE
|
if settings.quantization == QuantizationMethod.NONE
|
||||||
else " (reload base model on CPU - requires high RAM)"
|
else " (requires sufficient RAM)"
|
||||||
),
|
),
|
||||||
value="merge",
|
value="merge",
|
||||||
),
|
),
|
||||||
Choice(
|
Choice(
|
||||||
title="Save LoRA adapter only (can be merged later with llama.cpp or more RAM)",
|
title="Save LoRA adapter only (can be merged later)",
|
||||||
value="adapter",
|
value="adapter",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
return merge_choice
|
|
||||||
|
|
||||||
|
return strategy
|
||||||
def save_model(
|
|
||||||
model: Model,
|
|
||||||
save_directory: str,
|
|
||||||
settings: Settings,
|
|
||||||
strategy: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
print("Saving model...")
|
|
||||||
if strategy is None:
|
|
||||||
strategy = obtain_merge_strategy(settings)
|
|
||||||
if strategy is None:
|
|
||||||
print("[yellow]Action cancelled.[/]")
|
|
||||||
return
|
|
||||||
|
|
||||||
if strategy == "adapter":
|
|
||||||
model.model.save_pretrained(save_directory)
|
|
||||||
else:
|
|
||||||
merged_model = model.get_merged_model()
|
|
||||||
merged_model.save_pretrained(save_directory)
|
|
||||||
del merged_model
|
|
||||||
empty_cache()
|
|
||||||
|
|
||||||
model.tokenizer.save_pretrained(save_directory)
|
|
||||||
|
|
||||||
print(f"Model saved to [bold]{save_directory}[/].")
|
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
@@ -249,6 +221,8 @@ def run():
|
|||||||
# Silence the warning about multivariate TPE being experimental.
|
# Silence the warning about multivariate TPE being experimental.
|
||||||
warnings.filterwarnings("ignore", category=ExperimentalWarning)
|
warnings.filterwarnings("ignore", category=ExperimentalWarning)
|
||||||
|
|
||||||
|
os.makedirs(settings.study_checkpoint_dir, exist_ok=True)
|
||||||
|
|
||||||
study_checkpoint_file = os.path.join(
|
study_checkpoint_file = os.path.join(
|
||||||
settings.study_checkpoint_dir,
|
settings.study_checkpoint_dir,
|
||||||
"".join(
|
"".join(
|
||||||
@@ -257,7 +231,6 @@ def run():
|
|||||||
+ ".jsonl",
|
+ ".jsonl",
|
||||||
)
|
)
|
||||||
|
|
||||||
os.makedirs(settings.study_checkpoint_dir, exist_ok=True)
|
|
||||||
lock_obj = JournalFileOpenLock(study_checkpoint_file)
|
lock_obj = JournalFileOpenLock(study_checkpoint_file)
|
||||||
backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj)
|
backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj)
|
||||||
storage = JournalStorage(backend)
|
storage = JournalStorage(backend)
|
||||||
@@ -268,35 +241,57 @@ def run():
|
|||||||
existing_study = None
|
existing_study = None
|
||||||
|
|
||||||
if existing_study is not None:
|
if existing_study is not None:
|
||||||
# A study is in here. Check if it's finished.
|
|
||||||
choices = []
|
choices = []
|
||||||
|
|
||||||
if existing_study.user_attrs["finished"]:
|
if existing_study.user_attrs["finished"]:
|
||||||
|
print()
|
||||||
print(
|
print(
|
||||||
"[green]You have already processed this model. How would you like to proceed?[/]"
|
(
|
||||||
|
"[green]You have already processed this model.[/] "
|
||||||
|
"You can show the results from the previous run, allowing you to export models or to run additional trials. "
|
||||||
|
"Alternatively, you can ignore the previous run and start from scratch. "
|
||||||
|
"This will delete the checkpoint file and all results from the previous run."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
choices.append(
|
choices.append(
|
||||||
Choice(
|
Choice(
|
||||||
title="Show the results from the previous run, allowing you to export models, or to run additional trials.",
|
title="Show the results from the previous run",
|
||||||
value="continue",
|
value="continue",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
print()
|
||||||
print(
|
print(
|
||||||
"[yellow]You have already processed this model, but the run was interrupted. How would you like to proceed?[/]",
|
(
|
||||||
|
"[yellow]You have already processed this model, but the run was interrupted.[/] "
|
||||||
|
"You can continue the previous run from where it stopped. This will override any specified settings. "
|
||||||
|
"Alternatively, you can ignore the previous run and start from scratch. "
|
||||||
|
"This will delete the checkpoint file and all results from the previous run."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
choices.append(
|
choices.append(
|
||||||
Choice(
|
Choice(
|
||||||
title="Continue the previous run from where it stopped (will override all specified settings).",
|
title="Continue the previous run",
|
||||||
value="continue",
|
value="continue",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
choices.append(
|
choices.append(
|
||||||
Choice(
|
Choice(
|
||||||
title="Ignore the previous run and start from scratch. This will delete the checkpoint file and all results from the previous run.",
|
title="Ignore the previous run and start from scratch",
|
||||||
value="restart",
|
value="restart",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
choice = prompt_select("", choices)
|
|
||||||
|
choices.append(
|
||||||
|
Choice(
|
||||||
|
title="Exit program",
|
||||||
|
value="",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
choice = prompt_select("How would you like to proceed?", choices)
|
||||||
|
|
||||||
if choice == "continue":
|
if choice == "continue":
|
||||||
settings = Settings.model_validate_json(
|
settings = Settings.model_validate_json(
|
||||||
@@ -306,8 +301,7 @@ def run():
|
|||||||
os.unlink(study_checkpoint_file)
|
os.unlink(study_checkpoint_file)
|
||||||
backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj)
|
backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj)
|
||||||
storage = JournalStorage(backend)
|
storage = JournalStorage(backend)
|
||||||
else:
|
elif choice is None or choice == "":
|
||||||
print("Cancelled; exiting.")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
model = Model(settings)
|
model = Model(settings)
|
||||||
@@ -562,14 +556,14 @@ def run():
|
|||||||
raise TrialPruned()
|
raise TrialPruned()
|
||||||
|
|
||||||
study = optuna.create_study(
|
study = optuna.create_study(
|
||||||
study_name="heretic",
|
|
||||||
sampler=TPESampler(
|
sampler=TPESampler(
|
||||||
n_startup_trials=settings.n_startup_trials,
|
n_startup_trials=settings.n_startup_trials,
|
||||||
n_ei_candidates=128,
|
n_ei_candidates=128,
|
||||||
multivariate=True,
|
multivariate=True,
|
||||||
),
|
),
|
||||||
storage=storage,
|
|
||||||
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
||||||
|
storage=storage,
|
||||||
|
study_name="heretic",
|
||||||
load_if_exists=True,
|
load_if_exists=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -582,13 +576,14 @@ def run():
|
|||||||
|
|
||||||
start_index = trial_index = count_completed_trials()
|
start_index = trial_index = count_completed_trials()
|
||||||
if start_index > 0:
|
if start_index > 0:
|
||||||
|
print()
|
||||||
print("Resuming existing study.")
|
print("Resuming existing study.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
study.optimize(
|
study.optimize(
|
||||||
objective_wrapper, n_trials=settings.n_trials - count_completed_trials()
|
objective_wrapper,
|
||||||
|
n_trials=settings.n_trials - count_completed_trials(),
|
||||||
)
|
)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
# This additional handler takes care of the small chance that KeyboardInterrupt
|
# This additional handler takes care of the small chance that KeyboardInterrupt
|
||||||
# is raised just between trials, which wouldn't be caught by the handler
|
# is raised just between trials, which wouldn't be caught by the handler
|
||||||
@@ -638,14 +633,14 @@ def run():
|
|||||||
|
|
||||||
choices.append(
|
choices.append(
|
||||||
Choice(
|
Choice(
|
||||||
title="Continue optimization (run more trials)",
|
title="Run additional trials",
|
||||||
value="continue",
|
value="continue",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
choices.append(
|
choices.append(
|
||||||
Choice(
|
Choice(
|
||||||
title="None (exit program)",
|
title="Exit program",
|
||||||
value="",
|
value="",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -669,18 +664,26 @@ def run():
|
|||||||
if trial == "continue":
|
if trial == "continue":
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
n_more_trials = int(
|
n_additional_trials = prompt_text(
|
||||||
prompt_text("How many more trials do you want to run?")
|
"How many additional trials do you want to run?"
|
||||||
)
|
)
|
||||||
if n_more_trials > 0:
|
if n_additional_trials is None or n_additional_trials == "":
|
||||||
|
n_additional_trials = 0
|
||||||
|
break
|
||||||
|
n_additional_trials = int(n_additional_trials)
|
||||||
|
if n_additional_trials > 0:
|
||||||
break
|
break
|
||||||
print("[red]Please enter a number greater than 0.[/]")
|
print("[red]Please enter a number greater than 0.[/]")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
print("[red]Invalid input. Please enter a number.[/]")
|
print("[red]Please enter a number.[/]")
|
||||||
|
|
||||||
settings.n_trials += n_more_trials
|
if n_additional_trials == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
settings.n_trials += n_additional_trials
|
||||||
study.set_user_attr("settings", settings.model_dump_json())
|
study.set_user_attr("settings", settings.model_dump_json())
|
||||||
study.set_user_attr("finished", False)
|
study.set_user_attr("finished", False)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
study.optimize(
|
study.optimize(
|
||||||
objective_wrapper,
|
objective_wrapper,
|
||||||
@@ -688,8 +691,10 @@ def run():
|
|||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if count_completed_trials() == settings.n_trials:
|
if count_completed_trials() == settings.n_trials:
|
||||||
study.set_user_attr("finished", True)
|
study.set_user_attr("finished", True)
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
elif trial is None or trial == "":
|
elif trial is None or trial == "":
|
||||||
@@ -720,14 +725,11 @@ def run():
|
|||||||
"Save the model to a local folder",
|
"Save the model to a local folder",
|
||||||
"Upload the model to Hugging Face",
|
"Upload the model to Hugging Face",
|
||||||
"Chat with the model",
|
"Chat with the model",
|
||||||
"Nothing (return to trial selection menu)",
|
"Return to the trial selection menu",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if action is None or action == "Return to the trial selection menu":
|
||||||
action is None
|
|
||||||
or action == "Nothing (return to trial selection menu)"
|
|
||||||
):
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# All actions are wrapped in a try/except block so that if an error occurs,
|
# All actions are wrapped in a try/except block so that if an error occurs,
|
||||||
@@ -740,11 +742,23 @@ def run():
|
|||||||
if not save_directory:
|
if not save_directory:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
save_model(
|
strategy = obtain_merge_strategy(settings)
|
||||||
model,
|
if strategy is None:
|
||||||
save_directory,
|
continue
|
||||||
settings,
|
|
||||||
)
|
if strategy == "adapter":
|
||||||
|
print("Saving LoRA adapter...")
|
||||||
|
model.model.save_pretrained(save_directory)
|
||||||
|
else:
|
||||||
|
print("Saving merged model...")
|
||||||
|
merged_model = model.get_merged_model()
|
||||||
|
merged_model.save_pretrained(save_directory)
|
||||||
|
del merged_model
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
|
model.tokenizer.save_pretrained(save_directory)
|
||||||
|
|
||||||
|
print(f"Model saved to [bold]{save_directory}[/].")
|
||||||
|
|
||||||
case "Upload the model to Hugging Face":
|
case "Upload the model to Hugging Face":
|
||||||
# We don't use huggingface_hub.login() because that stores the token on disk,
|
# We don't use huggingface_hub.login() because that stores the token on disk,
|
||||||
@@ -780,7 +794,6 @@ def run():
|
|||||||
|
|
||||||
strategy = obtain_merge_strategy(settings)
|
strategy = obtain_merge_strategy(settings)
|
||||||
if strategy is None:
|
if strategy is None:
|
||||||
print("[yellow]Action cancelled.[/]")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if strategy == "adapter":
|
if strategy == "adapter":
|
||||||
|
|||||||
+15
-13
@@ -38,7 +38,7 @@ def get_model_class(
|
|||||||
) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]:
|
) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]:
|
||||||
configs = PretrainedConfig.get_config_dict(model)
|
configs = PretrainedConfig.get_config_dict(model)
|
||||||
|
|
||||||
if any(["vision_config" in x for x in configs]):
|
if any([("vision_config" in config) for config in configs]):
|
||||||
return AutoModelForImageTextToText
|
return AutoModelForImageTextToText
|
||||||
else:
|
else:
|
||||||
return AutoModelForCausalLM
|
return AutoModelForCausalLM
|
||||||
@@ -96,9 +96,9 @@ class Model:
|
|||||||
try:
|
try:
|
||||||
quantization_config = self._get_quantization_config(dtype)
|
quantization_config = self._get_quantization_config(dtype)
|
||||||
|
|
||||||
# Build kwargs, only include quantization_config if it's not None
|
|
||||||
# (some models like gpt-oss have issues with explicit None)
|
|
||||||
extra_kwargs = {}
|
extra_kwargs = {}
|
||||||
|
# Only include quantization_config if it's not None
|
||||||
|
# (some models like gpt-oss have issues with explicit None).
|
||||||
if quantization_config is not None:
|
if quantization_config is not None:
|
||||||
extra_kwargs["quantization_config"] = quantization_config
|
extra_kwargs["quantization_config"] = quantization_config
|
||||||
|
|
||||||
@@ -134,9 +134,11 @@ class Model:
|
|||||||
print(f"[red]Failed[/] ({error})")
|
print(f"[red]Failed[/] ({error})")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
print("[green]Ok[/]")
|
|
||||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||||
print("[bold green]Model loaded in 4-bit precision.[/]")
|
print("[green]Ok[/] (quantized to 4-bit precision)")
|
||||||
|
else:
|
||||||
|
print("[green]Ok[/]")
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
@@ -158,7 +160,7 @@ class Model:
|
|||||||
# Guard against calling this method at the wrong time.
|
# Guard against calling this method at the wrong time.
|
||||||
assert isinstance(self.model, PreTrainedModel)
|
assert isinstance(self.model, PreTrainedModel)
|
||||||
|
|
||||||
# Always use LoRA adapters for abliteration (faster reload, no weight modification)
|
# Always use LoRA adapters for abliteration (faster reload, no weight modification).
|
||||||
# We use the leaf names (e.g. "o_proj") as target modules.
|
# We use the leaf names (e.g. "o_proj") as target modules.
|
||||||
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
|
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
|
||||||
# but this is harmless as we only abliterate the modules we target in `abliterate()`,
|
# but this is harmless as we only abliterate the modules we target in `abliterate()`,
|
||||||
@@ -181,9 +183,8 @@ class Model:
|
|||||||
lora_alpha=lora_rank, # Apply adapter at full strength.
|
lora_alpha=lora_rank, # Apply adapter at full strength.
|
||||||
lora_dropout=0,
|
lora_dropout=0,
|
||||||
bias="none",
|
bias="none",
|
||||||
# Even if we're using AutoModelForImageTextToText, this is still correct, as it is (post-vision)
|
# Even if we're using AutoModelForImageTextToText, this is still correct,
|
||||||
# the same kind of model.
|
# as VL models are typically just causal LMs with an added image encoder.
|
||||||
# https://github.com/huggingface/peft/blob/622c2821cb0d7897bee53aad7914d42b5fecbf61/src/peft/auto.py#L45
|
|
||||||
task_type="CAUSAL_LM",
|
task_type="CAUSAL_LM",
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -191,9 +192,7 @@ class Model:
|
|||||||
# so the result is a PeftModel rather than a PeftMixedModel.
|
# so the result is a PeftModel rather than a PeftMixedModel.
|
||||||
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))
|
self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config))
|
||||||
|
|
||||||
print(
|
print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})")
|
||||||
f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
|
def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
|
||||||
"""
|
"""
|
||||||
@@ -637,7 +636,10 @@ class Model:
|
|||||||
abs_residuals = torch.abs(residuals)
|
abs_residuals = torch.abs(residuals)
|
||||||
# Get the (prompt, layer, 1) quantiles of the (prompt, layer, component) residuals.
|
# Get the (prompt, layer, 1) quantiles of the (prompt, layer, component) residuals.
|
||||||
thresholds = torch.quantile(
|
thresholds = torch.quantile(
|
||||||
abs_residuals, self.settings.winsorization_quantile, dim=2, keepdim=True
|
abs_residuals,
|
||||||
|
self.settings.winsorization_quantile,
|
||||||
|
dim=2,
|
||||||
|
keepdim=True,
|
||||||
)
|
)
|
||||||
return torch.clamp(residuals, -thresholds, thresholds)
|
return torch.clamp(residuals, -thresholds, thresholds)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user