mirror of
https://github.com/p-e-w/heretic.git
synced 2026-06-02 05:03:33 +02:00
fix: replace tqdm progress bars with Rich progress bars
This commit is contained in:
@@ -38,6 +38,7 @@ dependencies = [
|
||||
"pydantic-settings~=2.13",
|
||||
"questionary~=2.1",
|
||||
"rich~=14.3",
|
||||
"tqdm~=4.67",
|
||||
"transformers~=5.3",
|
||||
]
|
||||
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
# ruff: noqa: E402
|
||||
|
||||
from .progress import patch_tqdm
|
||||
|
||||
# This patches tqdm class definitions, which must happen
|
||||
# before any other module imports tqdm.
|
||||
patch_tqdm()
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
|
||||
@@ -91,7 +91,7 @@ class Model:
|
||||
self.trusted_models[settings.evaluate_model] = settings.trust_remote_code
|
||||
|
||||
for dtype in settings.dtypes:
|
||||
print(f"* Trying dtype [bold]{dtype}[/]... ", end="")
|
||||
print(f"* Trying dtype [bold]{dtype}[/]...")
|
||||
|
||||
try:
|
||||
quantization_config = self._get_quantization_config(dtype)
|
||||
@@ -131,13 +131,11 @@ class Model:
|
||||
except Exception as error:
|
||||
self.model = None # ty:ignore[invalid-assignment]
|
||||
empty_cache()
|
||||
print(f"[red]Failed[/] ({error})")
|
||||
print(f"* [red]Failed[/] ({error})")
|
||||
continue
|
||||
|
||||
if settings.quantization == QuantizationMethod.BNB_4BIT:
|
||||
print("[green]Ok[/] (quantized to 4-bit precision)")
|
||||
else:
|
||||
print("[green]Ok[/]")
|
||||
print("* Quantized to 4-bit precision")
|
||||
|
||||
break
|
||||
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: AGPL-3.0-or-later
|
||||
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
|
||||
|
||||
from typing import Any
|
||||
|
||||
import tqdm
|
||||
import tqdm.auto
|
||||
from rich.progress import Progress
|
||||
|
||||
|
||||
# A class that provides the same interface as tqdm,
|
||||
# but displays progress bars using Rich.
|
||||
class TqdmShim(tqdm.tqdm):
|
||||
def __init__(self, *args: Any, **kwargs: Any):
|
||||
self.rich_progress = Progress(transient=True)
|
||||
self.rich_progress.start()
|
||||
self.rich_task_id = self.rich_progress.add_task(
|
||||
kwargs.get("desc", ""),
|
||||
total=kwargs.get("total", None),
|
||||
)
|
||||
|
||||
# Chain up to the parent constructor to ensure that the internal state of the superclass
|
||||
# is correctly initialized, which some methods that we don't override might rely on.
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def display(self, *args: Any, **kwargs: Any):
|
||||
self.rich_progress.update(
|
||||
self.rich_task_id,
|
||||
description=self.desc,
|
||||
total=self.total,
|
||||
completed=self.n,
|
||||
)
|
||||
|
||||
def close(self, *args: Any, **kwargs: Any):
|
||||
self.rich_progress.stop()
|
||||
|
||||
|
||||
def patch_tqdm():
|
||||
tqdm.tqdm = TqdmShim # ty:ignore[invalid-assignment]
|
||||
tqdm.auto.tqdm = TqdmShim # ty:ignore[invalid-assignment]
|
||||
@@ -953,6 +953,7 @@ dependencies = [
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "questionary" },
|
||||
{ name = "rich" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "transformers" },
|
||||
]
|
||||
|
||||
@@ -961,8 +962,6 @@ research = [
|
||||
{ name = "geom-median" },
|
||||
{ name = "imageio" },
|
||||
{ name = "matplotlib" },
|
||||
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
{ name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||
{ name = "pacmap" },
|
||||
{ name = "scikit-learn", version = "1.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
|
||||
{ name = "scikit-learn", version = "1.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
|
||||
@@ -983,13 +982,12 @@ requires-dist = [
|
||||
{ name = "hf-transfer", specifier = "~=0.1" },
|
||||
{ name = "huggingface-hub", specifier = "~=1.7" },
|
||||
{ name = "imageio", marker = "extra == 'research'", specifier = "~=2.37" },
|
||||
{ name = "immutabledict", specifier = ">=4.3.1" },
|
||||
{ name = "immutabledict", specifier = "~=4.3" },
|
||||
{ name = "kernels", specifier = "~=0.12" },
|
||||
{ name = "langdetect", specifier = ">=1.0.9" },
|
||||
{ name = "lm-eval", extras = ["hf"], specifier = "~=0.4.11" },
|
||||
{ name = "langdetect", specifier = "~=1.0" },
|
||||
{ name = "lm-eval", extras = ["hf"], specifier = "~=0.4" },
|
||||
{ name = "matplotlib", marker = "extra == 'research'", specifier = "~=3.10" },
|
||||
{ name = "numpy", specifier = ">=2.2.6" },
|
||||
{ name = "numpy", marker = "extra == 'research'", specifier = "~=2.2" },
|
||||
{ name = "numpy", specifier = "~=2.2" },
|
||||
{ name = "optuna", specifier = "~=4.7" },
|
||||
{ name = "pacmap", marker = "extra == 'research'", specifier = "~=0.8" },
|
||||
{ name = "peft", specifier = "~=0.18" },
|
||||
@@ -998,6 +996,7 @@ requires-dist = [
|
||||
{ name = "questionary", specifier = "~=2.1" },
|
||||
{ name = "rich", specifier = "~=14.3" },
|
||||
{ name = "scikit-learn", marker = "extra == 'research'", specifier = "~=1.7" },
|
||||
{ name = "tqdm", specifier = "~=4.67" },
|
||||
{ name = "transformers", specifier = "~=5.3" },
|
||||
]
|
||||
provides-extras = ["research"]
|
||||
|
||||
Reference in New Issue
Block a user