mirror of
https://github.com/p-e-w/heretic.git
synced 2026-06-02 05:03:33 +02:00
fix: fix remaining issues
This commit is contained in:
+21
-47
@@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
|
||||
@@ -17,8 +16,8 @@ from pydantic_settings import (
|
||||
# !!!IMPORTANT!!!
|
||||
#
|
||||
# Any settings added to the classes defined in this module
|
||||
# must be evaluated for privacy implications and added to
|
||||
# the logic in get_essential_settings if required.
|
||||
# must be evaluated for privacy implications and have
|
||||
# exclude=True set in their field definitions if appropriate.
|
||||
|
||||
|
||||
class QuantizationMethod(str, Enum):
|
||||
@@ -65,11 +64,13 @@ class DatasetSpecification(BaseModel):
|
||||
residual_plot_label: str | None = Field(
|
||||
default=None,
|
||||
description="Label to use for the dataset in plots of residual vectors.",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
residual_plot_color: str | None = Field(
|
||||
default=None,
|
||||
description="Matplotlib color to use for the dataset in plots of residual vectors.",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -99,6 +100,7 @@ class Settings(BaseSettings):
|
||||
"If this model ID or path is set, then instead of abliterating the main model, "
|
||||
"evaluate this model relative to the main model."
|
||||
),
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
dtypes: list[str] = Field(
|
||||
@@ -142,6 +144,8 @@ class Settings(BaseSettings):
|
||||
trust_remote_code: bool | None = Field(
|
||||
default=None,
|
||||
description="Whether to trust remote code when loading the model.",
|
||||
# For security reasons, we don't store this setting.
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
batch_size: int = Field(
|
||||
@@ -152,6 +156,9 @@ class Settings(BaseSettings):
|
||||
max_batch_size: int = Field(
|
||||
default=128,
|
||||
description="Maximum batch size to try when automatically determining the optimal batch size.",
|
||||
# When storing a settings object, the batch size is already fixed,
|
||||
# either determined by the automatic mechanism or by explicit user choice.
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
max_response_length: int = Field(
|
||||
@@ -196,36 +203,45 @@ class Settings(BaseSettings):
|
||||
"the Chain-of-Thought block in responses, so that evaluation happens "
|
||||
"at the start of the actual response."
|
||||
),
|
||||
# When storing a settings object, the response prefix is already fixed,
|
||||
# either determined by the automatic mechanism or by explicit user choice.
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
print_responses: bool = Field(
|
||||
default=False,
|
||||
description="Whether to print prompt/response pairs when counting refusals.",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
print_residual_geometry: bool = Field(
|
||||
default=False,
|
||||
description="Whether to print detailed information about residuals and refusal directions.",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
plot_residuals: bool = Field(
|
||||
default=False,
|
||||
description="Whether to generate plots showing PaCMAP projections of residual vectors.",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
residual_plot_path: str = Field(
|
||||
default="plots",
|
||||
description="Base path to save plots of residual vectors to.",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
residual_plot_title: str = Field(
|
||||
default='PaCMAP Projection of Residual Vectors for "Harmless" and "Harmful" Prompts',
|
||||
description="Title placed above plots of residual vectors.",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
residual_plot_style: str = Field(
|
||||
default="dark_background",
|
||||
description="Matplotlib style sheet to use for plots of residual vectors.",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
kl_divergence_scale: float = Field(
|
||||
@@ -304,6 +320,7 @@ class Settings(BaseSettings):
|
||||
study_checkpoint_dir: str = Field(
|
||||
default="checkpoints",
|
||||
description="Directory to save and load study progress to/from.",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
benchmarks: list[BenchmarkSpecification] = Field(
|
||||
@@ -365,6 +382,7 @@ class Settings(BaseSettings):
|
||||
),
|
||||
],
|
||||
description="Benchmarks to offer to the user for evaluating abliterated models.",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
max_shard_size: int | str = Field(
|
||||
@@ -485,47 +503,3 @@ class Settings(BaseSettings):
|
||||
file_secret_settings,
|
||||
TomlConfigSettingsSource(settings_cls, toml_file="config.toml"),
|
||||
)
|
||||
|
||||
|
||||
def get_essential_settings(settings: Settings) -> Settings:
|
||||
"""
|
||||
Returns a stripped-down version of the settings object that only contains
|
||||
settings that directly influence the results of the abliteration run.
|
||||
In particular, this object contains no file system paths other than (possibly)
|
||||
paths to local models and datasets.
|
||||
"""
|
||||
|
||||
essential_settings = copy.deepcopy(settings)
|
||||
|
||||
del essential_settings.evaluate_model
|
||||
|
||||
# We always use the default for security reasons.
|
||||
del essential_settings.trust_remote_code
|
||||
|
||||
# When storing a settings object, the batch size is already fixed,
|
||||
# either determined by the automatic mechanism or by explicit user choice.
|
||||
del essential_settings.max_batch_size
|
||||
|
||||
# When storing a settings object, the response prefix is already fixed,
|
||||
# either determined by the automatic mechanism or by explicit user choice.
|
||||
del essential_settings.chain_of_thought_skips
|
||||
|
||||
del essential_settings.print_responses
|
||||
del essential_settings.print_residual_geometry
|
||||
del essential_settings.plot_residuals
|
||||
del essential_settings.residual_plot_path
|
||||
del essential_settings.residual_plot_title
|
||||
del essential_settings.residual_plot_style
|
||||
del essential_settings.study_checkpoint_dir
|
||||
del essential_settings.benchmarks
|
||||
|
||||
for dataset in [
|
||||
essential_settings.good_prompts,
|
||||
essential_settings.bad_prompts,
|
||||
essential_settings.good_evaluation_prompts,
|
||||
essential_settings.bad_evaluation_prompts,
|
||||
]:
|
||||
del dataset.residual_plot_label
|
||||
del dataset.residual_plot_color
|
||||
|
||||
return essential_settings
|
||||
|
||||
+9
-19
@@ -58,7 +58,7 @@ from rich.table import Table
|
||||
from rich.traceback import install
|
||||
|
||||
from .analyzer import Analyzer
|
||||
from .config import QuantizationMethod, get_essential_settings
|
||||
from .config import QuantizationMethod
|
||||
from .evaluator import Evaluator
|
||||
from .model import AbliterationParameters, Model, get_model_class
|
||||
from .system import empty_cache, get_accelerator_info
|
||||
@@ -598,10 +598,7 @@ def run():
|
||||
load_if_exists=True,
|
||||
)
|
||||
|
||||
study.set_user_attr(
|
||||
"settings",
|
||||
get_essential_settings(settings).model_dump_json(exclude_none=True),
|
||||
)
|
||||
study.set_user_attr("settings", settings.model_dump_json())
|
||||
study.set_user_attr("finished", False)
|
||||
|
||||
def count_completed_trials() -> int:
|
||||
@@ -715,10 +712,7 @@ def run():
|
||||
continue
|
||||
|
||||
settings.n_trials += n_additional_trials
|
||||
study.set_user_attr(
|
||||
"settings",
|
||||
get_essential_settings(settings).model_dump_json(exclude_none=True),
|
||||
)
|
||||
study.set_user_attr("settings", settings.model_dump_json())
|
||||
study.set_user_attr("finished", False)
|
||||
|
||||
try:
|
||||
@@ -910,22 +904,18 @@ def run():
|
||||
token=token,
|
||||
)
|
||||
|
||||
# If the model path exists locally and includes the
|
||||
# card, use it directly. If the model path doesn't
|
||||
# exist locally, it can be assumed to be a model
|
||||
# hosted on the Hugging Face Hub, in which case
|
||||
# we can retrieve the model card.
|
||||
model_path = Path(settings.model)
|
||||
if model_path.exists():
|
||||
if is_hf_path(settings.model):
|
||||
card = ModelCard.load(settings.model)
|
||||
else:
|
||||
card_path = (
|
||||
model_path / huggingface_hub.constants.REPOCARD_NAME
|
||||
Path(settings.model)
|
||||
/ huggingface_hub.constants.REPOCARD_NAME
|
||||
)
|
||||
if card_path.exists():
|
||||
card = ModelCard.load(card_path)
|
||||
else:
|
||||
card = None
|
||||
else:
|
||||
card = ModelCard.load(settings.model)
|
||||
|
||||
if card is not None:
|
||||
if card.data is None:
|
||||
card.data = ModelCardData()
|
||||
|
||||
@@ -411,7 +411,7 @@ def get_python_env_info() -> str:
|
||||
return f"{info['version']} ({info['implementation']}, {info['compiler']}) [{info['environment']}]"
|
||||
|
||||
|
||||
def get_package_version(name: str) -> str | None:
|
||||
def get_package_version(name: str) -> str:
|
||||
"""Gets the installed version of a package, stripping local suffixes like +cu128."""
|
||||
|
||||
# Normalize name: pip considers hyphens and underscores equivalent.
|
||||
@@ -427,6 +427,7 @@ def get_requirements_dict() -> dict[str, str]:
|
||||
# PyTorch is not listed as a dependency in the heretic-llm package
|
||||
# because installation is hardware-specific and must be done manually.
|
||||
packages_to_check = ["heretic-llm", "torch", "torchaudio", "torchvision"]
|
||||
|
||||
visited = set()
|
||||
required_packages = set()
|
||||
|
||||
@@ -459,18 +460,19 @@ def get_requirements_dict() -> dict[str, str]:
|
||||
# If a package is listed as a dependency but not installed, we skip it.
|
||||
continue
|
||||
|
||||
required_packages_sorted = sorted(required_packages)
|
||||
|
||||
# Lookup versions for all discovered packages.
|
||||
dependencies = {}
|
||||
version_info = get_heretic_version_info()
|
||||
for name in required_packages:
|
||||
|
||||
for package in required_packages_sorted:
|
||||
# If heretic-llm was installed from source (Git/Local), exclude it
|
||||
# from requirements.txt to prevent pip from downloading an unrelated
|
||||
# version from PyPI during reproduction.
|
||||
if name == "heretic-llm" and not version_info.is_standard_pypi:
|
||||
if package == "heretic-llm" and not version_info.is_standard_pypi:
|
||||
continue
|
||||
|
||||
version_str = get_package_version(name)
|
||||
if version_str:
|
||||
dependencies[name] = version_str
|
||||
dependencies[package] = get_package_version(package)
|
||||
|
||||
return dependencies
|
||||
|
||||
+19
-21
@@ -26,7 +26,7 @@ from psutil import Process
|
||||
from questionary import Choice, Style
|
||||
from rich.console import Console
|
||||
|
||||
from .config import DatasetSpecification, Settings, get_essential_settings
|
||||
from .config import DatasetSpecification, Settings
|
||||
from .system import (
|
||||
get_accelerator_info_dict,
|
||||
get_cpu_info_dict,
|
||||
@@ -193,7 +193,13 @@ def load_prompts(
|
||||
path = specification.dataset
|
||||
split_str = specification.split
|
||||
|
||||
if os.path.isdir(path):
|
||||
if is_hf_path(path):
|
||||
dataset = load_dataset(
|
||||
path,
|
||||
revision=specification.commit,
|
||||
split=split_str,
|
||||
)
|
||||
else:
|
||||
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
|
||||
# Dataset saved with datasets.save_to_disk; needs special handling.
|
||||
# Path should be the subdirectory for a particular split.
|
||||
@@ -211,19 +217,15 @@ def load_prompts(
|
||||
# Get the dataset by applying the indices.
|
||||
dataset = dataset[abs_instruction.from_ : abs_instruction.to]
|
||||
else:
|
||||
# Path is a local directory.
|
||||
# Path should be a local directory.
|
||||
dataset = load_dataset(
|
||||
path,
|
||||
revision=specification.commit,
|
||||
split=split_str,
|
||||
# Don't require the number of examples (lines) per split to be pre-defined.
|
||||
verification_mode=VerificationMode.NO_CHECKS,
|
||||
# But also don't use cached data, as the dataset may have changed on disk.
|
||||
download_mode=DownloadMode.FORCE_REDOWNLOAD,
|
||||
)
|
||||
else:
|
||||
# Probably a repository path; let load_dataset figure it out.
|
||||
dataset = load_dataset(path, split=split_str)
|
||||
|
||||
prompts = list(dataset[specification.column])
|
||||
|
||||
@@ -327,18 +329,16 @@ def get_readme_intro(
|
||||
def generate_config_toml(settings: Settings) -> str:
|
||||
"""Serializes the full Settings object to TOML."""
|
||||
|
||||
return tomli_w.dumps(get_essential_settings(settings).model_dump(exclude_none=True))
|
||||
return tomli_w.dumps(settings.model_dump(exclude_none=True))
|
||||
|
||||
|
||||
def generate_requirements_txt() -> str:
|
||||
"""Collects direct project dependencies as a formatted string."""
|
||||
|
||||
requirements = get_requirements_dict()
|
||||
sorted_requirements = sorted(
|
||||
[f"{name}=={version}" for name, version in requirements.items()],
|
||||
key=lambda x: x.lower(),
|
||||
)
|
||||
return "\n".join(sorted_requirements) + "\n"
|
||||
requirements = [
|
||||
f"{package}=={version}" for package, version in get_requirements_dict().items()
|
||||
]
|
||||
return "\n".join(requirements) + "\n"
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
@@ -350,16 +350,14 @@ def set_seed(seed: int):
|
||||
|
||||
|
||||
def format_hf_link(
|
||||
name: str,
|
||||
path: str,
|
||||
commit: str | None = None,
|
||||
is_dataset: bool = False,
|
||||
) -> str:
|
||||
if Path(name).exists():
|
||||
return f"`{name}` (Local)"
|
||||
|
||||
prefix = "datasets/" if is_dataset else ""
|
||||
base_url = f"https://huggingface.co/{prefix}{name}"
|
||||
link = f"[{name}]({base_url})"
|
||||
base_url = f"https://huggingface.co/{prefix}{path}"
|
||||
link = f"[{path}]({base_url})"
|
||||
|
||||
if commit:
|
||||
commit_url = f"{base_url}/commit/{commit}"
|
||||
link += f" (Commit: [`{commit[:7]}`]({commit_url}))"
|
||||
@@ -561,7 +559,7 @@ def generate_reproduce_json(
|
||||
"pytorch_version": torch.__version__,
|
||||
"requirements": get_requirements_dict(),
|
||||
},
|
||||
"settings": get_essential_settings(settings).model_dump(exclude_none=True),
|
||||
"settings": settings.model_dump(),
|
||||
"parameters": {
|
||||
"direction_index": trial.user_attrs["direction_index"],
|
||||
"abliteration_parameters": trial.user_attrs["parameters"],
|
||||
|
||||
Reference in New Issue
Block a user