mirror of
https://github.com/p-e-w/heretic.git
synced 2026-06-02 05:03:33 +02:00
feat: reproducibility when saving & uploading a heretic model (#191)
* feat: implement reproducibility features with safetensors * feat: prompt user before creating reproducibility folder * fix: use prompt_confirm wrapper * style comment * style comment * fix: ignore None values in Settings dump for TOML compatibility * fix: imports * feat: auto-generate seed if none provided for full reproducibility * style: fix ruff formatting issues * style: ruff * style: fix ty check errors with ty:ignore * Update src/heretic/main.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update src/heretic/utils.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * add period at end. Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Improve: Add README, checkpoint.jsonl, to Reproduce * fix: use centralize device info, remove random states file * feat: Add CUDA driver version * ruff * ruff... * ty fix * LGTM: Rich native strip, use nvidia-smi * ruff fix * ruff * revert kaggle hack) * normalize names for deduplication of packages/versions * docstring * rufff * cleanup, add suffix for torch CUDA version, distinguish ROCm * add PyTorch index URL detection * revert index URL to be simple * flip priority of index.. * add Important note * add exact suffix for WHL in instruction * add warning for heterogeneous GPU env * extend driver version info (more accelerators) * fix: style * sync * no abbreviation * use multi-line string * fix: prompt_confirm * feat: CPU info * strip 'slow' warning from environment.txt * feat: Add virtual env info to environment.txt * ruffff * feat: AMD (Radeon) GPU driver version * Refactor: system.py * feat: LGTM capturing specifc installation origin of heretic * feat: Include chosen trial into reproduce/README * style: run ruff format on utils.py * feat: reproduce.json * fix: seperate values in different keys * restore comment * style, clean, seperate commit key * no abbreviation, cleanup * remove labels, store only dependencies * missed import, ruff * sort import * feat: More CPU Info * only store direct dependencies of heretic * complete comment * refactor: use cpuinfo package instead * ruff import sort * distinguish cores & threads * move function amd-driver * rename * moving heretic package info, * rufff * Move: cleanup memory cache * fix: model.py import * no unknowns * generalize all accelerator info stuff * ruff f * move package info * type change * feat: no reproducibility suite for local saving/model used * import fix * fix: type check * style change * style ruff * feat: no env.txt, SHA256SUMS file, cleanup * feat: ADD tip to readme * remove trial index, two-keys only * fix: No time-zone * feat: No suite for local datasets allowed * simplify * featt: capture both direct and transitive dependencies * style: sort readme of reproducibility suite * feat: Store commit hash for datasets too * add total refusal prompts for evaluation display * remove try/except from cpu * extend SHA256 support * remove .txt * only have safetensors for SHA256 * style comment * use HF api to get commit hash * fix: requirements containing irrelevant dependencies * only store heretic-llm if from PyPI.. * add SELECTED tag to the trial that was pushed * AttributeError fix * simplify trial preservation * add direction_index in trial info * remove unwanted CPU info * style: rename --------- Co-authored-by: Vinayyyy7 <vinayumrethe99@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -91,6 +91,10 @@ n_trials = 200
|
||||
# Number of trials that use random sampling for the purpose of exploration.
|
||||
n_startup_trials = 60
|
||||
|
||||
# Random seed for reproducible optimization. Set to an integer to enable.
|
||||
# Applies to Python's random module, NumPy, PyTorch, and Optuna.
|
||||
# seed = 75
|
||||
|
||||
# Directory to save and load study progress to/from.
|
||||
study_checkpoint_dir = "checkpoints"
|
||||
|
||||
@@ -140,6 +144,7 @@ split = "train[:400]"
|
||||
column = "text"
|
||||
residual_plot_label = '"Harmless" prompts'
|
||||
residual_plot_color = "royalblue"
|
||||
commit = ""
|
||||
|
||||
# Dataset of prompts that tend to result in refusals (used for calculating refusal directions).
|
||||
[bad_prompts]
|
||||
@@ -148,15 +153,18 @@ split = "train[:400]"
|
||||
column = "text"
|
||||
residual_plot_label = '"Harmful" prompts'
|
||||
residual_plot_color = "darkorange"
|
||||
commit = ""
|
||||
|
||||
# Dataset of prompts that tend to not result in refusals (used for evaluating model performance).
|
||||
[good_evaluation_prompts]
|
||||
dataset = "mlabonne/harmless_alpaca"
|
||||
split = "test[:100]"
|
||||
column = "text"
|
||||
commit = ""
|
||||
|
||||
# Dataset of prompts that tend to result in refusals (used for evaluating model performance).
|
||||
[bad_evaluation_prompts]
|
||||
dataset = "mlabonne/harmful_behaviors"
|
||||
split = "test[:100]"
|
||||
column = "text"
|
||||
commit = ""
|
||||
|
||||
@@ -35,9 +35,11 @@ dependencies = [
|
||||
"optuna~=4.7",
|
||||
"peft~=0.18",
|
||||
"psutil~=7.2",
|
||||
"py-cpuinfo~=9.0",
|
||||
"pydantic-settings~=2.13",
|
||||
"questionary~=2.1",
|
||||
"rich~=14.3",
|
||||
"tomli-w~=1.2",
|
||||
"tqdm~=4.67",
|
||||
"transformers~=5.3",
|
||||
]
|
||||
|
||||
@@ -59,6 +59,10 @@ class DatasetSpecification(BaseModel):
|
||||
default=None,
|
||||
description="Matplotlib color to use for the dataset in plots of residual vectors.",
|
||||
)
|
||||
commit: str | None = Field(
|
||||
default=None,
|
||||
description="Hugging Face commit hash of the dataset.",
|
||||
)
|
||||
|
||||
|
||||
class BenchmarkSpecification(BaseModel):
|
||||
@@ -276,6 +280,14 @@ class Settings(BaseSettings):
|
||||
description="Number of trials that use random sampling for the purpose of exploration.",
|
||||
)
|
||||
|
||||
seed: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Random seed for reproducible optimization. "
|
||||
"Applies to Python's random module, NumPy, PyTorch, and Optuna."
|
||||
),
|
||||
)
|
||||
|
||||
study_checkpoint_dir: str = Field(
|
||||
default="checkpoints",
|
||||
description="Directory to save and load study progress to/from.",
|
||||
|
||||
+50
-49
@@ -12,6 +12,7 @@ patch_tqdm()
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import warnings
|
||||
@@ -29,13 +30,6 @@ import questionary
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from accelerate.utils import (
|
||||
is_mlu_available,
|
||||
is_musa_available,
|
||||
is_npu_available,
|
||||
is_sdaa_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
from huggingface_hub import ModelCard, ModelCardData
|
||||
from lm_eval.models.huggingface import HFLM
|
||||
from optuna import Trial, TrialPruned
|
||||
@@ -54,18 +48,21 @@ from .analyzer import Analyzer
|
||||
from .config import QuantizationMethod, Settings
|
||||
from .evaluator import Evaluator
|
||||
from .model import AbliterationParameters, Model, get_model_class
|
||||
from .system import empty_cache, get_accelerator_info
|
||||
from .utils import (
|
||||
empty_cache,
|
||||
format_duration,
|
||||
get_readme_intro,
|
||||
get_trial_parameters,
|
||||
load_prompts,
|
||||
print,
|
||||
print_memory_usage,
|
||||
prompt_confirm,
|
||||
prompt_password,
|
||||
prompt_path,
|
||||
prompt_select,
|
||||
prompt_text,
|
||||
set_seed,
|
||||
upload_reproduce_folder,
|
||||
)
|
||||
|
||||
|
||||
@@ -186,46 +183,12 @@ def run():
|
||||
)
|
||||
return
|
||||
|
||||
# Adapted from https://github.com/huggingface/accelerate/blob/main/src/accelerate/commands/env.py
|
||||
if torch.cuda.is_available():
|
||||
count = torch.cuda.device_count()
|
||||
total_vram = sum(torch.cuda.mem_get_info(i)[1] for i in range(count))
|
||||
print(
|
||||
f"Detected [bold]{count}[/] CUDA device(s) ({total_vram / (1024**3):.2f} GB total VRAM):"
|
||||
)
|
||||
for i in range(count):
|
||||
vram = torch.cuda.mem_get_info(i)[1] / (1024**3)
|
||||
print(
|
||||
f"* GPU {i}: [bold]{torch.cuda.get_device_name(i)}[/] ({vram:.2f} GB)"
|
||||
)
|
||||
elif is_xpu_available():
|
||||
count = torch.xpu.device_count()
|
||||
print(f"Detected [bold]{count}[/] XPU device(s):")
|
||||
for i in range(count):
|
||||
print(f"* XPU {i}: [bold]{torch.xpu.get_device_name(i)}[/]")
|
||||
elif is_mlu_available():
|
||||
count = torch.mlu.device_count() # ty:ignore[unresolved-attribute]
|
||||
print(f"Detected [bold]{count}[/] MLU device(s):")
|
||||
for i in range(count):
|
||||
print(f"* MLU {i}: [bold]{torch.mlu.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
|
||||
elif is_sdaa_available():
|
||||
count = torch.sdaa.device_count() # ty:ignore[unresolved-attribute]
|
||||
print(f"Detected [bold]{count}[/] SDAA device(s):")
|
||||
for i in range(count):
|
||||
print(f"* SDAA {i}: [bold]{torch.sdaa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
|
||||
elif is_musa_available():
|
||||
count = torch.musa.device_count() # ty:ignore[unresolved-attribute]
|
||||
print(f"Detected [bold]{count}[/] MUSA device(s):")
|
||||
for i in range(count):
|
||||
print(f"* MUSA {i}: [bold]{torch.musa.get_device_name(i)}[/]") # ty:ignore[unresolved-attribute]
|
||||
elif is_npu_available():
|
||||
print(f"NPU detected (CANN version: [bold]{torch.version.cann}[/])") # ty:ignore[unresolved-attribute]
|
||||
elif torch.backends.mps.is_available():
|
||||
print("Detected [bold]1[/] MPS device (Apple Metal)")
|
||||
else:
|
||||
print(
|
||||
"[bold yellow]No GPU or other accelerator detected. Operations will be slow.[/]"
|
||||
)
|
||||
if settings.seed is None:
|
||||
settings.seed = random.randint(0, 2**32 - 1)
|
||||
|
||||
set_seed(settings.seed)
|
||||
|
||||
print(get_accelerator_info())
|
||||
|
||||
# We don't need gradients as we only do inference.
|
||||
torch.set_grad_enabled(False)
|
||||
@@ -581,6 +544,7 @@ def run():
|
||||
|
||||
trial.set_user_attr("kl_divergence", kl_divergence)
|
||||
trial.set_user_attr("refusals", refusals)
|
||||
trial.set_user_attr("total_refusal_prompts", len(evaluator.bad_prompts))
|
||||
|
||||
return score
|
||||
|
||||
@@ -597,6 +561,7 @@ def run():
|
||||
n_startup_trials=settings.n_startup_trials,
|
||||
n_ei_candidates=128,
|
||||
multivariate=True,
|
||||
seed=settings.seed,
|
||||
),
|
||||
directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE],
|
||||
storage=storage,
|
||||
@@ -835,6 +800,30 @@ def run():
|
||||
if strategy is None:
|
||||
continue
|
||||
|
||||
# Reproducibility requires that the model and all datasets
|
||||
# are available on the Hugging Face Hub (not local paths).
|
||||
datasets = [
|
||||
settings.good_prompts.dataset,
|
||||
settings.bad_prompts.dataset,
|
||||
settings.good_evaluation_prompts.dataset,
|
||||
settings.bad_evaluation_prompts.dataset,
|
||||
]
|
||||
can_reproduce = not Path(settings.model).exists() and all(
|
||||
not Path(d).exists() for d in datasets
|
||||
)
|
||||
|
||||
if can_reproduce:
|
||||
# Pin the number of trials to the number of actual completed trials
|
||||
# for the reproduction configuration.
|
||||
settings.n_trials = count_completed_trials()
|
||||
|
||||
include_reproduce = prompt_confirm(
|
||||
"""Include 'reproduce' folder?
|
||||
This saves your exact configuration and system information, along with the study checkpoint, to help others verify your results."""
|
||||
)
|
||||
else:
|
||||
include_reproduce = False
|
||||
|
||||
if strategy == "adapter":
|
||||
print("Uploading LoRA adapter...")
|
||||
model.model.push_to_hub(
|
||||
@@ -894,7 +883,19 @@ def run():
|
||||
)
|
||||
card.push_to_hub(repo_id, token=token)
|
||||
|
||||
print(f"Model uploaded to [bold]{repo_id}[/].")
|
||||
if include_reproduce:
|
||||
upload_reproduce_folder(
|
||||
repo_id,
|
||||
settings,
|
||||
token,
|
||||
checkpoint_path=study_checkpoint_file,
|
||||
trial=trial,
|
||||
)
|
||||
print(
|
||||
f"Model and reproducibility files uploaded to [bold]{repo_id}[/]."
|
||||
)
|
||||
else:
|
||||
print(f"Model uploaded to [bold]{repo_id}[/].")
|
||||
|
||||
case "Chat with the model":
|
||||
print()
|
||||
|
||||
@@ -30,7 +30,8 @@ from transformers.generation import (
|
||||
)
|
||||
|
||||
from .config import QuantizationMethod, RowNormalization, Settings
|
||||
from .utils import Prompt, batchify, empty_cache, print
|
||||
from .system import empty_cache
|
||||
from .utils import Prompt, batchify, print
|
||||
|
||||
|
||||
def get_model_class(
|
||||
|
||||
@@ -0,0 +1,462 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
import gc
|
||||
import importlib.metadata
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import cpuinfo
|
||||
import torch
|
||||
from accelerate.utils import (
|
||||
is_mlu_available,
|
||||
is_musa_available,
|
||||
is_npu_available,
|
||||
is_sdaa_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
|
||||
|
||||
def empty_cache():
|
||||
"""Clears the backend cache and collects garbage."""
|
||||
# Collecting garbage is not an idempotent operation, and to avoid OOM errors,
|
||||
# gc.collect() has to be called both before and after emptying the backend cache.
|
||||
# See https://github.com/p-e-w/heretic/pull/17 for details.
|
||||
gc.collect()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_mlu_available():
|
||||
torch.mlu.empty_cache() # ty:ignore[unresolved-attribute]
|
||||
elif is_sdaa_available():
|
||||
torch.sdaa.empty_cache() # ty:ignore[unresolved-attribute]
|
||||
elif is_musa_available():
|
||||
torch.musa.empty_cache() # ty:ignore[unresolved-attribute]
|
||||
elif torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
def get_nvidia_driver_version() -> str | None:
|
||||
"""Gets the NVIDIA driver version using nvidia-smi."""
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
return output.strip().split("\n")[0]
|
||||
except (subprocess.CalledProcessError, FileNotFoundError, IndexError):
|
||||
return None
|
||||
|
||||
|
||||
def get_amdgpu_driver_version() -> str | None:
|
||||
"""Gets the AMD GPU (ROCm) driver and suite version info."""
|
||||
# 1. Try amd-smi (modern standard for ROCm 6.0+)
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["amd-smi", "version"],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
if output.strip():
|
||||
return output.strip().replace("\n", " | ")
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
# 2. Try rocm-smi --showdriverversion
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["rocm-smi", "--showdriverversion"],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
for line in output.split("\n"):
|
||||
if "Driver version" in line:
|
||||
return line.split(":")[-1].strip()
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
# 3. Try /sys/module/amdgpu/version (Linux kernel driver version)
|
||||
try:
|
||||
if platform.system() == "Linux":
|
||||
version_path = "/sys/module/amdgpu/version"
|
||||
if os.path.exists(version_path):
|
||||
with open(version_path, "r", encoding="utf-8") as f:
|
||||
return f.read().strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_xpu_driver_version() -> str | None:
|
||||
"""Gets the Intel XPU driver version."""
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["xpu-smi", "discovery"],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
for line in output.split("\n"):
|
||||
if "Driver Version" in line:
|
||||
return line.split(":")[-1].strip()
|
||||
return None
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
def get_npu_driver_version() -> str | None:
|
||||
"""Gets the Huawei NPU driver version."""
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["npu-smi", "info", "-t", "board", "-i", "0"],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
for line in output.split("\n"):
|
||||
if "Software Version" in line:
|
||||
return line.split()[-1].strip()
|
||||
return None
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
def get_mps_driver_version() -> str | None:
|
||||
"""Gets the Apple Silicon (MPS) driver version via macOS version."""
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["sw_vers", "-productVersion"],
|
||||
stderr=subprocess.DEVNULL,
|
||||
text=True,
|
||||
)
|
||||
return output.strip()
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class HereticVersionInfo:
|
||||
"""Detailed information about the heretic-llm installation."""
|
||||
|
||||
version: str
|
||||
origin: str | None
|
||||
is_standard_pypi: bool
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
def get_heretic_version_info() -> HereticVersionInfo:
|
||||
"""Detects version and installation source (PyPI, Git, Local) of heretic-llm."""
|
||||
package_name = "heretic-llm"
|
||||
origin_metadata: dict[str, Any] = {"type": "unknown"}
|
||||
# This package must be installed for this code to run.
|
||||
distribution = importlib.metadata.distribution(package_name)
|
||||
|
||||
base_version = distribution.version.lstrip("v")
|
||||
|
||||
try:
|
||||
direct_url_content = distribution.read_text("direct_url.json")
|
||||
except Exception:
|
||||
direct_url_content = None
|
||||
|
||||
if not direct_url_content:
|
||||
# Standard PyPI installation.
|
||||
origin_metadata["type"] = "pypi"
|
||||
return HereticVersionInfo(
|
||||
version=base_version,
|
||||
origin="PyPI",
|
||||
is_standard_pypi=True,
|
||||
metadata=origin_metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
data = json.loads(direct_url_content)
|
||||
|
||||
# Check for Git source.
|
||||
if "vcs_info" in data and data["vcs_info"].get("vcs") == "git":
|
||||
vcs_info = data["vcs_info"]
|
||||
commit_hash = vcs_info.get("commit_id", "unknown")
|
||||
repo_url = data.get("url", "unknown_repo")
|
||||
requested_revision = vcs_info.get("requested_revision")
|
||||
|
||||
if requested_revision:
|
||||
origin_str = (
|
||||
f"Git ({repo_url}@{requested_revision} - commit: {commit_hash})"
|
||||
)
|
||||
else:
|
||||
origin_str = f"Git ({repo_url} @ {commit_hash})"
|
||||
|
||||
origin_metadata.update(
|
||||
{
|
||||
"type": "git",
|
||||
"url": repo_url,
|
||||
"commit_hash": commit_hash,
|
||||
"requested_revision": requested_revision,
|
||||
}
|
||||
)
|
||||
|
||||
return HereticVersionInfo(
|
||||
version=base_version,
|
||||
origin=origin_str,
|
||||
is_standard_pypi=False,
|
||||
metadata=origin_metadata,
|
||||
)
|
||||
|
||||
# Check for local file/wheel directory.
|
||||
if "url" in data and data["url"].startswith("file://"):
|
||||
origin_metadata["type"] = "local"
|
||||
return HereticVersionInfo(
|
||||
version=base_version,
|
||||
origin="Local",
|
||||
is_standard_pypi=False,
|
||||
metadata=origin_metadata,
|
||||
)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return HereticVersionInfo(
|
||||
version=base_version,
|
||||
origin=None,
|
||||
is_standard_pypi=False,
|
||||
metadata=origin_metadata,
|
||||
)
|
||||
|
||||
|
||||
def get_accelerator_info_dict() -> dict[str, Any]:
|
||||
"""Retrieves raw accelerator info (CUDA, ROCm, etc) directly into structured keys."""
|
||||
if torch.cuda.is_available():
|
||||
count = torch.cuda.device_count()
|
||||
is_rocm = getattr(torch.version, "hip", None) is not None
|
||||
|
||||
# ROCm (AMD) and CUDA (NVIDIA) share the same API in PyTorch.
|
||||
# We distinguish them by checking for the HIP version.
|
||||
info: dict[str, Any] = {
|
||||
"type": "ROCm" if is_rocm else "CUDA",
|
||||
"api_name": "HIP Version" if is_rocm else "CUDA Version",
|
||||
"api_version": torch.version.hip if is_rocm else torch.version.cuda, # ty:ignore[unresolved-attribute]
|
||||
"driver_version": get_amdgpu_driver_version()
|
||||
if is_rocm
|
||||
else get_nvidia_driver_version(),
|
||||
"devices": [],
|
||||
}
|
||||
|
||||
for i in range(count):
|
||||
name = torch.cuda.get_device_name(i)
|
||||
vram = torch.cuda.mem_get_info(i)[1] / (1024**3)
|
||||
info["devices"].append({"name": name, "vram_gb": round(vram, 2)})
|
||||
|
||||
return info
|
||||
|
||||
if is_xpu_available():
|
||||
count = torch.xpu.device_count() # ty:ignore[unresolved-attribute]
|
||||
return {
|
||||
"type": "XPU",
|
||||
"api_name": None,
|
||||
"api_version": None,
|
||||
"driver_version": get_xpu_driver_version(),
|
||||
"devices": [{"name": torch.xpu.get_device_name(i)} for i in range(count)], # ty:ignore[unresolved-attribute]
|
||||
}
|
||||
|
||||
if is_mlu_available():
|
||||
count = torch.mlu.device_count() # ty:ignore[unresolved-attribute]
|
||||
return {
|
||||
"type": "MLU",
|
||||
"api_name": None,
|
||||
"api_version": None,
|
||||
"driver_version": None,
|
||||
"devices": [{"name": torch.mlu.get_device_name(i)} for i in range(count)], # ty:ignore[unresolved-attribute]
|
||||
}
|
||||
|
||||
if is_sdaa_available():
|
||||
count = torch.sdaa.device_count() # ty:ignore[unresolved-attribute]
|
||||
return {
|
||||
"type": "SDAA",
|
||||
"api_name": None,
|
||||
"api_version": None,
|
||||
"driver_version": None,
|
||||
"devices": [{"name": torch.sdaa.get_device_name(i)} for i in range(count)], # ty:ignore[unresolved-attribute]
|
||||
}
|
||||
|
||||
if is_musa_available():
|
||||
count = torch.musa.device_count() # ty:ignore[unresolved-attribute]
|
||||
return {
|
||||
"type": "MUSA",
|
||||
"api_name": None,
|
||||
"api_version": None,
|
||||
"driver_version": None,
|
||||
"devices": [{"name": torch.musa.get_device_name(i)} for i in range(count)], # ty:ignore[unresolved-attribute]
|
||||
}
|
||||
|
||||
if is_npu_available():
|
||||
return {
|
||||
"type": "NPU",
|
||||
"api_name": "CANN Version",
|
||||
"api_version": torch.version.cann, # ty:ignore[unresolved-attribute]
|
||||
"driver_version": get_npu_driver_version(),
|
||||
"devices": [], # Multi-NPU is less common.
|
||||
}
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
return {
|
||||
"type": "MPS",
|
||||
"api_name": None,
|
||||
"api_version": None,
|
||||
"driver_version": get_mps_driver_version(),
|
||||
"devices": [{"name": "Apple Metal"}],
|
||||
}
|
||||
|
||||
return {"type": None}
|
||||
|
||||
|
||||
def get_accelerator_info(include_warnings: bool = True) -> str:
|
||||
"""Convenience wrapper for hardware detection and console-friendly formatting."""
|
||||
info = get_accelerator_info_dict()
|
||||
|
||||
if info["type"] is None:
|
||||
suffix = " Operations will be slow." if include_warnings else ""
|
||||
return (
|
||||
f"[bold yellow]No GPU or other accelerator detected.{suffix}[/]\n".strip()
|
||||
)
|
||||
|
||||
devices = info["devices"]
|
||||
count = len(devices)
|
||||
total_vram = sum(d.get("vram_gb", 0) for d in devices)
|
||||
|
||||
vram_suffix = f" ({total_vram:.2f} GB total VRAM)" if total_vram > 0 else ""
|
||||
report = f"Detected [bold]{count or 1}[/] {info['type']} device(s){vram_suffix}\n"
|
||||
|
||||
if info.get("api_name") and info.get("api_version"):
|
||||
report += f"{info['api_name']}: [bold]{info['api_version']}[/]\n"
|
||||
|
||||
driver = info.get("driver_version") or "Unknown"
|
||||
report += f"Driver Version: [bold]{driver}[/]\n"
|
||||
|
||||
for i, dev in enumerate(devices):
|
||||
vram = f" ({dev['vram_gb']:.2f} GB)" if dev.get("vram_gb") else ""
|
||||
report += f"* {info['type']} {i}: [bold]{dev['name']}[/]{vram}\n"
|
||||
|
||||
return report.strip()
|
||||
|
||||
|
||||
def get_cpu_info_dict() -> dict[str, str | int | None]:
|
||||
"""Gets granular CPU identifiers using the py-cpuinfo library."""
|
||||
info = cpuinfo.get_cpu_info()
|
||||
|
||||
return {
|
||||
"brand": info.get("brand_raw"),
|
||||
"vendor": info.get("vendor_id_raw"),
|
||||
"family": info.get("family"),
|
||||
"model": info.get("model"),
|
||||
"stepping": info.get("stepping"),
|
||||
}
|
||||
|
||||
|
||||
def get_cpu_info() -> str:
|
||||
"""Gets the CPU brand name."""
|
||||
info = get_cpu_info_dict()
|
||||
parts = []
|
||||
parts.append(
|
||||
f"Family {info['family']}, Model {info['model']}, Stepping {info['stepping']}"
|
||||
)
|
||||
|
||||
details = f" ({'; '.join(parts)})" if parts else ""
|
||||
brand = info["brand"] or "Unknown CPU"
|
||||
return f"{brand}{details}"
|
||||
|
||||
|
||||
def get_python_env_info_dict() -> dict[str, str]:
|
||||
implementation = platform.python_implementation()
|
||||
compiler = platform.python_compiler()
|
||||
|
||||
# Check for Conda.
|
||||
if "CONDA_PREFIX" in os.environ:
|
||||
env_type = "Conda"
|
||||
# Check for Virtualenv/Venv.
|
||||
elif hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix:
|
||||
env_type = "Virtualenv/Venv"
|
||||
else:
|
||||
env_type = "System"
|
||||
|
||||
return {
|
||||
"version": platform.python_version(),
|
||||
"implementation": implementation,
|
||||
"compiler": compiler,
|
||||
"environment": env_type,
|
||||
}
|
||||
|
||||
|
||||
def get_python_env_info() -> str:
|
||||
"""Detects the type of Python environment (Conda, Venv, etc.) and build info."""
|
||||
info = get_python_env_info_dict()
|
||||
return f"{info['version']} ({info['implementation']}, {info['compiler']}) [{info['environment']}]"
|
||||
|
||||
|
||||
def get_package_version(name: str) -> str | None:
|
||||
"""Gets the installed version of a package, stripping local suffixes like +cu128."""
|
||||
# Normalize name: pip considers hyphens and underscores equivalent.
|
||||
normalized_name = name.lower().replace("_", "-")
|
||||
version_str = importlib.metadata.version(normalized_name)
|
||||
return version_str.split("+")[0] if "+" in version_str else version_str
|
||||
|
||||
|
||||
def get_requirements_dict() -> dict[str, str]:
|
||||
"""Recursively finds all direct and transitive dependencies of heretic-llm and core libraries."""
|
||||
# We start with heretic-llm and the core compute libraries.
|
||||
packages_to_check = ["heretic-llm", "torch", "torchaudio", "torchvision"]
|
||||
visited = set()
|
||||
required_packages = set()
|
||||
|
||||
while packages_to_check:
|
||||
package = packages_to_check.pop(0)
|
||||
# Normalize name: pip considers hyphens and underscores equivalent.
|
||||
normalized_package = package.lower().replace("_", "-")
|
||||
if normalized_package in visited:
|
||||
continue
|
||||
visited.add(normalized_package)
|
||||
|
||||
try:
|
||||
distribution = importlib.metadata.distribution(normalized_package)
|
||||
required_packages.add(normalized_package)
|
||||
if distribution.requires:
|
||||
for requirement in distribution.requires:
|
||||
# Requirements can include environment markers like '; extra == "hf"'
|
||||
# or version constraints. We should ignore optional 'extra' dependencies
|
||||
# to keep the reproduction environment clean and relevant.
|
||||
if ";" in requirement and "extra ==" in requirement:
|
||||
continue
|
||||
|
||||
# We just want the base package name.
|
||||
match = re.match(r"^([a-zA-Z0-9_\-]+)", requirement)
|
||||
if match:
|
||||
dep_name = match.group(0).lower().replace("_", "-")
|
||||
if dep_name not in visited:
|
||||
packages_to_check.append(dep_name)
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
# If a package is listed as a dependency but not installed, we skip it.
|
||||
continue
|
||||
|
||||
# Lookup versions for all discovered packages.
|
||||
dependencies = {}
|
||||
version_info = get_heretic_version_info()
|
||||
for name in required_packages:
|
||||
# If heretic-llm was installed from source (Git/Local), exclude it
|
||||
# from requirements.txt to prevent pip from downloading an unrelated
|
||||
# version from PyPI during reproduction.
|
||||
if name == "heretic-llm" and not version_info.is_standard_pypi:
|
||||
continue
|
||||
|
||||
version_str = get_package_version(name)
|
||||
if version_str:
|
||||
dependencies[name] = version_str
|
||||
|
||||
return dependencies
|
||||
+396
-31
@@ -1,22 +1,22 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
import gc
|
||||
import getpass
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import random
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from importlib.metadata import version
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
import questionary
|
||||
import tomli_w
|
||||
import torch
|
||||
from accelerate.utils import (
|
||||
is_mlu_available,
|
||||
is_musa_available,
|
||||
is_sdaa_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
from datasets import DatasetDict, ReadInstruction, load_dataset, load_from_disk
|
||||
from datasets.config import DATASET_STATE_JSON_FILENAME
|
||||
from datasets.download.download_manager import DownloadMode
|
||||
@@ -27,6 +27,14 @@ from questionary import Choice, Style
|
||||
from rich.console import Console
|
||||
|
||||
from .config import DatasetSpecification, Settings
|
||||
from .system import (
|
||||
get_accelerator_info_dict,
|
||||
get_cpu_info_dict,
|
||||
get_heretic_version_info,
|
||||
get_python_env_info_dict,
|
||||
get_requirements_dict,
|
||||
is_xpu_available,
|
||||
)
|
||||
|
||||
print = Console(highlight=False).print
|
||||
|
||||
@@ -147,6 +155,18 @@ def prompt_password(message: str) -> str:
|
||||
return questionary.password(message).ask()
|
||||
|
||||
|
||||
def prompt_confirm(message: str, default: bool = True) -> bool:
|
||||
if is_notebook():
|
||||
print()
|
||||
choices = "[Y/n]" if default else "[y/N]"
|
||||
result = input(f"{message} {choices} ").strip().lower()
|
||||
if not result:
|
||||
return default
|
||||
return result in ("y", "yes")
|
||||
else:
|
||||
return questionary.confirm(message, default=default).ask()
|
||||
|
||||
|
||||
def format_duration(seconds: float) -> str:
|
||||
seconds = round(seconds)
|
||||
hours, seconds = divmod(seconds, 3600)
|
||||
@@ -234,28 +254,6 @@ def batchify(items: list[T], batch_size: int) -> list[list[T]]:
|
||||
return [items[i : i + batch_size] for i in range(0, len(items), batch_size)]
|
||||
|
||||
|
||||
def empty_cache():
|
||||
# Collecting garbage is not an idempotent operation, and to avoid OOM errors,
|
||||
# gc.collect() has to be called both before and after emptying the backend cache.
|
||||
# See https://github.com/p-e-w/heretic/pull/17 for details.
|
||||
gc.collect()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif is_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_mlu_available():
|
||||
torch.mlu.empty_cache() # ty:ignore[unresolved-attribute]
|
||||
elif is_sdaa_available():
|
||||
torch.sdaa.empty_cache() # ty:ignore[unresolved-attribute]
|
||||
elif is_musa_available():
|
||||
torch.musa.empty_cache() # ty:ignore[unresolved-attribute]
|
||||
elif torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
|
||||
gc.collect()
|
||||
|
||||
|
||||
def get_trial_parameters(trial: Trial) -> dict[str, str]:
|
||||
params = {}
|
||||
|
||||
@@ -283,9 +281,10 @@ def get_readme_intro(
|
||||
else:
|
||||
model_link = f"[{settings.model}](https://huggingface.co/{settings.model})"
|
||||
|
||||
version_info = get_heretic_version_info()
|
||||
return f"""# This is a decensored version of {
|
||||
model_link
|
||||
}, made using [Heretic](https://github.com/p-e-w/heretic) v{version("heretic-llm")}
|
||||
}, made using [Heretic](https://github.com/p-e-w/heretic) v{version_info.version}
|
||||
|
||||
## Abliteration parameters
|
||||
|
||||
@@ -312,3 +311,369 @@ def get_readme_intro(
|
||||
-----
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def generate_config_toml(settings: Settings) -> str:
|
||||
"""Serializes the full Settings object to TOML."""
|
||||
return tomli_w.dumps(settings.model_dump(exclude_none=True))
|
||||
|
||||
|
||||
def generate_requirements_txt() -> str:
|
||||
"""Collects direct project dependencies as a formatted string."""
|
||||
requirements = get_requirements_dict()
|
||||
sorted_requirements = sorted(
|
||||
[f"{name}=={version}" for name, version in requirements.items()],
|
||||
key=lambda x: x.lower(),
|
||||
)
|
||||
return "\n".join(sorted_requirements) + "\n"
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""Sets the seed for all RNGs."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
|
||||
def generate_reproduce_readme(
|
||||
settings: Settings,
|
||||
checkpoint_filename: str,
|
||||
trial: Trial,
|
||||
timestamp: str | None = None,
|
||||
base_model_commit: str | None = None,
|
||||
) -> str:
|
||||
"""Generates a README.md for the reproduce/ folder."""
|
||||
torch_version = torch.__version__
|
||||
install_hint = f"pip install torch=={torch_version}"
|
||||
if "+" in torch_version:
|
||||
suffix = torch_version.split("+")[1]
|
||||
if suffix:
|
||||
install_hint += f" --index-url https://download.pytorch.org/whl/{suffix}"
|
||||
|
||||
heterogeneous_warning = ""
|
||||
if torch.cuda.is_available():
|
||||
count = torch.cuda.device_count()
|
||||
if count > 1:
|
||||
device_names = {torch.cuda.get_device_name(i) for i in range(count)}
|
||||
if len(device_names) > 1:
|
||||
heterogeneous_warning = """
|
||||
> [!WARNING]
|
||||
> **Heterogeneous GPUs Detected!**
|
||||
> This system uses multiple non-identical GPUs. When operations are distributed across different GPUs (e.g. via `device_map='auto'`), non-deterministic behavior can occur. **Reproducibility ***cannot*** be guaranteed in this environment.**
|
||||
"""
|
||||
|
||||
version_info = get_heretic_version_info()
|
||||
origin_warning = ""
|
||||
if not version_info.is_standard_pypi:
|
||||
if version_info.origin and version_info.origin.startswith("Git"):
|
||||
repo_info = version_info.origin.split("Git (")[1].strip(")")
|
||||
origin_warning = f"""
|
||||
> [!NOTE]
|
||||
> **Git Installation Detected**
|
||||
> This system installed `heretic-llm` from source repository: `{repo_info}`.
|
||||
> To reproduce these results, you must install Heretic from this exact repository and commit.
|
||||
"""
|
||||
elif version_info.origin == "Local":
|
||||
origin_warning = """
|
||||
> [!WARNING]
|
||||
> **Local Code Detected!**
|
||||
> This system installed `heretic-llm` from a local directory or wheel. Uncommitted or experimental code may have been executed. **Reproducibility ***cannot*** be guaranteed in this environment.**
|
||||
"""
|
||||
else:
|
||||
origin_warning = """
|
||||
> [!WARNING]
|
||||
> **Non-Standard Installation Detected!**
|
||||
> This system installed `heretic-llm` from an unknown non-standard source. **Reproducibility ***cannot*** be guaranteed in this environment.**
|
||||
"""
|
||||
|
||||
def format_hf_link(
|
||||
name: str, commit: str | None = None, is_dataset: bool = False
|
||||
) -> str:
|
||||
if Path(name).exists():
|
||||
return f"`{name}` (Local)"
|
||||
|
||||
prefix = "datasets/" if is_dataset else ""
|
||||
base_url = f"https://huggingface.co/{prefix}{name}"
|
||||
link = f"[{name}]({base_url})"
|
||||
if commit:
|
||||
commit_url = f"{base_url}/commit/{commit}"
|
||||
link += f" (Commit: [{commit[:7]}]({commit_url}))"
|
||||
return link
|
||||
|
||||
model_link = format_hf_link(settings.model, base_model_commit)
|
||||
dataset_info = f"""## Dataset Information
|
||||
|
||||
- **Good Prompts:** {format_hf_link(settings.good_prompts.dataset, settings.good_prompts.commit, is_dataset=True)}
|
||||
- **Bad Prompts:** {format_hf_link(settings.bad_prompts.dataset, settings.bad_prompts.commit, is_dataset=True)}
|
||||
- **Good Evaluation Prompts:** {format_hf_link(settings.good_evaluation_prompts.dataset, settings.good_evaluation_prompts.commit, is_dataset=True)}
|
||||
- **Bad Evaluation Prompts:** {format_hf_link(settings.bad_evaluation_prompts.dataset, settings.bad_evaluation_prompts.commit, is_dataset=True)}"""
|
||||
|
||||
timestamp_str = f"- **Run started at (UTC):** `{timestamp}`" if timestamp else ""
|
||||
|
||||
# System and Accelerator info using structured dictionaries.
|
||||
cpu = get_cpu_info_dict()
|
||||
python_env = get_python_env_info_dict()
|
||||
accelerator = get_accelerator_info_dict()
|
||||
|
||||
# Build System Environment section.
|
||||
system_env_lines = [
|
||||
f"- **OS:** `{platform.platform()}` (`{platform.machine()}`)",
|
||||
f"- **CPU:** `{cpu['brand'] or 'Unknown CPU'}`",
|
||||
f" - **Information:** Family `{cpu['family']}`, Model `{cpu['model']}`, Stepping `{cpu['stepping']}`",
|
||||
]
|
||||
|
||||
system_env_lines.extend(
|
||||
[
|
||||
f"- **Python:** `{python_env['version']}` (`{python_env['implementation']}`, `{python_env['compiler']}`) [`{python_env['environment']}`]",
|
||||
f"- **Heretic:** `v{version_info.version}`"
|
||||
+ (f" (Origin: `{version_info.origin}`)" if version_info.origin else ""),
|
||||
f"- **PyTorch:** `{torch.__version__}`",
|
||||
]
|
||||
)
|
||||
system_environment_report = "\n".join(system_env_lines)
|
||||
|
||||
# Build Accelerators section.
|
||||
if accelerator["type"] is None:
|
||||
accelerator_report = "> [!WARNING]\n> **No GPU or other accelerator detected.**"
|
||||
else:
|
||||
devices = accelerator["devices"]
|
||||
total_vram = sum(d.get("vram_gb", 0) for d in devices)
|
||||
vram_suffix = f" (`{total_vram:.2f} GB` total VRAM)" if total_vram > 0 else ""
|
||||
accelerator_lines = [
|
||||
f"- **{accelerator['type']}:** Detected `{len(devices)}` device(s){vram_suffix}"
|
||||
]
|
||||
|
||||
if accelerator.get("api_name") and accelerator.get("api_version"):
|
||||
accelerator_lines.append(
|
||||
f" - **{accelerator['api_name']}:** `{accelerator['api_version']}`"
|
||||
)
|
||||
|
||||
if accelerator.get("driver_version"):
|
||||
accelerator_lines.append(
|
||||
f" - **Driver Version:** `{accelerator['driver_version']}`"
|
||||
)
|
||||
|
||||
accelerator_lines.append("- **Devices:**")
|
||||
for i, dev in enumerate(devices):
|
||||
vram = f" (`{dev['vram_gb']:.2f} GB`)" if dev.get("vram_gb") else ""
|
||||
accelerator_lines.append(
|
||||
f" - **{accelerator['type']} {i}:** `{dev['name']}`{vram}"
|
||||
)
|
||||
accelerator_report = "\n".join(accelerator_lines)
|
||||
|
||||
return f"""# Reproduction Guide
|
||||
|
||||
This directory contains the necessary information and assets to reproduce the results obtained during this Heretic run.{heterogeneous_warning}{origin_warning}
|
||||
|
||||
## Model Information
|
||||
|
||||
- **Base Model:** {model_link}
|
||||
{timestamp_str}
|
||||
|
||||
{dataset_info}
|
||||
|
||||
## Selected Trial
|
||||
|
||||
- **Trial Number:** `#{trial.user_attrs["index"]}`
|
||||
- **Refusal Count:** `{trial.user_attrs.get("refusals")}/{trial.user_attrs.get("total_refusal_prompts")}`
|
||||
- **KL Divergence:** `{trial.user_attrs.get("kl_divergence", 0):.6f}`
|
||||
|
||||
## System Environment
|
||||
|
||||
{system_environment_report}
|
||||
|
||||
### Accelerators
|
||||
|
||||
{accelerator_report}
|
||||
|
||||
## Contents
|
||||
|
||||
- **config.toml**: The exact configuration used, including the seed `{settings.seed}`.
|
||||
- **requirements.txt**: The exact versions of all installed Python packages.
|
||||
- **{checkpoint_filename}**: The Optuna study journal containing the history of all trials.
|
||||
- **reproduce.json**: A machine-readable version of this report.
|
||||
- **SHA256SUMS**: Cryptographic hashes for all uploaded weight files (if applicable).
|
||||
|
||||
## How to Reproduce
|
||||
|
||||
1. Ensure your hardware and environment match the specifications in the **System Environment** section above.
|
||||
2. Install the exact package versions listed in `requirements.txt`.
|
||||
3. Place the provided `config.toml` in your working directory.
|
||||
4. Run `heretic` without any additional arguments.
|
||||
5. Verify the integrity of the reproduced files by comparing their SHA256 hashes against the manifest in `SHA256SUMS`.
|
||||
|
||||
> [!TIP]
|
||||
> To use the included Optuna study journal `{checkpoint_filename}`, place it in a `checkpoints/` directory before running `heretic` on the same model.
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Make sure to install correct PyTorch version from: `{install_hint}`
|
||||
"""
|
||||
|
||||
|
||||
def generate_reproduce_json(
|
||||
settings: Settings,
|
||||
trial: Trial,
|
||||
timestamp: str | None = None,
|
||||
base_model_commit: str | None = None,
|
||||
uploaded_model_hashes: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
"""Generates a reproduce.json file for the reproduce/ folder."""
|
||||
version_info = get_heretic_version_info()
|
||||
data = {
|
||||
"base_model": {
|
||||
"id": settings.model,
|
||||
"commit_hash": base_model_commit,
|
||||
},
|
||||
"system": {
|
||||
"os": {"platform": platform.platform(), "machine": platform.machine()},
|
||||
"cpu": get_cpu_info_dict(),
|
||||
"python": get_python_env_info_dict(),
|
||||
"heretic": {
|
||||
"version": version_info.version,
|
||||
"is_standard_pypi": version_info.is_standard_pypi,
|
||||
"metadata": version_info.metadata,
|
||||
},
|
||||
"pytorch_version": torch.__version__,
|
||||
"accelerator": get_accelerator_info_dict(),
|
||||
},
|
||||
"requirements": get_requirements_dict(),
|
||||
"settings": settings.model_dump(exclude_none=True),
|
||||
"trial": {
|
||||
"direction_index": trial.user_attrs.get("direction_index"),
|
||||
"parameters": trial.user_attrs.get("parameters"),
|
||||
"metrics": {
|
||||
"refusals": trial.user_attrs.get("refusals"),
|
||||
"total_refusal_prompts": trial.user_attrs.get("total_refusal_prompts"),
|
||||
"kl_divergence": trial.user_attrs.get("kl_divergence"),
|
||||
},
|
||||
},
|
||||
"timestamp": timestamp,
|
||||
"uploaded_model_hashes": uploaded_model_hashes or {},
|
||||
}
|
||||
return json.dumps(data, indent=4)
|
||||
|
||||
|
||||
def generate_sha256sums(hashes: dict[str, str]) -> str:
|
||||
"""Generates a GNU Coreutils compatible SHA256SUMS file content."""
|
||||
lines = []
|
||||
for filename, sha256 in sorted(hashes.items()):
|
||||
# Use '*' to indicate binary mode for model weights.
|
||||
lines.append(f"{sha256} *{filename}")
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
def create_reproduce_folder(
|
||||
path: Path,
|
||||
settings: Settings,
|
||||
checkpoint_path: str | Path,
|
||||
trial: Trial,
|
||||
uploaded_model_hashes: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
reproduce_dir = path / "reproduce"
|
||||
reproduce_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
checkpoint_filename = Path(checkpoint_path).name
|
||||
|
||||
# Fetch commit hashes for all HF datasets to ensure reproducibility.
|
||||
for spec in [
|
||||
settings.good_prompts,
|
||||
settings.bad_prompts,
|
||||
settings.good_evaluation_prompts,
|
||||
settings.bad_evaluation_prompts,
|
||||
]:
|
||||
if not Path(spec.dataset).exists():
|
||||
# Fail if the dataset is missing or unreachable.
|
||||
spec.commit = huggingface_hub.dataset_info(spec.dataset).sha
|
||||
|
||||
# Fetch commit hash for the base model if it's on HF.
|
||||
base_model_commit = None
|
||||
if not Path(settings.model).exists():
|
||||
try:
|
||||
base_model_commit = huggingface_hub.model_info(settings.model).sha
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Strip microseconds and timezone for a clean format.
|
||||
timestamp = (
|
||||
datetime.now(timezone.utc).replace(microsecond=0, tzinfo=None).isoformat()
|
||||
)
|
||||
|
||||
(reproduce_dir / "config.toml").write_text(
|
||||
generate_config_toml(settings), encoding="utf-8"
|
||||
)
|
||||
(reproduce_dir / "requirements.txt").write_text(
|
||||
generate_requirements_txt(), encoding="utf-8"
|
||||
)
|
||||
(reproduce_dir / "README.md").write_text(
|
||||
generate_reproduce_readme(
|
||||
settings,
|
||||
checkpoint_filename,
|
||||
trial,
|
||||
timestamp=timestamp,
|
||||
base_model_commit=base_model_commit,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
if uploaded_model_hashes:
|
||||
(reproduce_dir / "SHA256SUMS").write_text(
|
||||
generate_sha256sums(uploaded_model_hashes), encoding="utf-8"
|
||||
)
|
||||
(reproduce_dir / "reproduce.json").write_text(
|
||||
generate_reproduce_json(
|
||||
settings,
|
||||
trial,
|
||||
timestamp=timestamp,
|
||||
base_model_commit=base_model_commit,
|
||||
uploaded_model_hashes=uploaded_model_hashes,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# Copy Optuna study journal.
|
||||
checkpoint_file = Path(checkpoint_path)
|
||||
if checkpoint_file.exists():
|
||||
(reproduce_dir / checkpoint_file.name).write_bytes(checkpoint_file.read_bytes())
|
||||
|
||||
|
||||
def upload_reproduce_folder(
|
||||
repo_id: str,
|
||||
settings: Settings,
|
||||
token: str,
|
||||
checkpoint_path: str | Path,
|
||||
trial: Trial,
|
||||
) -> None:
|
||||
uploaded_model_hashes = {}
|
||||
try:
|
||||
api = huggingface_hub.HfApi()
|
||||
info = api.model_info(repo_id=repo_id, files_metadata=True, token=token)
|
||||
# For weights, we only care about safetensors.
|
||||
weight_extensions = (".safetensors",)
|
||||
if info.siblings is not None:
|
||||
for file in info.siblings:
|
||||
if file.rfilename.endswith(weight_extensions):
|
||||
sha256 = getattr(file, "lfs", {}).get("sha256")
|
||||
if sha256:
|
||||
uploaded_model_hashes[file.rfilename] = sha256
|
||||
except Exception as e:
|
||||
# Fail if integrity checks cannot be completed.
|
||||
raise RuntimeError(f"Could not fetch uploaded model hashes: {e}") from e
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
tmp_path = Path(tmpdir)
|
||||
create_reproduce_folder(
|
||||
tmp_path,
|
||||
settings,
|
||||
checkpoint_path=checkpoint_path,
|
||||
trial=trial,
|
||||
uploaded_model_hashes=uploaded_model_hashes,
|
||||
)
|
||||
|
||||
reproduce_dir = tmp_path / "reproduce"
|
||||
for file_path in reproduce_dir.iterdir():
|
||||
if file_path.is_file():
|
||||
huggingface_hub.upload_file(
|
||||
path_or_fileobj=str(file_path),
|
||||
path_in_repo=f"reproduce/{file_path.name}",
|
||||
repo_id=repo_id,
|
||||
token=token,
|
||||
)
|
||||
|
||||
@@ -588,7 +588,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "4.7.0"
|
||||
version = "4.8.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "dill" },
|
||||
@@ -607,9 +607,9 @@ dependencies = [
|
||||
{ name = "tqdm" },
|
||||
{ name = "xxhash" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/1c/9c/ba18de0b70858533e422ed6cfe0e46789473cef7fc7fc3653e23fa494730/datasets-4.7.0.tar.gz", hash = "sha256:4984cdfc65d04464da7f95205a55cb50515fd94ae3176caacb50a1b7273792e2", size = 602008, upload-time = "2026-03-09T19:01:49.298Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/22/22/73e46ac7a8c25e7ef0b3bd6f10da3465021d90219a32eb0b4d2afea4c56e/datasets-4.8.4.tar.gz", hash = "sha256:a1429ed853275ce7943a01c6d2e25475b4501eb758934362106a280470df3a52", size = 604382, upload-time = "2026-03-23T14:21:17.987Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/1e/03/c6d9c3119cf712f638fe763e887ecaac6acbb62bf1e2acc3cbde0df340fd/datasets-4.7.0-py3-none-any.whl", hash = "sha256:d5fe3025ec6acc3b5649f10d5576dff5e054134927604e6913c1467a04adc3c2", size = 527530, upload-time = "2026-03-09T19:01:47.443Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b0/e5/247d094108e42ac26363ab8dc57f168840cf7c05774b40ffeb0d78868fcc/datasets-4.8.4-py3-none-any.whl", hash = "sha256:cdc8bee4698e549d78bf1fed6aea2eebc760b22b084f07e6fc020c6577a6ce6d", size = 526991, upload-time = "2026-03-23T14:21:15.89Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -944,9 +944,11 @@ dependencies = [
|
||||
{ name = "optuna" },
|
||||
{ name = "peft" },
|
||||
{ name = "psutil" },
|
||||
{ name = "py-cpuinfo" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "questionary" },
|
||||
{ name = "rich" },
|
||||
{ name = "tomli-w" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
@@ -986,10 +988,12 @@ requires-dist = [
|
||||
{ name = "pacmap", marker = "extra == 'research'", specifier = "~=0.8" },
|
||||
{ name = "peft", specifier = "~=0.18" },
|
||||
{ name = "psutil", specifier = "~=7.2" },
|
||||
{ name = "py-cpuinfo", specifier = "~=9.0" },
|
||||
{ name = "pydantic-settings", specifier = "~=2.13" },
|
||||
{ name = "questionary", specifier = "~=2.1" },
|
||||
{ name = "rich", specifier = "~=14.3" },
|
||||
{ name = "scikit-learn", marker = "extra == 'research'", specifier = "~=1.7" },
|
||||
{ name = "tomli-w", specifier = "~=1.2" },
|
||||
{ name = "tqdm", specifier = "~=4.67" },
|
||||
{ name = "transformers", specifier = "~=5.3" },
|
||||
]
|
||||
@@ -1095,7 +1099,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "1.7.1"
|
||||
version = "1.7.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "filelock" },
|
||||
@@ -1108,9 +1112,9 @@ dependencies = [
|
||||
{ name = "typer" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b4/a8/94ccc0aec97b996a3a68f3e1fa06a4bd7185dd02bf22bfba794a0ade8440/huggingface_hub-1.7.1.tar.gz", hash = "sha256:be38fe66e9b03c027ad755cb9e4b87ff0303c98acf515b5d579690beb0bf3048", size = 722097, upload-time = "2026-03-13T09:36:07.758Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/19/15/eafc1c57bf0f8afffb243dcd4c0cceb785e956acc17bba4d9bf2ae21fc9c/huggingface_hub-1.7.2.tar.gz", hash = "sha256:7f7e294e9bbb822e025bdb2ada025fa4344d978175a7f78e824d86e35f7ab43b", size = 724684, upload-time = "2026-03-20T10:36:08.767Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/75/ca21955d6117a394a482c7862ce96216239d0e3a53133ae8510727a8bcfa/huggingface_hub-1.7.1-py3-none-any.whl", hash = "sha256:38c6cce7419bbde8caac26a45ed22b0cea24152a8961565d70ec21f88752bfaa", size = 616308, upload-time = "2026-03-13T09:36:06.062Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/08/de/3ad061a05f74728927ded48c90b73521b9a9328c85d841bdefb30e01fb85/huggingface_hub-1.7.2-py3-none-any.whl", hash = "sha256:288f33a0a17b2a73a1359e2a5fd28d1becb2c121748c6173ab8643fb342c850e", size = 618036, upload-time = "2026-03-20T10:36:06.824Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1180,7 +1184,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "kernels"
|
||||
version = "0.12.2"
|
||||
version = "0.12.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "huggingface-hub" },
|
||||
@@ -1188,9 +1192,9 @@ dependencies = [
|
||||
{ name = "pyyaml" },
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a9/07/d2b635e965b232cae1aa873c6e0458947196be8dca7bb02e64d3cd6e8d19/kernels-0.12.2.tar.gz", hash = "sha256:812fc43c2814f046cee655cbebf3918cddd489715773670bdb38cca3f5203b5b", size = 57108, upload-time = "2026-03-04T10:03:00.379Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b3/84/9f68f355f6ce99e977872021fbdbafadcf2820f51d3f7bd697ec3801cb7a/kernels-0.12.3.tar.gz", hash = "sha256:87e29716578e7e71dc5a7578e0132bfdae305bedaeb602698f87c88ca6c60e32", size = 57407, upload-time = "2026-03-20T10:20:42.166Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/08/be/f5d6758b48633e4f6a28198fcf4bf9f763cc6a82e2335d9fe8802a5cb440/kernels-0.12.2-py3-none-any.whl", hash = "sha256:1289261804748cf3cf8e3afab80b505b0f1b28e4ec88379cdf08dc31e64964b8", size = 55205, upload-time = "2026-03-04T10:02:59.305Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e7/3e/778e4a86830e9139df2d16d86c4488fce426ec19daa83cbd2854ef389030/kernels-0.12.3-py3-none-any.whl", hash = "sha256:5d1d33fcb774e03bb7f0688ac24d91ef6b963692f80f0a85ddd2286e69f3cf2f", size = 55501, upload-time = "2026-03-20T10:20:40.643Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2241,7 +2245,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "optuna"
|
||||
version = "4.7.0"
|
||||
version = "4.8.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "alembic" },
|
||||
@@ -2253,9 +2257,9 @@ dependencies = [
|
||||
{ name = "sqlalchemy" },
|
||||
{ name = "tqdm" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/58/b2/b5e12de7b4486556fe2257611b55dbabf30d0300bdb031831aa943ad20e4/optuna-4.7.0.tar.gz", hash = "sha256:d91817e2079825557bd2e97de2e8c9ae260bfc99b32712502aef8a5095b2d2c0", size = 479740, upload-time = "2026-01-19T05:45:52.604Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/bf/9b/62f120fb2ecbc4338bee70c5a3671c8e561714f3aa1a046b897ff142050e/optuna-4.8.0.tar.gz", hash = "sha256:6f7043e9f8ecb5e607af86a7eb00fb5ec2be26c3b08c201209a73d36aff37a38", size = 482603, upload-time = "2026-03-16T04:59:58.659Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/75/d1/6c8a4fbb38a9e3565f5c36b871262a85ecab3da48120af036b1e4937a15c/optuna-4.7.0-py3-none-any.whl", hash = "sha256:e41ec84018cecc10eabf28143573b1f0bde0ba56dba8151631a590ecbebc1186", size = 413894, upload-time = "2026-01-19T05:45:50.815Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ac/24/7c731839566d30dc70556d9824ef17692d896c15e3df627bce8c16f753e1/optuna-4.8.0-py3-none-any.whl", hash = "sha256:c57a7682679c36bfc9bca0da430698179e513874074b71bebedb0334964ab930", size = 419456, upload-time = "2026-03-16T04:59:56.977Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2641,6 +2645,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "py-cpuinfo"
|
||||
version = "9.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/37/a8/d832f7293ebb21690860d2e01d8115e5ff6f2ae8bbdc953f0eb0fa4bd2c7/py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690", size = 104716, upload-time = "2022-10-25T20:38:06.303Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "22.0.0"
|
||||
@@ -3670,6 +3683,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomli-w"
|
||||
version = "1.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/19/75/241269d1da26b624c0d5e110e8149093c759b7a286138f4efd61a60e75fe/tomli_w-1.2.0.tar.gz", hash = "sha256:2dd14fac5a47c27be9cd4c976af5a12d87fb1f0b4512f81d69cce3b35ae25021", size = 7184, upload-time = "2025-01-15T12:07:24.262Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/18/c86eb8e0202e32dd3df50d43d7ff9854f8e0603945ff398974c1d91ac1ef/tomli_w-1.2.0-py3-none-any.whl", hash = "sha256:188306098d013b691fcadc011abd66727d3c414c571bb01b1a174ba8c983cf90", size = 6675, upload-time = "2025-01-15T12:07:22.074Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch"
|
||||
version = "2.9.1"
|
||||
|
||||
Reference in New Issue
Block a user