fix: fix remaining issues

This commit is contained in:
Philipp Emanuel Weidmann
2026-04-23 18:36:01 +05:30
parent 54f5daad90
commit 5c0f344760
4 changed files with 57 additions and 93 deletions
+21 -47
View File
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
import copy
from enum import Enum
from typing import Dict
@@ -17,8 +16,8 @@ from pydantic_settings import (
# !!!IMPORTANT!!!
#
# Any settings added to the classes defined in this module
# must be evaluated for privacy implications and added to
# the logic in get_essential_settings if required.
# must be evaluated for privacy implications and have
# exclude=True set in their field definitions if appropriate.
class QuantizationMethod(str, Enum):
@@ -65,11 +64,13 @@ class DatasetSpecification(BaseModel):
residual_plot_label: str | None = Field(
default=None,
description="Label to use for the dataset in plots of residual vectors.",
exclude=True,
)
residual_plot_color: str | None = Field(
default=None,
description="Matplotlib color to use for the dataset in plots of residual vectors.",
exclude=True,
)
@@ -99,6 +100,7 @@ class Settings(BaseSettings):
"If this model ID or path is set, then instead of abliterating the main model, "
"evaluate this model relative to the main model."
),
exclude=True,
)
dtypes: list[str] = Field(
@@ -142,6 +144,8 @@ class Settings(BaseSettings):
trust_remote_code: bool | None = Field(
default=None,
description="Whether to trust remote code when loading the model.",
# For security reasons, we don't store this setting.
exclude=True,
)
batch_size: int = Field(
@@ -152,6 +156,9 @@ class Settings(BaseSettings):
max_batch_size: int = Field(
default=128,
description="Maximum batch size to try when automatically determining the optimal batch size.",
# When storing a settings object, the batch size is already fixed,
# either determined by the automatic mechanism or by explicit user choice.
exclude=True,
)
max_response_length: int = Field(
@@ -196,36 +203,45 @@ class Settings(BaseSettings):
"the Chain-of-Thought block in responses, so that evaluation happens "
"at the start of the actual response."
),
# When storing a settings object, the response prefix is already fixed,
# either determined by the automatic mechanism or by explicit user choice.
exclude=True,
)
print_responses: bool = Field(
default=False,
description="Whether to print prompt/response pairs when counting refusals.",
exclude=True,
)
print_residual_geometry: bool = Field(
default=False,
description="Whether to print detailed information about residuals and refusal directions.",
exclude=True,
)
plot_residuals: bool = Field(
default=False,
description="Whether to generate plots showing PaCMAP projections of residual vectors.",
exclude=True,
)
residual_plot_path: str = Field(
default="plots",
description="Base path to save plots of residual vectors to.",
exclude=True,
)
residual_plot_title: str = Field(
default='PaCMAP Projection of Residual Vectors for "Harmless" and "Harmful" Prompts',
description="Title placed above plots of residual vectors.",
exclude=True,
)
residual_plot_style: str = Field(
default="dark_background",
description="Matplotlib style sheet to use for plots of residual vectors.",
exclude=True,
)
kl_divergence_scale: float = Field(
@@ -304,6 +320,7 @@ class Settings(BaseSettings):
study_checkpoint_dir: str = Field(
default="checkpoints",
description="Directory to save and load study progress to/from.",
exclude=True,
)
benchmarks: list[BenchmarkSpecification] = Field(
@@ -365,6 +382,7 @@ class Settings(BaseSettings):
),
],
description="Benchmarks to offer to the user for evaluating abliterated models.",
exclude=True,
)
max_shard_size: int | str = Field(
@@ -485,47 +503,3 @@ class Settings(BaseSettings):
file_secret_settings,
TomlConfigSettingsSource(settings_cls, toml_file="config.toml"),
)
def get_essential_settings(settings: Settings) -> Settings:
"""
Returns a stripped-down version of the settings object that only contains
settings that directly influence the results of the abliteration run.
In particular, this object contains no file system paths other than (possibly)
paths to local models and datasets.
"""
essential_settings = copy.deepcopy(settings)
del essential_settings.evaluate_model
# We always use the default for security reasons.
del essential_settings.trust_remote_code
# When storing a settings object, the batch size is already fixed,
# either determined by the automatic mechanism or by explicit user choice.
del essential_settings.max_batch_size
# When storing a settings object, the response prefix is already fixed,
# either determined by the automatic mechanism or by explicit user choice.
del essential_settings.chain_of_thought_skips
del essential_settings.print_responses
del essential_settings.print_residual_geometry
del essential_settings.plot_residuals
del essential_settings.residual_plot_path
del essential_settings.residual_plot_title
del essential_settings.residual_plot_style
del essential_settings.study_checkpoint_dir
del essential_settings.benchmarks
for dataset in [
essential_settings.good_prompts,
essential_settings.bad_prompts,
essential_settings.good_evaluation_prompts,
essential_settings.bad_evaluation_prompts,
]:
del dataset.residual_plot_label
del dataset.residual_plot_color
return essential_settings
+9 -19
View File
@@ -58,7 +58,7 @@ from rich.table import Table
from rich.traceback import install
from .analyzer import Analyzer
from .config import QuantizationMethod, get_essential_settings
from .config import QuantizationMethod
from .evaluator import Evaluator
from .model import AbliterationParameters, Model, get_model_class
from .system import empty_cache, get_accelerator_info
@@ -598,10 +598,7 @@ def run():
load_if_exists=True,
)
study.set_user_attr(
"settings",
get_essential_settings(settings).model_dump_json(exclude_none=True),
)
study.set_user_attr("settings", settings.model_dump_json())
study.set_user_attr("finished", False)
def count_completed_trials() -> int:
@@ -715,10 +712,7 @@ def run():
continue
settings.n_trials += n_additional_trials
study.set_user_attr(
"settings",
get_essential_settings(settings).model_dump_json(exclude_none=True),
)
study.set_user_attr("settings", settings.model_dump_json())
study.set_user_attr("finished", False)
try:
@@ -910,22 +904,18 @@ def run():
token=token,
)
# If the model path exists locally and includes the
# card, use it directly. If the model path doesn't
# exist locally, it can be assumed to be a model
# hosted on the Hugging Face Hub, in which case
# we can retrieve the model card.
model_path = Path(settings.model)
if model_path.exists():
if is_hf_path(settings.model):
card = ModelCard.load(settings.model)
else:
card_path = (
model_path / huggingface_hub.constants.REPOCARD_NAME
Path(settings.model)
/ huggingface_hub.constants.REPOCARD_NAME
)
if card_path.exists():
card = ModelCard.load(card_path)
else:
card = None
else:
card = ModelCard.load(settings.model)
if card is not None:
if card.data is None:
card.data = ModelCardData()
+8 -6
View File
@@ -411,7 +411,7 @@ def get_python_env_info() -> str:
return f"{info['version']} ({info['implementation']}, {info['compiler']}) [{info['environment']}]"
def get_package_version(name: str) -> str | None:
def get_package_version(name: str) -> str:
"""Gets the installed version of a package, stripping local suffixes like +cu128."""
# Normalize name: pip considers hyphens and underscores equivalent.
@@ -427,6 +427,7 @@ def get_requirements_dict() -> dict[str, str]:
# PyTorch is not listed as a dependency in the heretic-llm package
# because installation is hardware-specific and must be done manually.
packages_to_check = ["heretic-llm", "torch", "torchaudio", "torchvision"]
visited = set()
required_packages = set()
@@ -459,18 +460,19 @@ def get_requirements_dict() -> dict[str, str]:
# If a package is listed as a dependency but not installed, we skip it.
continue
required_packages_sorted = sorted(required_packages)
# Lookup versions for all discovered packages.
dependencies = {}
version_info = get_heretic_version_info()
for name in required_packages:
for package in required_packages_sorted:
# If heretic-llm was installed from source (Git/Local), exclude it
# from requirements.txt to prevent pip from downloading an unrelated
# version from PyPI during reproduction.
if name == "heretic-llm" and not version_info.is_standard_pypi:
if package == "heretic-llm" and not version_info.is_standard_pypi:
continue
version_str = get_package_version(name)
if version_str:
dependencies[name] = version_str
dependencies[package] = get_package_version(package)
return dependencies
+19 -21
View File
@@ -26,7 +26,7 @@ from psutil import Process
from questionary import Choice, Style
from rich.console import Console
from .config import DatasetSpecification, Settings, get_essential_settings
from .config import DatasetSpecification, Settings
from .system import (
get_accelerator_info_dict,
get_cpu_info_dict,
@@ -193,7 +193,13 @@ def load_prompts(
path = specification.dataset
split_str = specification.split
if os.path.isdir(path):
if is_hf_path(path):
dataset = load_dataset(
path,
revision=specification.commit,
split=split_str,
)
else:
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
# Dataset saved with datasets.save_to_disk; needs special handling.
# Path should be the subdirectory for a particular split.
@@ -211,19 +217,15 @@ def load_prompts(
# Get the dataset by applying the indices.
dataset = dataset[abs_instruction.from_ : abs_instruction.to]
else:
# Path is a local directory.
# Path should be a local directory.
dataset = load_dataset(
path,
revision=specification.commit,
split=split_str,
# Don't require the number of examples (lines) per split to be pre-defined.
verification_mode=VerificationMode.NO_CHECKS,
# But also don't use cached data, as the dataset may have changed on disk.
download_mode=DownloadMode.FORCE_REDOWNLOAD,
)
else:
# Probably a repository path; let load_dataset figure it out.
dataset = load_dataset(path, split=split_str)
prompts = list(dataset[specification.column])
@@ -327,18 +329,16 @@ def get_readme_intro(
def generate_config_toml(settings: Settings) -> str:
"""Serializes the full Settings object to TOML."""
return tomli_w.dumps(get_essential_settings(settings).model_dump(exclude_none=True))
return tomli_w.dumps(settings.model_dump(exclude_none=True))
def generate_requirements_txt() -> str:
"""Collects direct project dependencies as a formatted string."""
requirements = get_requirements_dict()
sorted_requirements = sorted(
[f"{name}=={version}" for name, version in requirements.items()],
key=lambda x: x.lower(),
)
return "\n".join(sorted_requirements) + "\n"
requirements = [
f"{package}=={version}" for package, version in get_requirements_dict().items()
]
return "\n".join(requirements) + "\n"
def set_seed(seed: int):
@@ -350,16 +350,14 @@ def set_seed(seed: int):
def format_hf_link(
name: str,
path: str,
commit: str | None = None,
is_dataset: bool = False,
) -> str:
if Path(name).exists():
return f"`{name}` (Local)"
prefix = "datasets/" if is_dataset else ""
base_url = f"https://huggingface.co/{prefix}{name}"
link = f"[{name}]({base_url})"
base_url = f"https://huggingface.co/{prefix}{path}"
link = f"[{path}]({base_url})"
if commit:
commit_url = f"{base_url}/commit/{commit}"
link += f" (Commit: [`{commit[:7]}`]({commit_url}))"
@@ -561,7 +559,7 @@ def generate_reproduce_json(
"pytorch_version": torch.__version__,
"requirements": get_requirements_dict(),
},
"settings": get_essential_settings(settings).model_dump(exclude_none=True),
"settings": settings.model_dump(),
"parameters": {
"direction_index": trial.user_attrs["direction_index"],
"abliteration_parameters": trial.user_attrs["parameters"],