mirror of
https://github.com/p-e-w/heretic.git
synced 2026-06-01 20:58:47 +02:00
feat: Add 4-bit loading + LoRA support for low VRAM optimization (#60)
* Add files via upload * perf: optimize abliteration matrix op (#46) * perf: optimize abliteration matrix op * refactor: comments and var names correspond with arditi * refactor: fix comments and improve var notation * fix: accidental line change and improve comments --------- Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com> * Fix line endings to LF * Add hybrid approach for GPT-OSS compatibility - Check for LoRA adapters before attempting LoRA abliteration - Fall back to direct weight modification for nn.Parameter (GPT-OSS) - Ensures compatibility across all model architectures * Fix projector bug, update print statement, revert README * Revert README changes to match upstream * Fix import sorting for ruff * Fix reload_model for evaluate_model, add type hints and validation * Apply ruff formatting * Replace load_in_4bit with quantization enum * Fix precision loss: use FP32 refusal direction directly * Move r assignment into non-LoRA path * Fix linting: apply ruff formatting * Add auto-merge for LoRA adapters on save/upload * Fix linting: apply ruff formatting * Implement CPU-based merge for 4-bit models with OOM fallback * Remove use_lora flag (LoRA always on), add user prompt for 4-bit export * Fix: PEFT target_modules expects module names without path prefix * Fix linting: apply ruff formatting * Add LoRA fallback and fix quantization_config handling - Add try/except around LoRA initialization with fallback to direct weight modification - Only pass quantization_config when not None (fixes gpt-oss loading) - Use simple forward pass instead of generate() for model test (avoids chat template issues) - Reset non-LoRA models by reloading in reload_model() - Check self.use_lora before accessing LoRA adapters in abliterate() * Add 8-bit quantization support via bitsandbytes - Add BNB_8BIT option to QuantizationMethod enum - Add --load-in-8bit CLI support (auto via pydantic-settings) - Update documentation in config.py and config.default.toml - Useful for mid-range VRAM (12-16 GB) as balance between memory and numeric stability * Improve LoRA merge warning and fix linting * Apply final ruff formatting * Fix CI: apply ruff import sorting * Use tiny model for CI efficiency * Fix import sorting in test_lora.py * Fix formatting in test_lora.py * feat: Show merge warning for all models (requires high RAM) * style: Apply ruff fixes * Fix undefined Style import in main.py * Fix(model): Support MoE/3D tensors and enforce dtype safety in abliterate * Fix(ci): Format model.py with ruff * Fix(main): Remove invalid style argument from prompt_select and unused import * Fix logic errors, memory leak, and redundant merges in main.py * Fix linting and formatting issues (isort, ruff) * chore: Simplify .gitattributes as requested * refactor: Remove defensive try-except around LoRA initialization * chore: Update uv.lock with peft and bitsandbytes * chore: Regenerate uv.lock to include missing peft dependency * style: Fix import sorting (isort) for CI compliance * style: Simplify .gitattributes to single line as requested * Address PR #60 feedback: Remove caching, fix LoRA reload, global LoRA usage, style fixes * Address PR review comments: clarify code, fix quantization, rename method - Add explanatory comments for warning suppression and gc behavior - Remove redundant gc.collect() calls (empty_cache handles it) - Fix output message order (ask merge strategy before 'Uploading...') - Add comment explaining 8-bit quantization doesn't need compute_dtype - Remove extra newline after dtype comment - Add future-proofing note for hybrid layer support (#43) - Remove leftover comment in get_merged_model - Delete test_lora.py (debug script, not a real test) - Add comment explaining needs_reload flag purpose - Extract quantization config into _get_quantization_config() helper - Rename reload_model() to reset_model_for_trial() for clarity - Fix reload_model to respect quantization config (fixes evaluate_model bug) - Remove unused gc import * Restore gc.collect() before empty_cache() for large models * refactor: Remove LoRA fallback remnants, simplify code - Remove use_lora flag (always true since LoRA is always applied) - Remove isinstance(PeftModel) check in get_merged_model() (always true) - Simplify reset_model_for_trial() by removing defensive try/except - Remove redundant gc.collect() calls (empty_cache handles GC) - Remove unused gc import from main.py * Address p-e-w review feedback: rename reset_model, remove loaded_model_name, fix type hints, remove GPT-OSS MoE, update assertion * Restore skip logic for non-LoRA modules and fix 4-bit base_layer.weight access * Remove defensive lora_A check per review - get_layer_modules already filters * Fix try_add: nest component init inside Module check, add assert for unexpected types * Add note about module.weight assumption for type checking * Change 'Reloading model' to 'Resetting model' in logging --------- Co-authored-by: accemlcc <accemlcc@users.noreply.github.com> Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com> Co-authored-by: Hager <Michael.Hager@bruker.com>
This commit is contained in:
@@ -0,0 +1 @@
|
||||
* text eol=lf
|
||||
@@ -18,6 +18,10 @@ dtypes = [
|
||||
# Device map to pass to Accelerate when loading the model.
|
||||
device_map = "auto"
|
||||
|
||||
# Quantization method to use when loading the model.
|
||||
# Options: "none" (no quantization), "bnb_4bit" (4-bit quantization using bitsandbytes).
|
||||
quantization = "none"
|
||||
|
||||
# Memory limits to impose. 0 is usually your first graphics card.
|
||||
# max_memory = {0 = "16GB", "cpu" = "64GB"}
|
||||
|
||||
|
||||
@@ -23,10 +23,12 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"accelerate>=1.10.0",
|
||||
"bitsandbytes>=0.45.0",
|
||||
"datasets>=4.0.0",
|
||||
"hf-transfer>=0.1.9",
|
||||
"huggingface-hub>=0.34.4",
|
||||
"optuna>=4.5.0",
|
||||
"peft>=0.14.0",
|
||||
"pydantic-settings>=2.10.1",
|
||||
"questionary>=2.1.1",
|
||||
"rich>=14.1.0",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -12,6 +13,11 @@ from pydantic_settings import (
|
||||
)
|
||||
|
||||
|
||||
class QuantizationMethod(str, Enum):
|
||||
NONE = "none"
|
||||
BNB_4BIT = "bnb_4bit"
|
||||
|
||||
|
||||
class DatasetSpecification(BaseModel):
|
||||
dataset: str = Field(
|
||||
description="Hugging Face dataset ID, or path to dataset on disk."
|
||||
@@ -71,6 +77,11 @@ class Settings(BaseSettings):
|
||||
description="Whether to trust remote code when loading the model.",
|
||||
)
|
||||
|
||||
quantization: QuantizationMethod = Field(
|
||||
default=QuantizationMethod.NONE,
|
||||
description="Quantization method to use when loading the model. Options: 'none' (no quantization), 'bnb_4bit' (4-bit quantization using bitsandbytes).",
|
||||
)
|
||||
|
||||
batch_size: int = Field(
|
||||
default=0, # auto
|
||||
description="Number of input sequences to process in parallel (0 = auto).",
|
||||
|
||||
+103
-7
@@ -31,9 +31,10 @@ from optuna.trial import TrialState
|
||||
from pydantic import ValidationError
|
||||
from questionary import Choice
|
||||
from rich.traceback import install
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from .analyzer import Analyzer
|
||||
from .config import Settings
|
||||
from .config import QuantizationMethod, Settings
|
||||
from .evaluator import Evaluator
|
||||
from .model import AbliterationParameters, Model
|
||||
from .utils import (
|
||||
@@ -50,6 +51,73 @@ from .utils import (
|
||||
)
|
||||
|
||||
|
||||
def obtain_merge_strategy(settings: Settings) -> str | None:
|
||||
"""
|
||||
Prompts the user for how to proceed with quantized models.
|
||||
Returns "merge", "adapter", or None (if cancelled/invalid).
|
||||
"""
|
||||
# Prompt for all PEFT models to ensure user is aware of merge implications
|
||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
# Quantized models need special handling - we must reload the base model
|
||||
# in full precision to merge the LoRA adapters
|
||||
print()
|
||||
print(
|
||||
"[yellow]Model was loaded with quantization. Merging requires reloading the base model.[/]"
|
||||
)
|
||||
print(
|
||||
"[red](!) WARNING: CPU Merging requires dequantizing the entire model to System RAM.[/]"
|
||||
)
|
||||
print("[red] This can lead to SYSTEM FREEZES if you run out of memory.[/]")
|
||||
print(
|
||||
"[yellow] Rule of thumb: You need approx. 3x the parameter count in GB.[/]"
|
||||
)
|
||||
|
||||
try:
|
||||
# Estimate memory requirements by loading the model structure on the "meta" device.
|
||||
# This doesn't consume actual RAM but allows us to inspect the parameter count/dtype.
|
||||
#
|
||||
# Suppress warnings during meta device loading (e.g., "Some weights were not initialized").
|
||||
# These are expected and harmless since we're only inspecting model structure, not running inference.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
meta_model = AutoModelForCausalLM.from_pretrained(
|
||||
settings.model,
|
||||
device_map="meta",
|
||||
torch_dtype=torch.bfloat16,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
footprint_bytes = meta_model.get_memory_footprint()
|
||||
footprint_gb = footprint_bytes / (1024**3)
|
||||
print(
|
||||
f"[yellow] Estimated net RAM required for model weights (excluding overhead): [bold]~{footprint_gb:.1f} GB[/][/]"
|
||||
)
|
||||
except Exception:
|
||||
# Fallback if meta loading fails (e.g. owing to custom model code
|
||||
# or `bitsandbytes` quantization config issues on the meta device)
|
||||
print(
|
||||
"[yellow] Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
|
||||
)
|
||||
print()
|
||||
|
||||
merge_choice = prompt_select(
|
||||
"How do you want to proceed?",
|
||||
choices=[
|
||||
Choice(
|
||||
title="Merge full model (reload base model on CPU - requires high RAM)",
|
||||
value="merge",
|
||||
),
|
||||
Choice(
|
||||
title="Save LoRA adapter only (can be merged later with llama.cpp or more RAM)",
|
||||
value="adapter",
|
||||
),
|
||||
],
|
||||
)
|
||||
return merge_choice
|
||||
|
||||
# Default for non-quantized models
|
||||
return "merge"
|
||||
|
||||
|
||||
def run():
|
||||
# Enable expandable segments to reduce memory fragmentation on multi-GPU setups.
|
||||
if (
|
||||
@@ -220,7 +288,7 @@ def run():
|
||||
print()
|
||||
print(f"Loading model [bold]{settings.evaluate_model}[/]...")
|
||||
settings.model = settings.evaluate_model
|
||||
model.reload_model()
|
||||
model.reset_model()
|
||||
print("* Evaluating...")
|
||||
evaluator.get_score()
|
||||
return
|
||||
@@ -330,8 +398,8 @@ def run():
|
||||
print("* Parameters:")
|
||||
for name, value in get_trial_parameters(trial).items():
|
||||
print(f" * {name} = [bold]{value}[/]")
|
||||
print("* Reloading model...")
|
||||
model.reload_model()
|
||||
print("* Resetting model...")
|
||||
model.reset_model()
|
||||
print("* Abliterating...")
|
||||
model.abliterate(refusal_directions, direction_index, parameters)
|
||||
print("* Evaluating...")
|
||||
@@ -446,8 +514,8 @@ def run():
|
||||
print("* Parameters:")
|
||||
for name, value in get_trial_parameters(trial).items():
|
||||
print(f" * {name} = [bold]{value}[/]")
|
||||
print("* Reloading model...")
|
||||
model.reload_model()
|
||||
print("* Resetting model...")
|
||||
model.reset_model()
|
||||
print("* Abliterating...")
|
||||
model.abliterate(
|
||||
refusal_directions,
|
||||
@@ -481,7 +549,19 @@ def run():
|
||||
continue
|
||||
|
||||
print("Saving model...")
|
||||
strategy = obtain_merge_strategy(settings)
|
||||
if strategy is None:
|
||||
print("[yellow]Action cancelled.[/]")
|
||||
continue
|
||||
|
||||
if strategy == "adapter":
|
||||
model.model.save_pretrained(save_directory)
|
||||
else:
|
||||
merged_model = model.get_merged_model()
|
||||
merged_model.save_pretrained(save_directory)
|
||||
del merged_model
|
||||
empty_cache()
|
||||
|
||||
model.tokenizer.save_pretrained(save_directory)
|
||||
print(f"Model saved to [bold]{save_directory}[/].")
|
||||
|
||||
@@ -517,13 +597,29 @@ def run():
|
||||
)
|
||||
private = visibility == "Private"
|
||||
|
||||
print("Uploading model...")
|
||||
strategy = obtain_merge_strategy(settings)
|
||||
if strategy is None:
|
||||
print("[yellow]Action cancelled.[/]")
|
||||
continue
|
||||
|
||||
if strategy == "adapter":
|
||||
print("Uploading LoRA adapter...")
|
||||
model.model.push_to_hub(
|
||||
repo_id,
|
||||
private=private,
|
||||
token=token,
|
||||
)
|
||||
else:
|
||||
print("Uploading merged model...")
|
||||
merged_model = model.get_merged_model()
|
||||
merged_model.push_to_hub(
|
||||
repo_id,
|
||||
private=private,
|
||||
token=token,
|
||||
)
|
||||
del merged_model
|
||||
empty_cache()
|
||||
|
||||
model.tokenizer.push_to_hub(
|
||||
repo_id,
|
||||
private=private,
|
||||
|
||||
+241
-52
@@ -6,22 +6,31 @@ from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
from torch import LongTensor, Tensor
|
||||
from torch.nn import ModuleList
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BatchEncoding,
|
||||
BitsAndBytesConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
TextStreamer,
|
||||
)
|
||||
from transformers.generation.utils import GenerateOutput
|
||||
from transformers.generation import (
|
||||
GenerateDecoderOnlyOutput,
|
||||
GenerateEncoderDecoderOutput,
|
||||
)
|
||||
|
||||
from .config import Settings
|
||||
from .config import QuantizationMethod, Settings
|
||||
from .utils import batchify, empty_cache, print
|
||||
|
||||
GenerateOutput = GenerateDecoderOnlyOutput | GenerateEncoderDecoderOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class AbliterationParameters:
|
||||
@@ -35,6 +44,7 @@ class Model:
|
||||
def __init__(self, settings: Settings):
|
||||
self.settings = settings
|
||||
self.response_prefix = ""
|
||||
self.needs_reload = False
|
||||
|
||||
print()
|
||||
print(f"Loading model [bold]{settings.model}[/]...")
|
||||
@@ -68,12 +78,21 @@ class Model:
|
||||
print(f"* Trying dtype [bold]{dtype}[/]... ", end="")
|
||||
|
||||
try:
|
||||
quantization_config = self._get_quantization_config(dtype)
|
||||
|
||||
# Build kwargs, only include quantization_config if it's not None
|
||||
# (some models like gpt-oss have issues with explicit None)
|
||||
extra_kwargs = {}
|
||||
if quantization_config is not None:
|
||||
extra_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
settings.model,
|
||||
dtype=dtype,
|
||||
device_map=settings.device_map,
|
||||
max_memory=self.max_memory,
|
||||
trust_remote_code=self.trusted_models.get(settings.model),
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
# If we reach this point and the model requires trust_remote_code,
|
||||
@@ -92,103 +111,241 @@ class Model:
|
||||
continue
|
||||
|
||||
print("[green]Ok[/]")
|
||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
print("[bold green]Model loaded in 4-bit precision.[/]")
|
||||
break
|
||||
|
||||
if self.model is None:
|
||||
raise Exception("Failed to load model with all configured dtypes.")
|
||||
|
||||
self._apply_lora()
|
||||
|
||||
# LoRA B matrices are initialized to zero by default in PEFT,
|
||||
# so we don't need to do anything manually.
|
||||
|
||||
print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers")
|
||||
print("* Abliterable components:")
|
||||
for component, matrices in self.get_layer_matrices(0).items():
|
||||
for component, modules in self.get_layer_modules(0).items():
|
||||
print(
|
||||
f" * [bold]{component}[/]: [bold]{len(matrices)}[/] matrices per layer"
|
||||
f" * [bold]{component}[/]: [bold]{len(modules)}[/] modules per layer"
|
||||
)
|
||||
|
||||
def reload_model(self):
|
||||
def _apply_lora(self):
|
||||
# Always use LoRA adapters for abliteration (faster reload, no weight modification)
|
||||
# We use the leaf names (e.g. "o_proj") as target modules.
|
||||
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
|
||||
# but this is harmless as we only abliterate the modules we target in `abliterate()`,
|
||||
# leaving the others at their default (identity) state.
|
||||
# NOTE: This will need to be updated when hybrid layer support (#43) is merged.
|
||||
target_modules = [
|
||||
comp.split(".")[-1] for comp in self.get_abliterable_components()
|
||||
]
|
||||
peft_config = LoraConfig(
|
||||
r=1, # Rank 1 is sufficient for directional ablation
|
||||
target_modules=target_modules,
|
||||
lora_alpha=1,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
self.model = get_peft_model(self.model, peft_config)
|
||||
print(
|
||||
f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]"
|
||||
)
|
||||
|
||||
def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
|
||||
"""
|
||||
Creates quantization config based on settings.
|
||||
|
||||
Args:
|
||||
dtype: The dtype string (e.g., "auto", "bfloat16")
|
||||
|
||||
Returns:
|
||||
BitsAndBytesConfig or None
|
||||
"""
|
||||
if self.settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
# BitsAndBytesConfig expects a torch.dtype, not a string.
|
||||
if dtype == "auto":
|
||||
compute_dtype = torch.bfloat16
|
||||
else:
|
||||
compute_dtype = getattr(torch, dtype)
|
||||
|
||||
return BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_compute_dtype=compute_dtype,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_use_double_quant=True,
|
||||
)
|
||||
return None
|
||||
|
||||
def get_merged_model(self) -> PreTrainedModel:
|
||||
"""
|
||||
Returns the model with LoRA adapters merged.
|
||||
For quantized models, performs CPU-based merge.
|
||||
"""
|
||||
|
||||
# Check if we need special handling for quantized models
|
||||
if self.settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
# Quantized models need special handling - we must reload the base model
|
||||
# in full precision to merge the LoRA adapters
|
||||
|
||||
# Get the adapter state dict before we do anything
|
||||
adapter_state = {}
|
||||
for name, param in self.model.named_parameters():
|
||||
if "lora_" in name:
|
||||
adapter_state[name] = param.data.clone().cpu()
|
||||
|
||||
# Load base model in full precision on CPU to avoid VRAM issues
|
||||
print("* Loading base model on CPU (this may take a while)...")
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.settings.model,
|
||||
torch_dtype=self.model.dtype,
|
||||
device_map="cpu",
|
||||
trust_remote_code=self.trusted_models.get(self.settings.model),
|
||||
)
|
||||
|
||||
# Apply LoRA adapters to the CPU model
|
||||
|
||||
print("* Applying LoRA adapters...")
|
||||
target_modules = self.get_abliterable_components()
|
||||
peft_config = LoraConfig(
|
||||
r=1,
|
||||
target_modules=target_modules,
|
||||
lora_alpha=1,
|
||||
lora_dropout=0,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
peft_model = get_peft_model(base_model, peft_config)
|
||||
|
||||
# Copy the trained adapter weights
|
||||
for name, param in peft_model.named_parameters():
|
||||
if name in adapter_state:
|
||||
param.data = adapter_state[name].to(param.device)
|
||||
|
||||
# Merge and unload
|
||||
print("* Merging LoRA adapters into base model...")
|
||||
merged_model = peft_model.merge_and_unload()
|
||||
return merged_model
|
||||
else:
|
||||
# Non-quantized model - can merge directly
|
||||
print("* Merging LoRA adapters into base model...")
|
||||
merged_model = self.model.merge_and_unload()
|
||||
# merge_and_unload() modifies self.model in-place, destroying LoRA adapters.
|
||||
# Mark for full reload if user switches trials later.
|
||||
self.needs_reload = True
|
||||
return merged_model
|
||||
|
||||
def reset_model(self):
|
||||
"""
|
||||
Resets the model to a clean state for the next trial or evaluation.
|
||||
|
||||
Behavior:
|
||||
- Fast path: If the same model is loaded and doesn't need full reload,
|
||||
resets LoRA adapter weights to zero (identity transformation).
|
||||
- Slow path: If switching models or after merge_and_unload(),
|
||||
performs full model reload with quantization config.
|
||||
"""
|
||||
current_model = getattr(self.model.config, "name_or_path", None)
|
||||
if current_model == self.settings.model and not self.needs_reload:
|
||||
# Reset LoRA adapters to zero (identity transformation)
|
||||
for name, module in self.model.named_modules():
|
||||
if "lora_B" in name and hasattr(module, "weight"):
|
||||
torch.nn.init.zeros_(module.weight)
|
||||
return
|
||||
|
||||
dtype = self.model.dtype
|
||||
|
||||
# Purge existing model object from memory to make space.
|
||||
self.model = None
|
||||
empty_cache()
|
||||
|
||||
quantization_config = self._get_quantization_config(str(dtype).split(".")[-1])
|
||||
|
||||
# Build kwargs, only include quantization_config if it's not None
|
||||
extra_kwargs = {}
|
||||
if quantization_config is not None:
|
||||
extra_kwargs["quantization_config"] = quantization_config
|
||||
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.settings.model,
|
||||
dtype=dtype,
|
||||
device_map=self.settings.device_map,
|
||||
max_memory=self.max_memory,
|
||||
trust_remote_code=self.trusted_models.get(self.settings.model),
|
||||
**extra_kwargs,
|
||||
)
|
||||
|
||||
if self.trusted_models.get(self.settings.model) is None:
|
||||
self.trusted_models[self.settings.model] = True
|
||||
self._apply_lora()
|
||||
|
||||
self.needs_reload = False
|
||||
|
||||
def get_layers(self) -> ModuleList:
|
||||
model = self.model
|
||||
|
||||
# Unwrap PeftModel (always true after _apply_lora)
|
||||
if isinstance(model, PeftModel):
|
||||
model = model.base_model.model
|
||||
|
||||
# Most multimodal models.
|
||||
with suppress(Exception):
|
||||
return self.model.model.language_model.layers
|
||||
return model.model.language_model.layers
|
||||
|
||||
# Text-only models.
|
||||
return self.model.model.layers
|
||||
return model.model.layers
|
||||
|
||||
def get_layer_matrices(self, layer_index: int) -> dict[str, list[Tensor]]:
|
||||
def get_layer_modules(self, layer_index: int) -> dict[str, list[torch.nn.Module]]:
|
||||
layer = self.get_layers()[layer_index]
|
||||
|
||||
matrices = {}
|
||||
modules = {}
|
||||
|
||||
def try_add(component: str, matrix: Any):
|
||||
# Handle Triton tensors (e.g., from MXFP4 quantization) by extracting
|
||||
# the underlying PyTorch tensor via the .data attribute.
|
||||
if hasattr(matrix, "data") and torch.is_tensor(matrix.data):
|
||||
matrix = matrix.data
|
||||
|
||||
assert torch.is_tensor(matrix)
|
||||
|
||||
if component not in matrices:
|
||||
matrices[component] = []
|
||||
|
||||
matrices[component].append(matrix)
|
||||
def try_add(component: str, module: Any):
|
||||
# Only add if it's a proper nn.Module (PEFT can wrap these with LoRA)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
if component not in modules:
|
||||
modules[component] = []
|
||||
modules[component].append(module)
|
||||
else:
|
||||
# Assert for unexpected types (catches architecture changes)
|
||||
assert not isinstance(module, torch.Tensor), (
|
||||
f"Unexpected Tensor in {component} - expected nn.Module"
|
||||
)
|
||||
|
||||
# Exceptions aren't suppressed here, because there is currently
|
||||
# no alternative location for the attention out-projection.
|
||||
try_add("attn.o_proj", layer.self_attn.o_proj.weight)
|
||||
try_add("attn.o_proj", layer.self_attn.o_proj)
|
||||
|
||||
# Most dense models.
|
||||
with suppress(Exception):
|
||||
try_add("mlp.down_proj", layer.mlp.down_proj.weight)
|
||||
try_add("mlp.down_proj", layer.mlp.down_proj)
|
||||
|
||||
# Some MoE models (e.g. Qwen3).
|
||||
with suppress(Exception):
|
||||
for expert in layer.mlp.experts:
|
||||
try_add("mlp.down_proj", expert.down_proj.weight)
|
||||
try_add("mlp.down_proj", expert.down_proj)
|
||||
|
||||
# Phi-3.5-MoE (and possibly others).
|
||||
with suppress(Exception):
|
||||
for expert in layer.block_sparse_moe.experts:
|
||||
try_add("mlp.down_proj", expert.w2.weight)
|
||||
|
||||
# gpt-oss MoE.
|
||||
with suppress(Exception):
|
||||
# The implementation of gpt-oss in Transformers differs from many other MoE models
|
||||
# in that it stores the down-projections for all experts in a single 3D tensor,
|
||||
# but thanks to PyTorch's broadcasting magic, it all just works anyway.
|
||||
try_add("mlp.down_proj", layer.mlp.experts.down_proj)
|
||||
try_add("mlp.down_proj", expert.w2)
|
||||
|
||||
# Granite MoE Hybrid - attention layers with shared_mlp.
|
||||
with suppress(Exception):
|
||||
try_add("mlp.down_proj", layer.shared_mlp.output_linear.weight)
|
||||
try_add("mlp.down_proj", layer.shared_mlp.output_linear)
|
||||
|
||||
# Granite MoE Hybrid - MoE layers with experts.
|
||||
with suppress(Exception):
|
||||
for expert in layer.moe.experts:
|
||||
try_add("mlp.down_proj", expert.output_linear.weight)
|
||||
try_add("mlp.down_proj", expert.output_linear)
|
||||
|
||||
# We need at least one MLP down-projection.
|
||||
assert matrices["mlp.down_proj"]
|
||||
# We need at least one module across all components for abliteration to work.
|
||||
total_modules = sum(len(mods) for mods in modules.values())
|
||||
assert total_modules > 0, "No abliterable modules found in layer"
|
||||
|
||||
return matrices
|
||||
return modules
|
||||
|
||||
def get_abliterable_components(self) -> list[str]:
|
||||
return list(self.get_layer_matrices(0).keys())
|
||||
return list(self.get_layer_modules(0).keys())
|
||||
|
||||
def abliterate(
|
||||
self,
|
||||
@@ -214,7 +371,7 @@ class Model:
|
||||
# Note that some implementations of abliteration also orthogonalize
|
||||
# the embedding matrix, but it's unclear if that has any benefits.
|
||||
for layer_index in range(len(self.get_layers())):
|
||||
for component, matrices in self.get_layer_matrices(layer_index).items():
|
||||
for component, modules in self.get_layer_modules(layer_index).items():
|
||||
params = parameters[component]
|
||||
|
||||
distance = abs(layer_index - params.max_weight_position)
|
||||
@@ -237,18 +394,50 @@ class Model:
|
||||
else:
|
||||
layer_refusal_direction = refusal_direction
|
||||
|
||||
# Projects any right-multiplied vector(s) onto the subspace
|
||||
# spanned by the refusal direction.
|
||||
projector = torch.outer(
|
||||
layer_refusal_direction,
|
||||
layer_refusal_direction,
|
||||
).to(self.model.dtype)
|
||||
for module in modules:
|
||||
# LoRA abliteration: delta W = -lambda * v * (v^T W)
|
||||
# lora_B = -lambda * v
|
||||
# lora_A = v^T W
|
||||
|
||||
for matrix in matrices:
|
||||
# Ensure projector is on the same device as the matrix for multi-GPU support.
|
||||
device_projector = projector.to(matrix.device)
|
||||
# In-place subtraction is safe as we're not using Autograd.
|
||||
matrix.sub_(weight * (device_projector @ matrix))
|
||||
# Use the FP32 refusal direction directly (no downcast/upcast)
|
||||
# and move to the correct device.
|
||||
# NOTE: Assumes module has .weight (true for Linear layers we target)
|
||||
v = layer_refusal_direction.to(module.weight.device)
|
||||
|
||||
# Get W (dequantize if necessary)
|
||||
# For LoRA-wrapped modules, the quantized weights are in base_layer
|
||||
base_weight = (
|
||||
module.base_layer.weight
|
||||
if hasattr(module, "base_layer")
|
||||
else module.weight
|
||||
)
|
||||
quant_state = getattr(base_weight, "quant_state", None)
|
||||
|
||||
if quant_state is not None:
|
||||
# 4-bit quantization
|
||||
W = bnb.functional.dequantize_4bit(
|
||||
base_weight.data, quant_state
|
||||
).to(torch.float32)
|
||||
else:
|
||||
W = base_weight.to(torch.float32)
|
||||
|
||||
# Calculate lora_A = v^T W
|
||||
# v is (d_out,), W is (d_out, d_in)
|
||||
# v @ W -> (d_in,)
|
||||
lora_A = (v @ W).view(1, -1)
|
||||
|
||||
# Calculate lora_B = -weight * v
|
||||
# v is (d_out,)
|
||||
lora_B = (-weight * v).view(-1, 1)
|
||||
|
||||
# Assign to adapters
|
||||
# We assume the default adapter name "default"
|
||||
module.lora_A["default"].weight.data = lora_A.to(
|
||||
module.lora_A["default"].weight.dtype
|
||||
)
|
||||
module.lora_B["default"].weight.data = lora_B.to(
|
||||
module.lora_B["default"].weight.dtype
|
||||
)
|
||||
|
||||
def get_chat(self, prompt: str) -> list[dict[str, str]]:
|
||||
return [
|
||||
|
||||
Reference in New Issue
Block a user