feat: Add 4-bit loading + LoRA support for low VRAM optimization (#60)

* Add files via upload

* perf: optimize abliteration matrix op (#46)

* perf: optimize abliteration matrix op

* refactor: comments and var names correspond with arditi

* refactor: fix comments and improve var notation

* fix: accidental line change and improve comments

---------

Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com>

* Fix line endings to LF

* Add hybrid approach for GPT-OSS compatibility

- Check for LoRA adapters before attempting LoRA abliteration
- Fall back to direct weight modification for nn.Parameter (GPT-OSS)
- Ensures compatibility across all model architectures

* Fix projector bug, update print statement, revert README

* Revert README changes to match upstream

* Fix import sorting for ruff

* Fix reload_model for evaluate_model, add type hints and validation

* Apply ruff formatting

* Replace load_in_4bit with quantization enum

* Fix precision loss: use FP32 refusal direction directly

* Move r assignment into non-LoRA path

* Fix linting: apply ruff formatting

* Add auto-merge for LoRA adapters on save/upload

* Fix linting: apply ruff formatting

* Implement CPU-based merge for 4-bit models with OOM fallback

* Remove use_lora flag (LoRA always on), add user prompt for 4-bit export

* Fix: PEFT target_modules expects module names without path prefix

* Fix linting: apply ruff formatting

* Add LoRA fallback and fix quantization_config handling

- Add try/except around LoRA initialization with fallback to direct weight modification
- Only pass quantization_config when not None (fixes gpt-oss loading)
- Use simple forward pass instead of generate() for model test (avoids chat template issues)
- Reset non-LoRA models by reloading in reload_model()
- Check self.use_lora before accessing LoRA adapters in abliterate()

* Add 8-bit quantization support via bitsandbytes

- Add BNB_8BIT option to QuantizationMethod enum
- Add --load-in-8bit CLI support (auto via pydantic-settings)
- Update documentation in config.py and config.default.toml
- Useful for mid-range VRAM (12-16 GB) as balance between memory and numeric stability

* Improve LoRA merge warning and fix linting

* Apply final ruff formatting

* Fix CI: apply ruff import sorting

* Use tiny model for CI efficiency

* Fix import sorting in test_lora.py

* Fix formatting in test_lora.py

* feat: Show merge warning for all models (requires high RAM)

* style: Apply ruff fixes

* Fix undefined Style import in main.py

* Fix(model): Support MoE/3D tensors and enforce dtype safety in abliterate

* Fix(ci): Format model.py with ruff

* Fix(main): Remove invalid style argument from prompt_select and unused import

* Fix logic errors, memory leak, and redundant merges in main.py

* Fix linting and formatting issues (isort, ruff)

* chore: Simplify .gitattributes as requested

* refactor: Remove defensive try-except around LoRA initialization

* chore: Update uv.lock with peft and bitsandbytes

* chore: Regenerate uv.lock to include missing peft dependency

* style: Fix import sorting (isort) for CI compliance

* style: Simplify .gitattributes to single line as requested

* Address PR #60 feedback: Remove caching, fix LoRA reload, global LoRA usage, style fixes

* Address PR review comments: clarify code, fix quantization, rename method

- Add explanatory comments for warning suppression and gc behavior
- Remove redundant gc.collect() calls (empty_cache handles it)
- Fix output message order (ask merge strategy before 'Uploading...')
- Add comment explaining 8-bit quantization doesn't need compute_dtype
- Remove extra newline after dtype comment
- Add future-proofing note for hybrid layer support (#43)
- Remove leftover comment in get_merged_model
- Delete test_lora.py (debug script, not a real test)
- Add comment explaining needs_reload flag purpose
- Extract quantization config into _get_quantization_config() helper
- Rename reload_model() to reset_model_for_trial() for clarity
- Fix reload_model to respect quantization config (fixes evaluate_model bug)
- Remove unused gc import

* Restore gc.collect() before empty_cache() for large models

* refactor: Remove LoRA fallback remnants, simplify code

- Remove use_lora flag (always true since LoRA is always applied)
- Remove isinstance(PeftModel) check in get_merged_model() (always true)
- Simplify reset_model_for_trial() by removing defensive try/except
- Remove redundant gc.collect() calls (empty_cache handles GC)
- Remove unused gc import from main.py

* Address p-e-w review feedback: rename reset_model, remove loaded_model_name, fix type hints, remove GPT-OSS MoE, update assertion

* Restore skip logic for non-LoRA modules and fix 4-bit base_layer.weight access

* Remove defensive lora_A check per review - get_layer_modules already filters

* Fix try_add: nest component init inside Module check, add assert for unexpected types

* Add note about module.weight assumption for type checking

* Change 'Reloading model' to 'Resetting model' in logging

---------

Co-authored-by: accemlcc <accemlcc@users.noreply.github.com>
Co-authored-by: mad-cat-lon <113548315+mad-cat-lon@users.noreply.github.com>
Co-authored-by: Hager <Michael.Hager@bruker.com>
This commit is contained in:
michaelh
2025-12-14 15:49:09 +01:00
committed by GitHub
parent 9d1734855d
commit 243f821d93
7 changed files with 2346 additions and 1470 deletions
+1
View File
@@ -0,0 +1 @@
* text eol=lf
+4
View File
@@ -18,6 +18,10 @@ dtypes = [
# Device map to pass to Accelerate when loading the model.
device_map = "auto"
# Quantization method to use when loading the model.
# Options: "none" (no quantization), "bnb_4bit" (4-bit quantization using bitsandbytes).
quantization = "none"
# Memory limits to impose. 0 is usually your first graphics card.
# max_memory = {0 = "16GB", "cpu" = "64GB"}
+2
View File
@@ -23,10 +23,12 @@ classifiers = [
]
dependencies = [
"accelerate>=1.10.0",
"bitsandbytes>=0.45.0",
"datasets>=4.0.0",
"hf-transfer>=0.1.9",
"huggingface-hub>=0.34.4",
"optuna>=4.5.0",
"peft>=0.14.0",
"pydantic-settings>=2.10.1",
"questionary>=2.1.1",
"rich>=14.1.0",
+11
View File
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
from enum import Enum
from typing import Dict
from pydantic import BaseModel, Field
@@ -12,6 +13,11 @@ from pydantic_settings import (
)
class QuantizationMethod(str, Enum):
NONE = "none"
BNB_4BIT = "bnb_4bit"
class DatasetSpecification(BaseModel):
dataset: str = Field(
description="Hugging Face dataset ID, or path to dataset on disk."
@@ -71,6 +77,11 @@ class Settings(BaseSettings):
description="Whether to trust remote code when loading the model.",
)
quantization: QuantizationMethod = Field(
default=QuantizationMethod.NONE,
description="Quantization method to use when loading the model. Options: 'none' (no quantization), 'bnb_4bit' (4-bit quantization using bitsandbytes).",
)
batch_size: int = Field(
default=0, # auto
description="Number of input sequences to process in parallel (0 = auto).",
+103 -7
View File
@@ -31,9 +31,10 @@ from optuna.trial import TrialState
from pydantic import ValidationError
from questionary import Choice
from rich.traceback import install
from transformers import AutoModelForCausalLM
from .analyzer import Analyzer
from .config import Settings
from .config import QuantizationMethod, Settings
from .evaluator import Evaluator
from .model import AbliterationParameters, Model
from .utils import (
@@ -50,6 +51,73 @@ from .utils import (
)
def obtain_merge_strategy(settings: Settings) -> str | None:
"""
Prompts the user for how to proceed with quantized models.
Returns "merge", "adapter", or None (if cancelled/invalid).
"""
# Prompt for all PEFT models to ensure user is aware of merge implications
if settings.quantization == QuantizationMethod.BNB_4BIT:
# Quantized models need special handling - we must reload the base model
# in full precision to merge the LoRA adapters
print()
print(
"[yellow]Model was loaded with quantization. Merging requires reloading the base model.[/]"
)
print(
"[red](!) WARNING: CPU Merging requires dequantizing the entire model to System RAM.[/]"
)
print("[red] This can lead to SYSTEM FREEZES if you run out of memory.[/]")
print(
"[yellow] Rule of thumb: You need approx. 3x the parameter count in GB.[/]"
)
try:
# Estimate memory requirements by loading the model structure on the "meta" device.
# This doesn't consume actual RAM but allows us to inspect the parameter count/dtype.
#
# Suppress warnings during meta device loading (e.g., "Some weights were not initialized").
# These are expected and harmless since we're only inspecting model structure, not running inference.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
meta_model = AutoModelForCausalLM.from_pretrained(
settings.model,
device_map="meta",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
footprint_bytes = meta_model.get_memory_footprint()
footprint_gb = footprint_bytes / (1024**3)
print(
f"[yellow] Estimated net RAM required for model weights (excluding overhead): [bold]~{footprint_gb:.1f} GB[/][/]"
)
except Exception:
# Fallback if meta loading fails (e.g. owing to custom model code
# or `bitsandbytes` quantization config issues on the meta device)
print(
"[yellow] Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]"
)
print()
merge_choice = prompt_select(
"How do you want to proceed?",
choices=[
Choice(
title="Merge full model (reload base model on CPU - requires high RAM)",
value="merge",
),
Choice(
title="Save LoRA adapter only (can be merged later with llama.cpp or more RAM)",
value="adapter",
),
],
)
return merge_choice
# Default for non-quantized models
return "merge"
def run():
# Enable expandable segments to reduce memory fragmentation on multi-GPU setups.
if (
@@ -220,7 +288,7 @@ def run():
print()
print(f"Loading model [bold]{settings.evaluate_model}[/]...")
settings.model = settings.evaluate_model
model.reload_model()
model.reset_model()
print("* Evaluating...")
evaluator.get_score()
return
@@ -330,8 +398,8 @@ def run():
print("* Parameters:")
for name, value in get_trial_parameters(trial).items():
print(f" * {name} = [bold]{value}[/]")
print("* Reloading model...")
model.reload_model()
print("* Resetting model...")
model.reset_model()
print("* Abliterating...")
model.abliterate(refusal_directions, direction_index, parameters)
print("* Evaluating...")
@@ -446,8 +514,8 @@ def run():
print("* Parameters:")
for name, value in get_trial_parameters(trial).items():
print(f" * {name} = [bold]{value}[/]")
print("* Reloading model...")
model.reload_model()
print("* Resetting model...")
model.reset_model()
print("* Abliterating...")
model.abliterate(
refusal_directions,
@@ -481,7 +549,19 @@ def run():
continue
print("Saving model...")
strategy = obtain_merge_strategy(settings)
if strategy is None:
print("[yellow]Action cancelled.[/]")
continue
if strategy == "adapter":
model.model.save_pretrained(save_directory)
else:
merged_model = model.get_merged_model()
merged_model.save_pretrained(save_directory)
del merged_model
empty_cache()
model.tokenizer.save_pretrained(save_directory)
print(f"Model saved to [bold]{save_directory}[/].")
@@ -517,13 +597,29 @@ def run():
)
private = visibility == "Private"
print("Uploading model...")
strategy = obtain_merge_strategy(settings)
if strategy is None:
print("[yellow]Action cancelled.[/]")
continue
if strategy == "adapter":
print("Uploading LoRA adapter...")
model.model.push_to_hub(
repo_id,
private=private,
token=token,
)
else:
print("Uploading merged model...")
merged_model = model.get_merged_model()
merged_model.push_to_hub(
repo_id,
private=private,
token=token,
)
del merged_model
empty_cache()
model.tokenizer.push_to_hub(
repo_id,
private=private,
+241 -52
View File
@@ -6,22 +6,31 @@ from contextlib import suppress
from dataclasses import dataclass
from typing import Any
import bitsandbytes as bnb
import torch
import torch.nn.functional as F
from peft import LoraConfig, PeftModel, get_peft_model
from torch import LongTensor, Tensor
from torch.nn import ModuleList
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BatchEncoding,
BitsAndBytesConfig,
PreTrainedModel,
PreTrainedTokenizerBase,
TextStreamer,
)
from transformers.generation.utils import GenerateOutput
from transformers.generation import (
GenerateDecoderOnlyOutput,
GenerateEncoderDecoderOutput,
)
from .config import Settings
from .config import QuantizationMethod, Settings
from .utils import batchify, empty_cache, print
GenerateOutput = GenerateDecoderOnlyOutput | GenerateEncoderDecoderOutput
@dataclass
class AbliterationParameters:
@@ -35,6 +44,7 @@ class Model:
def __init__(self, settings: Settings):
self.settings = settings
self.response_prefix = ""
self.needs_reload = False
print()
print(f"Loading model [bold]{settings.model}[/]...")
@@ -68,12 +78,21 @@ class Model:
print(f"* Trying dtype [bold]{dtype}[/]... ", end="")
try:
quantization_config = self._get_quantization_config(dtype)
# Build kwargs, only include quantization_config if it's not None
# (some models like gpt-oss have issues with explicit None)
extra_kwargs = {}
if quantization_config is not None:
extra_kwargs["quantization_config"] = quantization_config
self.model = AutoModelForCausalLM.from_pretrained(
settings.model,
dtype=dtype,
device_map=settings.device_map,
max_memory=self.max_memory,
trust_remote_code=self.trusted_models.get(settings.model),
**extra_kwargs,
)
# If we reach this point and the model requires trust_remote_code,
@@ -92,103 +111,241 @@ class Model:
continue
print("[green]Ok[/]")
if settings.quantization == QuantizationMethod.BNB_4BIT:
print("[bold green]Model loaded in 4-bit precision.[/]")
break
if self.model is None:
raise Exception("Failed to load model with all configured dtypes.")
self._apply_lora()
# LoRA B matrices are initialized to zero by default in PEFT,
# so we don't need to do anything manually.
print(f"* Transformer model with [bold]{len(self.get_layers())}[/] layers")
print("* Abliterable components:")
for component, matrices in self.get_layer_matrices(0).items():
for component, modules in self.get_layer_modules(0).items():
print(
f" * [bold]{component}[/]: [bold]{len(matrices)}[/] matrices per layer"
f" * [bold]{component}[/]: [bold]{len(modules)}[/] modules per layer"
)
def reload_model(self):
def _apply_lora(self):
# Always use LoRA adapters for abliteration (faster reload, no weight modification)
# We use the leaf names (e.g. "o_proj") as target modules.
# This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"),
# but this is harmless as we only abliterate the modules we target in `abliterate()`,
# leaving the others at their default (identity) state.
# NOTE: This will need to be updated when hybrid layer support (#43) is merged.
target_modules = [
comp.split(".")[-1] for comp in self.get_abliterable_components()
]
peft_config = LoraConfig(
r=1, # Rank 1 is sufficient for directional ablation
target_modules=target_modules,
lora_alpha=1,
lora_dropout=0,
bias="none",
task_type="CAUSAL_LM",
)
self.model = get_peft_model(self.model, peft_config)
print(
f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]"
)
def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None:
"""
Creates quantization config based on settings.
Args:
dtype: The dtype string (e.g., "auto", "bfloat16")
Returns:
BitsAndBytesConfig or None
"""
if self.settings.quantization == QuantizationMethod.BNB_4BIT:
# BitsAndBytesConfig expects a torch.dtype, not a string.
if dtype == "auto":
compute_dtype = torch.bfloat16
else:
compute_dtype = getattr(torch, dtype)
return BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
return None
def get_merged_model(self) -> PreTrainedModel:
"""
Returns the model with LoRA adapters merged.
For quantized models, performs CPU-based merge.
"""
# Check if we need special handling for quantized models
if self.settings.quantization == QuantizationMethod.BNB_4BIT:
# Quantized models need special handling - we must reload the base model
# in full precision to merge the LoRA adapters
# Get the adapter state dict before we do anything
adapter_state = {}
for name, param in self.model.named_parameters():
if "lora_" in name:
adapter_state[name] = param.data.clone().cpu()
# Load base model in full precision on CPU to avoid VRAM issues
print("* Loading base model on CPU (this may take a while)...")
base_model = AutoModelForCausalLM.from_pretrained(
self.settings.model,
torch_dtype=self.model.dtype,
device_map="cpu",
trust_remote_code=self.trusted_models.get(self.settings.model),
)
# Apply LoRA adapters to the CPU model
print("* Applying LoRA adapters...")
target_modules = self.get_abliterable_components()
peft_config = LoraConfig(
r=1,
target_modules=target_modules,
lora_alpha=1,
lora_dropout=0,
bias="none",
task_type="CAUSAL_LM",
)
peft_model = get_peft_model(base_model, peft_config)
# Copy the trained adapter weights
for name, param in peft_model.named_parameters():
if name in adapter_state:
param.data = adapter_state[name].to(param.device)
# Merge and unload
print("* Merging LoRA adapters into base model...")
merged_model = peft_model.merge_and_unload()
return merged_model
else:
# Non-quantized model - can merge directly
print("* Merging LoRA adapters into base model...")
merged_model = self.model.merge_and_unload()
# merge_and_unload() modifies self.model in-place, destroying LoRA adapters.
# Mark for full reload if user switches trials later.
self.needs_reload = True
return merged_model
def reset_model(self):
"""
Resets the model to a clean state for the next trial or evaluation.
Behavior:
- Fast path: If the same model is loaded and doesn't need full reload,
resets LoRA adapter weights to zero (identity transformation).
- Slow path: If switching models or after merge_and_unload(),
performs full model reload with quantization config.
"""
current_model = getattr(self.model.config, "name_or_path", None)
if current_model == self.settings.model and not self.needs_reload:
# Reset LoRA adapters to zero (identity transformation)
for name, module in self.model.named_modules():
if "lora_B" in name and hasattr(module, "weight"):
torch.nn.init.zeros_(module.weight)
return
dtype = self.model.dtype
# Purge existing model object from memory to make space.
self.model = None
empty_cache()
quantization_config = self._get_quantization_config(str(dtype).split(".")[-1])
# Build kwargs, only include quantization_config if it's not None
extra_kwargs = {}
if quantization_config is not None:
extra_kwargs["quantization_config"] = quantization_config
self.model = AutoModelForCausalLM.from_pretrained(
self.settings.model,
dtype=dtype,
device_map=self.settings.device_map,
max_memory=self.max_memory,
trust_remote_code=self.trusted_models.get(self.settings.model),
**extra_kwargs,
)
if self.trusted_models.get(self.settings.model) is None:
self.trusted_models[self.settings.model] = True
self._apply_lora()
self.needs_reload = False
def get_layers(self) -> ModuleList:
model = self.model
# Unwrap PeftModel (always true after _apply_lora)
if isinstance(model, PeftModel):
model = model.base_model.model
# Most multimodal models.
with suppress(Exception):
return self.model.model.language_model.layers
return model.model.language_model.layers
# Text-only models.
return self.model.model.layers
return model.model.layers
def get_layer_matrices(self, layer_index: int) -> dict[str, list[Tensor]]:
def get_layer_modules(self, layer_index: int) -> dict[str, list[torch.nn.Module]]:
layer = self.get_layers()[layer_index]
matrices = {}
modules = {}
def try_add(component: str, matrix: Any):
# Handle Triton tensors (e.g., from MXFP4 quantization) by extracting
# the underlying PyTorch tensor via the .data attribute.
if hasattr(matrix, "data") and torch.is_tensor(matrix.data):
matrix = matrix.data
assert torch.is_tensor(matrix)
if component not in matrices:
matrices[component] = []
matrices[component].append(matrix)
def try_add(component: str, module: Any):
# Only add if it's a proper nn.Module (PEFT can wrap these with LoRA)
if isinstance(module, torch.nn.Module):
if component not in modules:
modules[component] = []
modules[component].append(module)
else:
# Assert for unexpected types (catches architecture changes)
assert not isinstance(module, torch.Tensor), (
f"Unexpected Tensor in {component} - expected nn.Module"
)
# Exceptions aren't suppressed here, because there is currently
# no alternative location for the attention out-projection.
try_add("attn.o_proj", layer.self_attn.o_proj.weight)
try_add("attn.o_proj", layer.self_attn.o_proj)
# Most dense models.
with suppress(Exception):
try_add("mlp.down_proj", layer.mlp.down_proj.weight)
try_add("mlp.down_proj", layer.mlp.down_proj)
# Some MoE models (e.g. Qwen3).
with suppress(Exception):
for expert in layer.mlp.experts:
try_add("mlp.down_proj", expert.down_proj.weight)
try_add("mlp.down_proj", expert.down_proj)
# Phi-3.5-MoE (and possibly others).
with suppress(Exception):
for expert in layer.block_sparse_moe.experts:
try_add("mlp.down_proj", expert.w2.weight)
# gpt-oss MoE.
with suppress(Exception):
# The implementation of gpt-oss in Transformers differs from many other MoE models
# in that it stores the down-projections for all experts in a single 3D tensor,
# but thanks to PyTorch's broadcasting magic, it all just works anyway.
try_add("mlp.down_proj", layer.mlp.experts.down_proj)
try_add("mlp.down_proj", expert.w2)
# Granite MoE Hybrid - attention layers with shared_mlp.
with suppress(Exception):
try_add("mlp.down_proj", layer.shared_mlp.output_linear.weight)
try_add("mlp.down_proj", layer.shared_mlp.output_linear)
# Granite MoE Hybrid - MoE layers with experts.
with suppress(Exception):
for expert in layer.moe.experts:
try_add("mlp.down_proj", expert.output_linear.weight)
try_add("mlp.down_proj", expert.output_linear)
# We need at least one MLP down-projection.
assert matrices["mlp.down_proj"]
# We need at least one module across all components for abliteration to work.
total_modules = sum(len(mods) for mods in modules.values())
assert total_modules > 0, "No abliterable modules found in layer"
return matrices
return modules
def get_abliterable_components(self) -> list[str]:
return list(self.get_layer_matrices(0).keys())
return list(self.get_layer_modules(0).keys())
def abliterate(
self,
@@ -214,7 +371,7 @@ class Model:
# Note that some implementations of abliteration also orthogonalize
# the embedding matrix, but it's unclear if that has any benefits.
for layer_index in range(len(self.get_layers())):
for component, matrices in self.get_layer_matrices(layer_index).items():
for component, modules in self.get_layer_modules(layer_index).items():
params = parameters[component]
distance = abs(layer_index - params.max_weight_position)
@@ -237,18 +394,50 @@ class Model:
else:
layer_refusal_direction = refusal_direction
# Projects any right-multiplied vector(s) onto the subspace
# spanned by the refusal direction.
projector = torch.outer(
layer_refusal_direction,
layer_refusal_direction,
).to(self.model.dtype)
for module in modules:
# LoRA abliteration: delta W = -lambda * v * (v^T W)
# lora_B = -lambda * v
# lora_A = v^T W
for matrix in matrices:
# Ensure projector is on the same device as the matrix for multi-GPU support.
device_projector = projector.to(matrix.device)
# In-place subtraction is safe as we're not using Autograd.
matrix.sub_(weight * (device_projector @ matrix))
# Use the FP32 refusal direction directly (no downcast/upcast)
# and move to the correct device.
# NOTE: Assumes module has .weight (true for Linear layers we target)
v = layer_refusal_direction.to(module.weight.device)
# Get W (dequantize if necessary)
# For LoRA-wrapped modules, the quantized weights are in base_layer
base_weight = (
module.base_layer.weight
if hasattr(module, "base_layer")
else module.weight
)
quant_state = getattr(base_weight, "quant_state", None)
if quant_state is not None:
# 4-bit quantization
W = bnb.functional.dequantize_4bit(
base_weight.data, quant_state
).to(torch.float32)
else:
W = base_weight.to(torch.float32)
# Calculate lora_A = v^T W
# v is (d_out,), W is (d_out, d_in)
# v @ W -> (d_in,)
lora_A = (v @ W).view(1, -1)
# Calculate lora_B = -weight * v
# v is (d_out,)
lora_B = (-weight * v).view(-1, 1)
# Assign to adapters
# We assume the default adapter name "default"
module.lora_A["default"].weight.data = lora_A.to(
module.lora_A["default"].weight.dtype
)
module.lora_B["default"].weight.data = lora_B.to(
module.lora_B["default"].weight.dtype
)
def get_chat(self, prompt: str) -> list[dict[str, str]]:
return [
Generated
+1978 -1405
View File
File diff suppressed because it is too large Load Diff