feat: add configurable residual processing to reduce peak VRAM usage (#239)

* refactor residual memory optimizations

* formatting

* Fixed config.py positioning and default

* fixed analyzier declaration in main.py

* removing del statements

* ruff

* small updates

* ty moveback ish
This commit is contained in:
Magic
2026-04-18 07:16:22 -04:00
committed by GitHub
parent 5083fc0dd7
commit ed5d8b9104
4 changed files with 82 additions and 16 deletions
+6
View File
@@ -137,6 +137,12 @@ refusal_markers = [
# System prompt to use when prompting the model.
system_prompt = "You are a helpful assistant."
# Move intermediate analysis tensors (such as residuals and logprobs)
# to CPU memory as soon as possible to reduce peak VRAM usage.
# This lowers peak VRAM usage during residual analysis and evaluation,
# but may slightly reduce performance due to host/device transfers.
offload_outputs_to_cpu = true
# Dataset of prompts that tend to not result in refusals (used for calculating refusal directions).
[good_prompts]
dataset = "mlabonne/harmless_alpaca"
+8
View File
@@ -397,6 +397,14 @@ class Settings(BaseSettings):
description="System prompt to use when prompting the model.",
)
offload_outputs_to_cpu: bool = Field(
default=True,
description=(
"Whether to move intermediate analysis tensors (such as residuals and logprobs) "
"to CPU memory as soon as possible to reduce peak VRAM usage."
),
)
good_prompts: DatasetSpecification = Field(
default=DatasetSpecification(
dataset="mlabonne/harmless_alpaca",
+26 -14
View File
@@ -421,13 +421,33 @@ def run():
print()
print("Calculating per-layer refusal directions...")
print("* Obtaining residuals for good prompts...")
good_residuals = model.get_residuals_batched(good_prompts)
print("* Obtaining residuals for bad prompts...")
bad_residuals = model.get_residuals_batched(bad_prompts)
good_means = good_residuals.mean(dim=0)
bad_means = bad_residuals.mean(dim=0)
needs_full_residuals = settings.print_residual_geometry or settings.plot_residuals
good_residuals = None
bad_residuals = None
if needs_full_residuals:
print("* Obtaining residuals for good prompts...")
good_residuals = model.get_residuals_batched(good_prompts)
print("* Obtaining residuals for bad prompts...")
bad_residuals = model.get_residuals_batched(bad_prompts)
good_means = good_residuals.mean(dim=0)
bad_means = bad_residuals.mean(dim=0)
analyzer = Analyzer(settings, model, good_residuals, bad_residuals)
if settings.print_residual_geometry:
analyzer.print_residual_geometry()
if settings.plot_residuals:
analyzer.plot_residuals()
else:
print("* Obtaining residual mean for good prompts...")
good_means = model.get_residuals_mean(good_prompts)
print("* Obtaining residual mean for bad prompts...")
bad_means = model.get_residuals_mean(bad_prompts)
refusal_directions = F.normalize(bad_means - good_means, p=2, dim=1)
@@ -442,14 +462,6 @@ def run():
)
refusal_directions = F.normalize(refusal_directions, p=2, dim=1)
analyzer = Analyzer(settings, model, good_residuals, bad_residuals)
if settings.print_residual_geometry:
analyzer.print_residual_geometry()
if settings.plot_residuals:
analyzer.plot_residuals()
# We don't need the residuals after computing refusal directions.
del good_residuals, bad_residuals, analyzer
empty_cache()
+42 -2
View File
@@ -636,6 +636,9 @@ class Model:
max_new_tokens=1,
output_hidden_states=True,
return_dict_in_generate=True,
# KV cache is unnecessary here because we only need the hidden states
# for the first generated token.
use_cache=False,
)
# This cast is valid because GenerateDecoderOnlyOutput is the return type
@@ -669,7 +672,11 @@ class Model:
dim=2,
keepdim=True,
)
return torch.clamp(residuals, -thresholds, thresholds)
residuals = torch.clamp(residuals, -thresholds, thresholds)
if self.settings.offload_outputs_to_cpu:
residuals = residuals.cpu()
empty_cache()
return residuals
@@ -681,6 +688,30 @@ class Model:
return torch.cat(residuals, dim=0)
def get_residuals_mean(self, prompts: list[Prompt]) -> Tensor:
if not prompts:
raise ValueError("prompts must not be empty")
running_sum = None
total_count = 0
for batch in batchify(prompts, self.settings.batch_size):
batch_residuals = self.get_residuals(batch)
# Accumulate in high precision on CPU to reduce peak VRAM usage.
batch_sum = batch_residuals.sum(dim=0, dtype=torch.float64).cpu()
if running_sum is None:
running_sum = batch_sum
else:
running_sum += batch_sum
total_count += batch_residuals.shape[0]
assert running_sum is not None
return (running_sum / total_count).to(torch.float32)
# We work with logprobs rather than probabilities for numerical stability
# when computing the KL divergence.
def get_logprobs(self, prompts: list[Prompt]) -> Tensor:
@@ -691,6 +722,7 @@ class Model:
max_new_tokens=1,
output_scores=True,
return_dict_in_generate=True,
use_cache=False,
)
# This cast is valid because GenerateDecoderOnlyOutput is the return type
@@ -702,7 +734,15 @@ class Model:
logits = cast(tuple[FloatTensor], outputs.scores)[0]
# The returned tensor has shape (prompt, token).
return F.log_softmax(logits, dim=-1)
logprobs = F.log_softmax(logits, dim=-1)
del outputs
if self.settings.offload_outputs_to_cpu:
logprobs = logprobs.cpu()
empty_cache()
return logprobs
def get_logprobs_batched(self, prompts: list[Prompt]) -> Tensor:
logprobs = []