mirror of
https://github.com/p-e-w/heretic.git
synced 2026-06-02 05:03:33 +02:00
feat: add configurable residual processing to reduce peak VRAM usage (#239)
* refactor residual memory optimizations * formatting * Fixed config.py positioning and default * fixed analyzier declaration in main.py * removing del statements * ruff * small updates * ty moveback ish
This commit is contained in:
@@ -137,6 +137,12 @@ refusal_markers = [
|
||||
# System prompt to use when prompting the model.
|
||||
system_prompt = "You are a helpful assistant."
|
||||
|
||||
# Move intermediate analysis tensors (such as residuals and logprobs)
|
||||
# to CPU memory as soon as possible to reduce peak VRAM usage.
|
||||
# This lowers peak VRAM usage during residual analysis and evaluation,
|
||||
# but may slightly reduce performance due to host/device transfers.
|
||||
offload_outputs_to_cpu = true
|
||||
|
||||
# Dataset of prompts that tend to not result in refusals (used for calculating refusal directions).
|
||||
[good_prompts]
|
||||
dataset = "mlabonne/harmless_alpaca"
|
||||
|
||||
@@ -397,6 +397,14 @@ class Settings(BaseSettings):
|
||||
description="System prompt to use when prompting the model.",
|
||||
)
|
||||
|
||||
offload_outputs_to_cpu: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether to move intermediate analysis tensors (such as residuals and logprobs) "
|
||||
"to CPU memory as soon as possible to reduce peak VRAM usage."
|
||||
),
|
||||
)
|
||||
|
||||
good_prompts: DatasetSpecification = Field(
|
||||
default=DatasetSpecification(
|
||||
dataset="mlabonne/harmless_alpaca",
|
||||
|
||||
+20
-8
@@ -421,6 +421,13 @@ def run():
|
||||
|
||||
print()
|
||||
print("Calculating per-layer refusal directions...")
|
||||
|
||||
needs_full_residuals = settings.print_residual_geometry or settings.plot_residuals
|
||||
|
||||
good_residuals = None
|
||||
bad_residuals = None
|
||||
|
||||
if needs_full_residuals:
|
||||
print("* Obtaining residuals for good prompts...")
|
||||
good_residuals = model.get_residuals_batched(good_prompts)
|
||||
print("* Obtaining residuals for bad prompts...")
|
||||
@@ -429,6 +436,19 @@ def run():
|
||||
good_means = good_residuals.mean(dim=0)
|
||||
bad_means = bad_residuals.mean(dim=0)
|
||||
|
||||
analyzer = Analyzer(settings, model, good_residuals, bad_residuals)
|
||||
|
||||
if settings.print_residual_geometry:
|
||||
analyzer.print_residual_geometry()
|
||||
|
||||
if settings.plot_residuals:
|
||||
analyzer.plot_residuals()
|
||||
else:
|
||||
print("* Obtaining residual mean for good prompts...")
|
||||
good_means = model.get_residuals_mean(good_prompts)
|
||||
print("* Obtaining residual mean for bad prompts...")
|
||||
bad_means = model.get_residuals_mean(bad_prompts)
|
||||
|
||||
refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1)
|
||||
|
||||
if settings.orthogonalize_direction:
|
||||
@@ -442,14 +462,6 @@ def run():
|
||||
)
|
||||
refusal_directions = F.normalize(refusal_directions, p=2, dim=1)
|
||||
|
||||
analyzer = Analyzer(settings, model, good_residuals, bad_residuals)
|
||||
|
||||
if settings.print_residual_geometry:
|
||||
analyzer.print_residual_geometry()
|
||||
|
||||
if settings.plot_residuals:
|
||||
analyzer.plot_residuals()
|
||||
|
||||
# We don't need the residuals after computing refusal directions.
|
||||
del good_residuals, bad_residuals, analyzer
|
||||
empty_cache()
|
||||
|
||||
+42
-2
@@ -636,6 +636,9 @@ class Model:
|
||||
max_new_tokens=1,
|
||||
output_hidden_states=True,
|
||||
return_dict_in_generate=True,
|
||||
# KV cache is unnecessary here because we only need the hidden states
|
||||
# for the first generated token.
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
||||
@@ -669,7 +672,11 @@ class Model:
|
||||
dim=2,
|
||||
keepdim=True,
|
||||
)
|
||||
return torch.clamp(residuals, -thresholds, thresholds)
|
||||
residuals = torch.clamp(residuals, -thresholds, thresholds)
|
||||
|
||||
if self.settings.offload_outputs_to_cpu:
|
||||
residuals = residuals.cpu()
|
||||
empty_cache()
|
||||
|
||||
return residuals
|
||||
|
||||
@@ -681,6 +688,30 @@ class Model:
|
||||
|
||||
return torch.cat(residuals, dim=0)
|
||||
|
||||
def get_residuals_mean(self, prompts: list[Prompt]) -> Tensor:
|
||||
if not prompts:
|
||||
raise ValueError("prompts must not be empty")
|
||||
|
||||
running_sum = None
|
||||
total_count = 0
|
||||
|
||||
for batch in batchify(prompts, self.settings.batch_size):
|
||||
batch_residuals = self.get_residuals(batch)
|
||||
|
||||
# Accumulate in high precision on CPU to reduce peak VRAM usage.
|
||||
batch_sum = batch_residuals.sum(dim=0, dtype=torch.float64).cpu()
|
||||
|
||||
if running_sum is None:
|
||||
running_sum = batch_sum
|
||||
else:
|
||||
running_sum += batch_sum
|
||||
|
||||
total_count += batch_residuals.shape[0]
|
||||
|
||||
assert running_sum is not None
|
||||
|
||||
return (running_sum / total_count).to(torch.float32)
|
||||
|
||||
# We work with logprobs rather than probabilities for numerical stability
|
||||
# when computing the KL divergence.
|
||||
def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
|
||||
@@ -691,6 +722,7 @@ class Model:
|
||||
max_new_tokens=1,
|
||||
output_scores=True,
|
||||
return_dict_in_generate=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# This cast is valid because GenerateDecoderOnlyOutput is the return type
|
||||
@@ -702,7 +734,15 @@ class Model:
|
||||
logits = cast(tuple[FloatTensor], outputs.scores)[0]
|
||||
|
||||
# The returned tensor has shape (prompt, token).
|
||||
return F.log_softmax(logits, dim=-1)
|
||||
logprobs = F.log_softmax(logits, dim=-1)
|
||||
|
||||
del outputs
|
||||
|
||||
if self.settings.offload_outputs_to_cpu:
|
||||
logprobs = logprobs.cpu()
|
||||
empty_cache()
|
||||
|
||||
return logprobs
|
||||
|
||||
def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor:
|
||||
logprobs = []
|
||||
|
||||
Reference in New Issue
Block a user