mirror of
https://github.com/p-e-w/heretic.git
synced 2026-06-02 05:03:33 +02:00
feat: Add 4-bit loading + LoRA support for low VRAM optimization (#60)
* Add files via upload * perf: optimize abliteration matrix op (#46) * perf: optimize abliteration matrix op * refactor: comments and var names correspond with arditi * refactor: fix comments and improve var notation * fix: accidental line change and improve comments --------- Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com> * Fix line endings to LF * Add hybrid approach for GPT-OSS compatibility - Check for LoRA adapters before attempting LoRA abliteration - Fall back to direct weight modification for nn.Parameter (GPT-OSS) - Ensures compatibility across all model architectures * Fix projector bug, update print statement, revert README * Revert README changes to match upstream * Fix import sorting for ruff * Fix reload_model for evaluate_model, add type hints and validation * Apply ruff formatting * Replace load_in_4bit with quantization enum * Fix precision loss: use FP32 refusal direction directly * Move r assignment into non-LoRA path * Fix linting: apply ruff formatting * Add auto-merge for LoRA adapters on save/upload * Fix linting: apply ruff formatting * Implement CPU-based merge for 4-bit models with OOM fallback * Remove use_lora flag (LoRA always on), add user prompt for 4-bit export * Fix: PEFT target_modules expects module names without path prefix * Fix linting: apply ruff formatting * Add LoRA fallback and fix quantization_config handling - Add try/except around LoRA initialization with fallback to direct weight modification - Only pass quantization_config when not None (fixes gpt-oss loading) - Use simple forward pass instead of generate() for model test (avoids chat template issues) - Reset non-LoRA models by reloading in reload_model() - Check self.use_lora before accessing LoRA adapters in abliterate() * Add 8-bit quantization support via bitsandbytes - Add BNB_8BIT option to QuantizationMethod enum - Add --load-in-8bit CLI support (auto via pydantic-settings) - Update documentation in config.py and config.default.toml - Useful for mid-range VRAM (12-16 GB) as balance between memory and numeric stability * Improve LoRA merge warning and fix linting * Apply final ruff formatting * Fix CI: apply ruff import sorting * Use tiny model for CI efficiency * Fix import sorting in test_lora.py * Fix formatting in test_lora.py * feat: Show merge warning for all models (requires high RAM) * style: Apply ruff fixes * Fix undefined Style import in main.py * Fix(model): Support MoE/3D tensors and enforce dtype safety in abliterate * Fix(ci): Format model.py with ruff * Fix(main): Remove invalid style argument from prompt_select and unused import * Fix logic errors, memory leak, and redundant merges in main.py * Fix linting and formatting issues (isort, ruff) * chore: Simplify .gitattributes as requested * refactor: Remove defensive try-except around LoRA initialization * chore: Update uv.lock with peft and bitsandbytes * chore: Regenerate uv.lock to include missing peft dependency * style: Fix import sorting (isort) for CI compliance * style: Simplify .gitattributes to single line as requested * Address PR #60 feedback: Remove caching, fix LoRA reload, global LoRA usage, style fixes * Address PR review comments: clarify code, fix quantization, rename method - Add explanatory comments for warning suppression and gc behavior - Remove redundant gc.collect() calls (empty_cache handles it) - Fix output message order (ask merge strategy before 'Uploading...') - Add comment explaining 8-bit quantization doesn't need compute_dtype - Remove extra newline after dtype comment - Add future-proofing note for hybrid layer support (#43) - Remove leftover comment in get_merged_model - Delete test_lora.py (debug script, not a real test) - Add comment explaining needs_reload flag purpose - Extract quantization config into _get_quantization_config() helper - Rename reload_model() to reset_model_for_trial() for clarity - Fix reload_model to respect quantization config (fixes evaluate_model bug) - Remove unused gc import * Restore gc.collect() before empty_cache() for large models * refactor: Remove LoRA fallback remnants, simplify code - Remove use_lora flag (always true since LoRA is always applied) - Remove isinstance(PeftModel) check in get_merged_model() (always true) - Simplify reset_model_for_trial() by removing defensive try/except - Remove redundant gc.collect() calls (empty_cache handles GC) - Remove unused gc import from main.py * Address p-e-w review feedback: rename reset_model, remove loaded_model_name, fix type hints, remove GPT-OSS MoE, update assertion * Restore skip logic for non-LoRA modules and fix 4-bit base_layer.weight access * Remove defensive lora_A check per review - get_layer_modules already filters * Fix try_add: nest component init inside Module check, add assert for unexpected types * Add note about module.weight assumption for type checking * Change 'Reloading model' to 'Resetting model' in logging --------- Co-authored-by: accemlcc <accemlcc@users.noreply.github.com> Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com> Co-authored-by: Hager <Michael.Hager@bruker.com>
This commit is contained in:
@@ -0,0 +1 @@
|
|||||||
|
* text eol=lf
|
||||||
@@ -18,6 +18,10 @@ dtypes = [
|
|||||||
# Device map to pass to Accelerate when loading the model.
|
# Device map to pass to Accelerate when loading the model.
|
||||||
device_map = "auto"
|
device_map = "auto"
|
||||||
|
|
||||||
|
# Quantization method to use when loading the model.
|
||||||
|
# Options: "none" (no quantization), "bnb_4bit" (4-bit quantization using bitsandbytes).
|
||||||
|
quantization = "none"
|
||||||
|
|
||||||
# Memory limits to impose. 0 is usually your first graphics card.
|
# Memory limits to impose. 0 is usually your first graphics card.
|
||||||
# max_memory = {0 = "16GB", "cpu" = "64GB"}
|
# max_memory = {0 = "16GB", "cpu" = "64GB"}
|
||||||
|
|
||||||
|
|||||||
@@ -23,10 +23,12 @@ classifiers = [
|
|||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"accelerate>=1.10.0",
|
"accelerate>=1.10.0",
|
||||||
|
"bitsandbytes>=0.45.0",
|
||||||
"datasets>=4.0.0",
|
"datasets>=4.0.0",
|
||||||
"hf-transfer>=0.1.9",
|
"hf-transfer>=0.1.9",
|
||||||
"huggingface-hub>=0.34.4",
|
"huggingface-hub>=0.34.4",
|
||||||
"optuna>=4.5.0",
|
"optuna>=4.5.0",
|
||||||
|
"peft>=0.14.0",
|
||||||
"pydantic-settings>=2.10.1",
|
"pydantic-settings>=2.10.1",
|
||||||
"questionary>=2.1.1",
|
"questionary>=2.1.1",
|
||||||
"rich>=14.1.0",
|
"rich>=14.1.0",
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -12,6 +13,11 @@ from pydantic_settings import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationMethod(str, Enum):
|
||||||
|
NONE = "none"
|
||||||
|
BNB_4BIT = "bnb_4bit"
|
||||||
|
|
||||||
|
|
||||||
class DatasetSpecification(BaseModel):
|
class DatasetSpecification(BaseModel):
|
||||||
dataset: str = Field(
|
dataset: str = Field(
|
||||||
description="Hugging Face dataset ID, or path to dataset on disk."
|
description="Hugging Face dataset ID, or path to dataset on disk."
|
||||||
@@ -71,6 +77,11 @@ class Settings(BaseSettings):
|
|||||||
description="Whether to trust remote code when loading the model.",
|
description="Whether to trust remote code when loading the model.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
quantization: QuantizationMethod = Field(
|
||||||
|
default=QuantizationMethod.NONE,
|
||||||
|
description="Quantization method to use when loading the model. Options: 'none' (no quantization), 'bnb_4bit' (4-bit quantization using bitsandbytes).",
|
||||||
|
)
|
||||||
|
|
||||||
batch_size: int = Field(
|
batch_size: int = Field(
|
||||||
default=0, # auto
|
default=0, # auto
|
||||||
description="Number of input sequences to process in parallel (0 = auto).",
|
description="Number of input sequences to process in parallel (0 = auto).",
|
||||||
|
|||||||
+109
-13
@@ -31,9 +31,10 @@ from optuna.trial import TrialState
|
|||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from questionary import Choice
|
from questionary import Choice
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
from .analyzer import Analyzer
|
from .analyzer import Analyzer
|
||||||
from .config import Settings
|
from .config import QuantizationMethod, Settings
|
||||||
from .evaluator import Evaluator
|
from .evaluator import Evaluator
|
||||||
from .model import AbliterationParameters, Model
|
from .model import AbliterationParameters, Model
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@@ -50,6 +51,73 @@ from .utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def obtain_merge_strategy(settings: Settings) -> str | None:
|
||||||
|
"""
|
||||||
|
Prompts the user for how to proceed with quantized models.
|
||||||
|
Returns "merge", "adapter", or None (if cancelled/invalid).
|
||||||
|
"""
|
||||||
|
# Prompt for all PEFT models to ensure user is aware of merge implications
|
||||||
|
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||||
|
# Quantized models need special handling - we must reload the base model
|
||||||
|
# in full precision to merge the LoRA adapters
|
||||||
|
print()
|
||||||
|
print(
|
||||||
|
"[yellow]Model was loaded with quantization. Merging requires reloading the base model.[/]"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"[red](!) WARNING: CPU Merging requires dequantizing the entire model to System RAM.[/]"
|
||||||
|
)
|
||||||
|
print("[red] This can lead to SYSTEM FREEZES if you run out of memory.[/]")
|
||||||
|
print(
|
||||||
|
"[yellow] Rule of thumb: You need approx. 3x the parameter count in GB.[/]"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Estimate memory requirements by loading the model structure on the "meta" device.
|
||||||
|
# This doesn't consume actual RAM but allows us to inspect the parameter count/dtype.
|
||||||
|
#
|
||||||
|
# Suppress warnings during meta device loading (e.g., "Some weights were not initialized").
|
||||||
|
# These are expected and harmless since we're only inspecting model structure, not running inference.
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
meta_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
settings.model,
|
||||||
|
device_map="meta",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
footprint_bytes = meta_model.get_memory_footprint()
|
||||||
|
footprint_gb = footprint_bytes / (1024**3)
|
||||||
|
print(
|
||||||
|
f"[yellow] Estimated net RAM required for model weights (excluding overhead): [bold]~{footprint_gb:.1f} GB[/][/]"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Fallback if meta loading fails (e.g. owing to custom model code
|
||||||
|
# or `bitsandbytes` quantization config issues on the meta device)
|
||||||
|
print(
|
||||||
|
"[yellow] Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
|
||||||
|
)
|
||||||
|
print()
|
||||||
|
|
||||||
|
merge_choice = prompt_select(
|
||||||
|
"How do you want to proceed?",
|
||||||
|
choices=[
|
||||||
|
Choice(
|
||||||
|
title="Merge full model (reload base model on CPU - requires high RAM)",
|
||||||
|
value="merge",
|
||||||
|
),
|
||||||
|
Choice(
|
||||||
|
title="Save LoRA adapter only (can be merged later with llama.cpp or more RAM)",
|
||||||
|
value="adapter",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return merge_choice
|
||||||
|
|
||||||
|
# Default for non-quantized models
|
||||||
|
return "merge"
|
||||||
|
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
# Enable expandable segments to reduce memory fragmentation on multi-GPU setups.
|
# Enable expandable segments to reduce memory fragmentation on multi-GPU setups.
|
||||||
if (
|
if (
|
||||||
@@ -220,7 +288,7 @@ def run():
|
|||||||
print()
|
print()
|
||||||
print(f"Loading model [bold]{settings.evaluate_model}[/]...")
|
print(f"Loading model [bold]{settings.evaluate_model}[/]...")
|
||||||
settings.model = settings.evaluate_model
|
settings.model = settings.evaluate_model
|
||||||
model.reload_model()
|
model.reset_model()
|
||||||
print("* Evaluating...")
|
print("* Evaluating...")
|
||||||
evaluator.get_score()
|
evaluator.get_score()
|
||||||
return
|
return
|
||||||
@@ -330,8 +398,8 @@ def run():
|
|||||||
print("* Parameters:")
|
print("* Parameters:")
|
||||||
for name, value in get_trial_parameters(trial).items():
|
for name, value in get_trial_parameters(trial).items():
|
||||||
print(f" * {name} = [bold]{value}[/]")
|
print(f" * {name} = [bold]{value}[/]")
|
||||||
print("* Reloading model...")
|
print("* Resetting model...")
|
||||||
model.reload_model()
|
model.reset_model()
|
||||||
print("* Abliterating...")
|
print("* Abliterating...")
|
||||||
model.abliterate(refusal_directions, direction_index, parameters)
|
model.abliterate(refusal_directions, direction_index, parameters)
|
||||||
print("* Evaluating...")
|
print("* Evaluating...")
|
||||||
@@ -446,8 +514,8 @@ def run():
|
|||||||
print("* Parameters:")
|
print("* Parameters:")
|
||||||
for name, value in get_trial_parameters(trial).items():
|
for name, value in get_trial_parameters(trial).items():
|
||||||
print(f" * {name} = [bold]{value}[/]")
|
print(f" * {name} = [bold]{value}[/]")
|
||||||
print("* Reloading model...")
|
print("* Resetting model...")
|
||||||
model.reload_model()
|
model.reset_model()
|
||||||
print("* Abliterating...")
|
print("* Abliterating...")
|
||||||
model.abliterate(
|
model.abliterate(
|
||||||
refusal_directions,
|
refusal_directions,
|
||||||
@@ -481,7 +549,19 @@ def run():
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
print("Saving model...")
|
print("Saving model...")
|
||||||
model.model.save_pretrained(save_directory)
|
strategy = obtain_merge_strategy(settings)
|
||||||
|
if strategy is None:
|
||||||
|
print("[yellow]Action cancelled.[/]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if strategy == "adapter":
|
||||||
|
model.model.save_pretrained(save_directory)
|
||||||
|
else:
|
||||||
|
merged_model = model.get_merged_model()
|
||||||
|
merged_model.save_pretrained(save_directory)
|
||||||
|
del merged_model
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
model.tokenizer.save_pretrained(save_directory)
|
model.tokenizer.save_pretrained(save_directory)
|
||||||
print(f"Model saved to [bold]{save_directory}[/].")
|
print(f"Model saved to [bold]{save_directory}[/].")
|
||||||
|
|
||||||
@@ -517,13 +597,29 @@ def run():
|
|||||||
)
|
)
|
||||||
private = visibility == "Private"
|
private = visibility == "Private"
|
||||||
|
|
||||||
print("Uploading model...")
|
strategy = obtain_merge_strategy(settings)
|
||||||
|
if strategy is None:
|
||||||
|
print("[yellow]Action cancelled.[/]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if strategy == "adapter":
|
||||||
|
print("Uploading LoRA adapter...")
|
||||||
|
model.model.push_to_hub(
|
||||||
|
repo_id,
|
||||||
|
private=private,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Uploading merged model...")
|
||||||
|
merged_model = model.get_merged_model()
|
||||||
|
merged_model.push_to_hub(
|
||||||
|
repo_id,
|
||||||
|
private=private,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
del merged_model
|
||||||
|
empty_cache()
|
||||||
|
|
||||||
model.model.push_to_hub(
|
|
||||||
repo_id,
|
|
||||||
private=private,
|
|
||||||
token=token,
|
|
||||||
)
|
|
||||||
model.tokenizer.push_to_hub(
|
model.tokenizer.push_to_hub(
|
||||||
repo_id,
|
repo_id,
|
||||||
private=private,
|
private=private,
|
||||||
|
|||||||
+241
-52
@@ -6,22 +6,31 @@ from contextlib import suppress
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from peft import LoraConfig, PeftModel, get_peft_model
|
||||||
from torch import LongTensor, Tensor
|
from torch import LongTensor, Tensor
|
||||||
from torch.nn import ModuleList
|
from torch.nn import ModuleList
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BatchEncoding,
|
BatchEncoding,
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
PreTrainedModel,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
TextStreamer,
|
TextStreamer,
|
||||||
)
|
)
|
||||||
from transformers.generation.utils import GenerateOutput
|
from transformers.generation import (
|
||||||
|
GenerateDecoderOnlyOutput,
|
||||||
|
GenerateEncoderDecoderOutput,
|
||||||
|
)
|
||||||
|
|
||||||
from .config import Settings
|
from .config import QuantizationMethod, Settings
|
||||||
from .utils import batchify, empty_cache, print
|
from .utils import batchify, empty_cache, print
|
||||||
|
|
||||||
|
GenerateOutput = GenerateDecoderOnlyOutput | GenerateEncoderDecoderOutput
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AbliterationParameters:
|
class AbliterationParameters:
|
||||||
@@ -35,6 +44,7 @@ class Model:
|
|||||||
def __init__(self, settings: Settings):
|
def __init__(self, settings: Settings):
|
||||||
self.settings = settings
|
self.settings = settings
|
||||||
self.response_prefix = ""
|
self.response_prefix = ""
|
||||||
|
self.needs_reload = False
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(f"Loading model [bold]{settings.model}[/]...")
|
print(f"Loading model [bold]{settings.model}[/]...")
|
||||||
@@ -68,12 +78,21 @@ class Model:
|
|||||||
print(f"* Trying dtype [bold]{dtype}[/]... ", end="")
|
print(f"* Trying dtype [bold]{dtype}[/]... ", end="")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
quantization_config = self._get_quantization_config(dtype)
|
||||||
|
|
||||||
|
# Build kwargs, only include quantization_config if it's not None
|
||||||
|
# (some models like gpt-oss have issues with explicit None)
|
||||||
|
extra_kwargs = {}
|
||||||
|
if quantization_config is not None:
|
||||||
|
extra_kwargs["quantization_config"] = quantization_config
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
settings.model,
|
settings.model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device_map=settings.device_map,
|
device_map=settings.device_map,
|
||||||
max_memory=self.max_memory,
|
max_memory=self.max_memory,
|
||||||
trust_remote_code=self.trusted_models.get(settings.model),
|
trust_remote_code=self.trusted_models.get(settings.model),
|
||||||
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If we reach this point and the model requires trust_remote_code,
|
# If we reach this point and the model requires trust_remote_code,
|
||||||
@@ -92,103 +111,241 @@ class Model:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
print("[green]Ok[/]")
|
print("[green]Ok[/]")
|
||||||
|
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||||
|
print("[bold green]Model loaded in 4-bit precision.[/]")
|
||||||
break
|
break
|
||||||
|
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
raise Exception("Failed to load model with all configured dtypes.")
|
raise Exception("Failed to load model with all configured dtypes.")
|
||||||
|
|
||||||
|
self._apply_lora()
|
||||||
|
|
||||||
|
# LoRA B matrices are initialized to zero by default in PEFT,
|
||||||
|
# so we don't need to do anything manually.
|
||||||
|
|
||||||
print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers")
|
print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers")
|
||||||
print("* Abliterable components:")
|
print("* Abliterable components:")
|
||||||
for component, matrices in self.get_layer_matrices(0).items():
|
for component, modules in self.get_layer_modules(0).items():
|
||||||
print(
|
print(
|
||||||
f" * [bold]{component}[/]: [bold]{len(matrices)}[/] matrices per layer"
|
f" * [bold]{component}[/]: [bold]{len(modules)}[/] modules per layer"
|
||||||
)
|
)
|
||||||
|
|
||||||
def reload_model(self):
|
def _apply_lora(self):
|
||||||
|
# Always use LoRA adapters for abliteration (faster reload, no weight modification)
|
||||||
|
# We use the leaf names (e.g. "o_proj") as target modules.
|
||||||
|
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
|
||||||
|
# but this is harmless as we only abliterate the modules we target in `abliterate()`,
|
||||||
|
# leaving the others at their default (identity) state.
|
||||||
|
# NOTE: This will need to be updated when hybrid layer support (#43) is merged.
|
||||||
|
target_modules = [
|
||||||
|
comp.split(".")[-1] for comp in self.get_abliterable_components()
|
||||||
|
]
|
||||||
|
peft_config = LoraConfig(
|
||||||
|
r=1, # Rank 1 is sufficient for directional ablation
|
||||||
|
target_modules=target_modules,
|
||||||
|
lora_alpha=1,
|
||||||
|
lora_dropout=0,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
self.model = get_peft_model(self.model, peft_config)
|
||||||
|
print(
|
||||||
|
f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
|
||||||
|
"""
|
||||||
|
Creates quantization config based on settings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype: The dtype string (e.g., "auto", "bfloat16")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BitsAndBytesConfig or None
|
||||||
|
"""
|
||||||
|
if self.settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||||
|
# BitsAndBytesConfig expects a torch.dtype, not a string.
|
||||||
|
if dtype == "auto":
|
||||||
|
compute_dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
compute_dtype = getattr(torch, dtype)
|
||||||
|
|
||||||
|
return BitsAndBytesConfig(
|
||||||
|
load_in_4bit=True,
|
||||||
|
bnb_4bit_compute_dtype=compute_dtype,
|
||||||
|
bnb_4bit_quant_type="nf4",
|
||||||
|
bnb_4bit_use_double_quant=True,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_merged_model(self) -> PreTrainedModel:
|
||||||
|
"""
|
||||||
|
Returns the model with LoRA adapters merged.
|
||||||
|
For quantized models, performs CPU-based merge.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Check if we need special handling for quantized models
|
||||||
|
if self.settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||||
|
# Quantized models need special handling - we must reload the base model
|
||||||
|
# in full precision to merge the LoRA adapters
|
||||||
|
|
||||||
|
# Get the adapter state dict before we do anything
|
||||||
|
adapter_state = {}
|
||||||
|
for name, param in self.model.named_parameters():
|
||||||
|
if "lora_" in name:
|
||||||
|
adapter_state[name] = param.data.clone().cpu()
|
||||||
|
|
||||||
|
# Load base model in full precision on CPU to avoid VRAM issues
|
||||||
|
print("* Loading base model on CPU (this may take a while)...")
|
||||||
|
base_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.settings.model,
|
||||||
|
torch_dtype=self.model.dtype,
|
||||||
|
device_map="cpu",
|
||||||
|
trust_remote_code=self.trusted_models.get(self.settings.model),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply LoRA adapters to the CPU model
|
||||||
|
|
||||||
|
print("* Applying LoRA adapters...")
|
||||||
|
target_modules = self.get_abliterable_components()
|
||||||
|
peft_config = LoraConfig(
|
||||||
|
r=1,
|
||||||
|
target_modules=target_modules,
|
||||||
|
lora_alpha=1,
|
||||||
|
lora_dropout=0,
|
||||||
|
bias="none",
|
||||||
|
task_type="CAUSAL_LM",
|
||||||
|
)
|
||||||
|
peft_model = get_peft_model(base_model, peft_config)
|
||||||
|
|
||||||
|
# Copy the trained adapter weights
|
||||||
|
for name, param in peft_model.named_parameters():
|
||||||
|
if name in adapter_state:
|
||||||
|
param.data = adapter_state[name].to(param.device)
|
||||||
|
|
||||||
|
# Merge and unload
|
||||||
|
print("* Merging LoRA adapters into base model...")
|
||||||
|
merged_model = peft_model.merge_and_unload()
|
||||||
|
return merged_model
|
||||||
|
else:
|
||||||
|
# Non-quantized model - can merge directly
|
||||||
|
print("* Merging LoRA adapters into base model...")
|
||||||
|
merged_model = self.model.merge_and_unload()
|
||||||
|
# merge_and_unload() modifies self.model in-place, destroying LoRA adapters.
|
||||||
|
# Mark for full reload if user switches trials later.
|
||||||
|
self.needs_reload = True
|
||||||
|
return merged_model
|
||||||
|
|
||||||
|
def reset_model(self):
|
||||||
|
"""
|
||||||
|
Resets the model to a clean state for the next trial or evaluation.
|
||||||
|
|
||||||
|
Behavior:
|
||||||
|
- Fast path: If the same model is loaded and doesn't need full reload,
|
||||||
|
resets LoRA adapter weights to zero (identity transformation).
|
||||||
|
- Slow path: If switching models or after merge_and_unload(),
|
||||||
|
performs full model reload with quantization config.
|
||||||
|
"""
|
||||||
|
current_model = getattr(self.model.config, "name_or_path", None)
|
||||||
|
if current_model == self.settings.model and not self.needs_reload:
|
||||||
|
# Reset LoRA adapters to zero (identity transformation)
|
||||||
|
for name, module in self.model.named_modules():
|
||||||
|
if "lora_B" in name and hasattr(module, "weight"):
|
||||||
|
torch.nn.init.zeros_(module.weight)
|
||||||
|
return
|
||||||
|
|
||||||
dtype = self.model.dtype
|
dtype = self.model.dtype
|
||||||
|
|
||||||
# Purge existing model object from memory to make space.
|
# Purge existing model object from memory to make space.
|
||||||
self.model = None
|
self.model = None
|
||||||
empty_cache()
|
empty_cache()
|
||||||
|
|
||||||
|
quantization_config = self._get_quantization_config(str(dtype).split(".")[-1])
|
||||||
|
|
||||||
|
# Build kwargs, only include quantization_config if it's not None
|
||||||
|
extra_kwargs = {}
|
||||||
|
if quantization_config is not None:
|
||||||
|
extra_kwargs["quantization_config"] = quantization_config
|
||||||
|
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
self.settings.model,
|
self.settings.model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device_map=self.settings.device_map,
|
device_map=self.settings.device_map,
|
||||||
max_memory=self.max_memory,
|
max_memory=self.max_memory,
|
||||||
trust_remote_code=self.trusted_models.get(self.settings.model),
|
trust_remote_code=self.trusted_models.get(self.settings.model),
|
||||||
|
**extra_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.trusted_models.get(self.settings.model) is None:
|
self._apply_lora()
|
||||||
self.trusted_models[self.settings.model] = True
|
|
||||||
|
self.needs_reload = False
|
||||||
|
|
||||||
def get_layers(self) -> ModuleList:
|
def get_layers(self) -> ModuleList:
|
||||||
|
model = self.model
|
||||||
|
|
||||||
|
# Unwrap PeftModel (always true after _apply_lora)
|
||||||
|
if isinstance(model, PeftModel):
|
||||||
|
model = model.base_model.model
|
||||||
|
|
||||||
# Most multimodal models.
|
# Most multimodal models.
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
return self.model.model.language_model.layers
|
return model.model.language_model.layers
|
||||||
|
|
||||||
# Text-only models.
|
# Text-only models.
|
||||||
return self.model.model.layers
|
return model.model.layers
|
||||||
|
|
||||||
def get_layer_matrices(self, layer_index: int) -> dict[str, list[Tensor]]:
|
def get_layer_modules(self, layer_index: int) -> dict[str, list[torch.nn.Module]]:
|
||||||
layer = self.get_layers()[layer_index]
|
layer = self.get_layers()[layer_index]
|
||||||
|
|
||||||
matrices = {}
|
modules = {}
|
||||||
|
|
||||||
def try_add(component: str, matrix: Any):
|
def try_add(component: str, module: Any):
|
||||||
# Handle Triton tensors (e.g., from MXFP4 quantization) by extracting
|
# Only add if it's a proper nn.Module (PEFT can wrap these with LoRA)
|
||||||
# the underlying PyTorch tensor via the .data attribute.
|
if isinstance(module, torch.nn.Module):
|
||||||
if hasattr(matrix, "data") and torch.is_tensor(matrix.data):
|
if component not in modules:
|
||||||
matrix = matrix.data
|
modules[component] = []
|
||||||
|
modules[component].append(module)
|
||||||
assert torch.is_tensor(matrix)
|
else:
|
||||||
|
# Assert for unexpected types (catches architecture changes)
|
||||||
if component not in matrices:
|
assert not isinstance(module, torch.Tensor), (
|
||||||
matrices[component] = []
|
f"Unexpected Tensor in {component} - expected nn.Module"
|
||||||
|
)
|
||||||
matrices[component].append(matrix)
|
|
||||||
|
|
||||||
# Exceptions aren't suppressed here, because there is currently
|
# Exceptions aren't suppressed here, because there is currently
|
||||||
# no alternative location for the attention out-projection.
|
# no alternative location for the attention out-projection.
|
||||||
try_add("attn.o_proj", layer.self_attn.o_proj.weight)
|
try_add("attn.o_proj", layer.self_attn.o_proj)
|
||||||
|
|
||||||
# Most dense models.
|
# Most dense models.
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
try_add("mlp.down_proj", layer.mlp.down_proj.weight)
|
try_add("mlp.down_proj", layer.mlp.down_proj)
|
||||||
|
|
||||||
# Some MoE models (e.g. Qwen3).
|
# Some MoE models (e.g. Qwen3).
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
for expert in layer.mlp.experts:
|
for expert in layer.mlp.experts:
|
||||||
try_add("mlp.down_proj", expert.down_proj.weight)
|
try_add("mlp.down_proj", expert.down_proj)
|
||||||
|
|
||||||
# Phi-3.5-MoE (and possibly others).
|
# Phi-3.5-MoE (and possibly others).
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
for expert in layer.block_sparse_moe.experts:
|
for expert in layer.block_sparse_moe.experts:
|
||||||
try_add("mlp.down_proj", expert.w2.weight)
|
try_add("mlp.down_proj", expert.w2)
|
||||||
|
|
||||||
# gpt-oss MoE.
|
|
||||||
with suppress(Exception):
|
|
||||||
# The implementation of gpt-oss in Transformers differs from many other MoE models
|
|
||||||
# in that it stores the down-projections for all experts in a single 3D tensor,
|
|
||||||
# but thanks to PyTorch's broadcasting magic, it all just works anyway.
|
|
||||||
try_add("mlp.down_proj", layer.mlp.experts.down_proj)
|
|
||||||
|
|
||||||
# Granite MoE Hybrid - attention layers with shared_mlp.
|
# Granite MoE Hybrid - attention layers with shared_mlp.
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
try_add("mlp.down_proj", layer.shared_mlp.output_linear.weight)
|
try_add("mlp.down_proj", layer.shared_mlp.output_linear)
|
||||||
|
|
||||||
# Granite MoE Hybrid - MoE layers with experts.
|
# Granite MoE Hybrid - MoE layers with experts.
|
||||||
with suppress(Exception):
|
with suppress(Exception):
|
||||||
for expert in layer.moe.experts:
|
for expert in layer.moe.experts:
|
||||||
try_add("mlp.down_proj", expert.output_linear.weight)
|
try_add("mlp.down_proj", expert.output_linear)
|
||||||
|
|
||||||
# We need at least one MLP down-projection.
|
# We need at least one module across all components for abliteration to work.
|
||||||
assert matrices["mlp.down_proj"]
|
total_modules = sum(len(mods) for mods in modules.values())
|
||||||
|
assert total_modules > 0, "No abliterable modules found in layer"
|
||||||
|
|
||||||
return matrices
|
return modules
|
||||||
|
|
||||||
def get_abliterable_components(self) -> list[str]:
|
def get_abliterable_components(self) -> list[str]:
|
||||||
return list(self.get_layer_matrices(0).keys())
|
return list(self.get_layer_modules(0).keys())
|
||||||
|
|
||||||
def abliterate(
|
def abliterate(
|
||||||
self,
|
self,
|
||||||
@@ -214,7 +371,7 @@ class Model:
|
|||||||
# Note that some implementations of abliteration also orthogonalize
|
# Note that some implementations of abliteration also orthogonalize
|
||||||
# the embedding matrix, but it's unclear if that has any benefits.
|
# the embedding matrix, but it's unclear if that has any benefits.
|
||||||
for layer_index in range(len(self.get_layers())):
|
for layer_index in range(len(self.get_layers())):
|
||||||
for component, matrices in self.get_layer_matrices(layer_index).items():
|
for component, modules in self.get_layer_modules(layer_index).items():
|
||||||
params = parameters[component]
|
params = parameters[component]
|
||||||
|
|
||||||
distance = abs(layer_index - params.max_weight_position)
|
distance = abs(layer_index - params.max_weight_position)
|
||||||
@@ -237,18 +394,50 @@ class Model:
|
|||||||
else:
|
else:
|
||||||
layer_refusal_direction = refusal_direction
|
layer_refusal_direction = refusal_direction
|
||||||
|
|
||||||
# Projects any right-multiplied vector(s) onto the subspace
|
for module in modules:
|
||||||
# spanned by the refusal direction.
|
# LoRA abliteration: delta W = -lambda * v * (v^T W)
|
||||||
projector = torch.outer(
|
# lora_B = -lambda * v
|
||||||
layer_refusal_direction,
|
# lora_A = v^T W
|
||||||
layer_refusal_direction,
|
|
||||||
).to(self.model.dtype)
|
|
||||||
|
|
||||||
for matrix in matrices:
|
# Use the FP32 refusal direction directly (no downcast/upcast)
|
||||||
# Ensure projector is on the same device as the matrix for multi-GPU support.
|
# and move to the correct device.
|
||||||
device_projector = projector.to(matrix.device)
|
# NOTE: Assumes module has .weight (true for Linear layers we target)
|
||||||
# In-place subtraction is safe as we're not using Autograd.
|
v = layer_refusal_direction.to(module.weight.device)
|
||||||
matrix.sub_(weight * (device_projector @ matrix))
|
|
||||||
|
# Get W (dequantize if necessary)
|
||||||
|
# For LoRA-wrapped modules, the quantized weights are in base_layer
|
||||||
|
base_weight = (
|
||||||
|
module.base_layer.weight
|
||||||
|
if hasattr(module, "base_layer")
|
||||||
|
else module.weight
|
||||||
|
)
|
||||||
|
quant_state = getattr(base_weight, "quant_state", None)
|
||||||
|
|
||||||
|
if quant_state is not None:
|
||||||
|
# 4-bit quantization
|
||||||
|
W = bnb.functional.dequantize_4bit(
|
||||||
|
base_weight.data, quant_state
|
||||||
|
).to(torch.float32)
|
||||||
|
else:
|
||||||
|
W = base_weight.to(torch.float32)
|
||||||
|
|
||||||
|
# Calculate lora_A = v^T W
|
||||||
|
# v is (d_out,), W is (d_out, d_in)
|
||||||
|
# v @ W -> (d_in,)
|
||||||
|
lora_A = (v @ W).view(1, -1)
|
||||||
|
|
||||||
|
# Calculate lora_B = -weight * v
|
||||||
|
# v is (d_out,)
|
||||||
|
lora_B = (-weight * v).view(-1, 1)
|
||||||
|
|
||||||
|
# Assign to adapters
|
||||||
|
# We assume the default adapter name "default"
|
||||||
|
module.lora_A["default"].weight.data = lora_A.to(
|
||||||
|
module.lora_A["default"].weight.dtype
|
||||||
|
)
|
||||||
|
module.lora_B["default"].weight.data = lora_B.to(
|
||||||
|
module.lora_B["default"].weight.dtype
|
||||||
|
)
|
||||||
|
|
||||||
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
||||||
return [
|
return [
|
||||||
|
|||||||
Reference in New Issue
Block a user