fix: replace tqdm progress bars with Rich progress bars

This commit is contained in:
Philipp Emanuel Weidmann
2026-03-28 18:30:15 +05:30
parent 1126332281
commit 96c7a7d98a
5 changed files with 58 additions and 12 deletions
+1
View File
@@ -38,6 +38,7 @@ dependencies = [
"pydantic-settings~=2.13", "pydantic-settings~=2.13",
"questionary~=2.1", "questionary~=2.1",
"rich~=14.3", "rich~=14.3",
"tqdm~=4.67",
"transformers~=5.3", "transformers~=5.3",
] ]
+8
View File
@@ -1,6 +1,14 @@
# SPDX-License-Identifier: AGPL-3.0-or-later # SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors # Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
# ruff: noqa: E402
from .progress import patch_tqdm
# This patches tqdm class definitions, which must happen
# before any other module imports tqdm.
patch_tqdm()
import logging import logging
import math import math
import os import os
+3 -5
View File
@@ -91,7 +91,7 @@ class Model:
self.trusted_models[settings.evaluate_model] = settings.trust_remote_code self.trusted_models[settings.evaluate_model] = settings.trust_remote_code
for dtype in settings.dtypes: for dtype in settings.dtypes:
print(f"* Trying dtype [bold]{dtype}[/]... ", end="") print(f"* Trying dtype [bold]{dtype}[/]...")
try: try:
quantization_config = self._get_quantization_config(dtype) quantization_config = self._get_quantization_config(dtype)
@@ -131,13 +131,11 @@ class Model:
except Exception as error: except Exception as error:
self.model = None # ty:ignore[invalid-assignment] self.model = None # ty:ignore[invalid-assignment]
empty_cache() empty_cache()
print(f"[red]Failed[/] ({error})") print(f"* [red]Failed[/] ({error})")
continue continue
if settings.quantization == QuantizationMethod.BNB_4BIT: if settings.quantization == QuantizationMethod.BNB_4BIT:
print("[green]Ok[/] (quantized to 4-bit precision)") print("* Quantized to 4-bit precision")
else:
print("[green]Ok[/]")
break break
+40
View File
@@ -0,0 +1,40 @@
# SPDX-License-Identifier: AGPL-3.0-or-later
# Copyright (C) 2025-2026 Philipp Emanuel Weidmann <pew@worldwidemann.com> + contributors
from typing import Any
import tqdm
import tqdm.auto
from rich.progress import Progress
# A class that provides the same interface as tqdm,
# but displays progress bars using Rich.
class TqdmShim(tqdm.tqdm):
def __init__(self, *args: Any, **kwargs: Any):
self.rich_progress = Progress(transient=True)
self.rich_progress.start()
self.rich_task_id = self.rich_progress.add_task(
kwargs.get("desc", ""),
total=kwargs.get("total", None),
)
# Chain up to the parent constructor to ensure that the internal state of the superclass
# is correctly initialized, which some methods that we don't override might rely on.
super().__init__(*args, **kwargs)
def display(self, *args: Any, **kwargs: Any):
self.rich_progress.update(
self.rich_task_id,
description=self.desc,
total=self.total,
completed=self.n,
)
def close(self, *args: Any, **kwargs: Any):
self.rich_progress.stop()
def patch_tqdm():
tqdm.tqdm = TqdmShim # ty:ignore[invalid-assignment]
tqdm.auto.tqdm = TqdmShim # ty:ignore[invalid-assignment]
Generated
+6 -7
View File
@@ -953,6 +953,7 @@ dependencies = [
{ name = "pydantic-settings" }, { name = "pydantic-settings" },
{ name = "questionary" }, { name = "questionary" },
{ name = "rich" }, { name = "rich" },
{ name = "tqdm" },
{ name = "transformers" }, { name = "transformers" },
] ]
@@ -961,8 +962,6 @@ research = [
{ name = "geom-median" }, { name = "geom-median" },
{ name = "imageio" }, { name = "imageio" },
{ name = "matplotlib" }, { name = "matplotlib" },
{ name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "numpy", version = "2.3.5", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
{ name = "pacmap" }, { name = "pacmap" },
{ name = "scikit-learn", version = "1.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "scikit-learn", version = "1.7.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" },
{ name = "scikit-learn", version = "1.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "scikit-learn", version = "1.8.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" },
@@ -983,13 +982,12 @@ requires-dist = [
{ name = "hf-transfer", specifier = "~=0.1" }, { name = "hf-transfer", specifier = "~=0.1" },
{ name = "huggingface-hub", specifier = "~=1.7" }, { name = "huggingface-hub", specifier = "~=1.7" },
{ name = "imageio", marker = "extra == 'research'", specifier = "~=2.37" }, { name = "imageio", marker = "extra == 'research'", specifier = "~=2.37" },
{ name = "immutabledict", specifier = ">=4.3.1" }, { name = "immutabledict", specifier = "~=4.3" },
{ name = "kernels", specifier = "~=0.12" }, { name = "kernels", specifier = "~=0.12" },
{ name = "langdetect", specifier = ">=1.0.9" }, { name = "langdetect", specifier = "~=1.0" },
{ name = "lm-eval", extras = ["hf"], specifier = "~=0.4.11" }, { name = "lm-eval", extras = ["hf"], specifier = "~=0.4" },
{ name = "matplotlib", marker = "extra == 'research'", specifier = "~=3.10" }, { name = "matplotlib", marker = "extra == 'research'", specifier = "~=3.10" },
{ name = "numpy", specifier = ">=2.2.6" }, { name = "numpy", specifier = "~=2.2" },
{ name = "numpy", marker = "extra == 'research'", specifier = "~=2.2" },
{ name = "optuna", specifier = "~=4.7" }, { name = "optuna", specifier = "~=4.7" },
{ name = "pacmap", marker = "extra == 'research'", specifier = "~=0.8" }, { name = "pacmap", marker = "extra == 'research'", specifier = "~=0.8" },
{ name = "peft", specifier = "~=0.18" }, { name = "peft", specifier = "~=0.18" },
@@ -998,6 +996,7 @@ requires-dist = [
{ name = "questionary", specifier = "~=2.1" }, { name = "questionary", specifier = "~=2.1" },
{ name = "rich", specifier = "~=14.3" }, { name = "rich", specifier = "~=14.3" },
{ name = "scikit-learn", marker = "extra == 'research'", specifier = "~=1.7" }, { name = "scikit-learn", marker = "extra == 'research'", specifier = "~=1.7" },
{ name = "tqdm", specifier = "~=4.67" },
{ name = "transformers", specifier = "~=5.3" }, { name = "transformers", specifier = "~=5.3" },
] ]
provides-extras = ["research"] provides-extras = ["research"]