mirror of
https://github.com/p-e-w/heretic.git
synced 2026-06-02 05:03:33 +02:00
feat: add functionality for collecting reproduce.json files from Hugging Face
This commit is contained in:
@@ -103,6 +103,16 @@ class Settings(BaseSettings):
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
collect_reproducibles: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"If this directory path is set, then instead of abliterating a model, "
|
||||
"download all reproduce.json files from public Heretic model repositories "
|
||||
"on Hugging Face, and store them in that directory for archival purposes."
|
||||
),
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
dtypes: list[str] = Field(
|
||||
default=[
|
||||
# In practice, "auto" almost always means bfloat16.
|
||||
|
||||
@@ -65,6 +65,7 @@ from .analyzer import Analyzer
|
||||
from .config import QuantizationMethod
|
||||
from .evaluator import Evaluator
|
||||
from .model import AbliterationParameters, Model, get_model_class
|
||||
from .reproduce import collect_reproducibles
|
||||
from .system import empty_cache, get_accelerator_info
|
||||
from .utils import (
|
||||
format_duration,
|
||||
@@ -177,6 +178,8 @@ def run():
|
||||
if (
|
||||
# There is at least one argument (argv[0] is the program name).
|
||||
len(sys.argv) > 1
|
||||
# Heretic is being invoked in standard (model processing) mode.
|
||||
and "--collect-reproducibles" not in sys.argv
|
||||
# No model has been explicitly provided.
|
||||
and "--model" not in sys.argv
|
||||
# The last argument is a parameter value rather than a flag (such as "--help").
|
||||
@@ -185,6 +188,11 @@ def run():
|
||||
# Assume the last argument is the model.
|
||||
sys.argv.insert(-1, "--model")
|
||||
|
||||
# Work around the "model" argument being required
|
||||
# when Heretic is invoked in a non-processing mode.
|
||||
if "--collect-reproducibles" in sys.argv and "--model" not in sys.argv:
|
||||
sys.argv.extend(["--model", ""])
|
||||
|
||||
try:
|
||||
# The required argument "model" must be provided by the user,
|
||||
# either on the command line or in the configuration file.
|
||||
@@ -201,6 +209,10 @@ def run():
|
||||
)
|
||||
return
|
||||
|
||||
if settings.collect_reproducibles is not None:
|
||||
collect_reproducibles(settings.collect_reproducibles)
|
||||
return
|
||||
|
||||
if settings.seed is None:
|
||||
settings.seed = random.randint(0, 2**32 - 1)
|
||||
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from huggingface_hub.utils import disable_progress_bars, enable_progress_bars
|
||||
|
||||
from .utils import print
|
||||
|
||||
|
||||
def collect_reproducibles(path: str):
|
||||
print(
|
||||
f"Collecting [bold]reproduce.json[/] files from Hugging Face and storing them in [bold]{path}[/]..."
|
||||
)
|
||||
print()
|
||||
|
||||
api = HfApi()
|
||||
|
||||
models = api.list_models(
|
||||
filter=["heretic", "reproducible"],
|
||||
sort="created_at",
|
||||
)
|
||||
|
||||
found = 0
|
||||
downloaded = 0
|
||||
|
||||
# We're only downloading tiny files, so the progress bars are just noise.
|
||||
disable_progress_bars()
|
||||
|
||||
try:
|
||||
for model in models:
|
||||
# Ignore repositories containing quantizations.
|
||||
if model.tags is not None and "gguf" in model.tags:
|
||||
continue
|
||||
|
||||
print(f"[bold]{model.id}[/]...", end="")
|
||||
|
||||
user, repository = model.id.split("/")
|
||||
|
||||
paths_info = api.get_paths_info(
|
||||
model.id,
|
||||
"reproduce/reproduce.json",
|
||||
expand=True,
|
||||
)
|
||||
# The reproduce.json file might not exist in the repository
|
||||
# despite the relevant tags being present.
|
||||
if not paths_info:
|
||||
print(" [yellow]no reproduce.json found[/]")
|
||||
continue
|
||||
|
||||
found += 1
|
||||
|
||||
commit_hash = paths_info[0].last_commit.oid
|
||||
|
||||
file_path = (
|
||||
Path(path)
|
||||
/ "huggingface.co"
|
||||
/ user
|
||||
/ f"{repository}-{commit_hash[:7]}.json"
|
||||
)
|
||||
if file_path.exists():
|
||||
print(" already stored")
|
||||
continue
|
||||
|
||||
cache_path = hf_hub_download(
|
||||
model.id,
|
||||
"reproduce/reproduce.json",
|
||||
)
|
||||
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copyfile(cache_path, file_path)
|
||||
print(" [green]downloaded[/]")
|
||||
|
||||
downloaded += 1
|
||||
finally:
|
||||
enable_progress_bars()
|
||||
|
||||
print()
|
||||
print(f"Found: [bold]{found}[/] files")
|
||||
print(f"Downloaded: [bold]{downloaded}[/] files")
|
||||
print(f"Already stored: [bold]{found - downloaded}[/] files")
|
||||
Reference in New Issue
Block a user