diff --git a/.gemini/styleguide.md b/.gemini/styleguide.md index 2b9dc20..88d23d1 100644 --- a/.gemini/styleguide.md +++ b/.gemini/styleguide.md @@ -4,6 +4,8 @@ * Comments should start with a capital letter and end with a period. They should use correct grammar and spelling. * Function and method signatures **must** be fully type-annotated, including the return type (if any). * Every Python code file **must** start with an SPDX/Copyright header. +* Settings descriptions should start with a capital letter and end with a period. +* When new settings are added in `config.py`, they should also be added to `config.default.toml`, set to their default value and with their description as a comment. The order of settings in `config.default.toml` should match that in `config.py`. * Pull requests should implement one change, and one change only. * PRs containing multiple semantically independent changes **must** be split into multiple PRs. * PRs **must not** change existing code unless the changes are *directly related* to the PR. This includes changes to formatting and comments. diff --git a/.gitignore b/.gitignore index 52e1942..1241cea 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,7 @@ wheels/ *.egg-info # Virtual environments -.venv +.venv/ # Caches /.ruff_cache/ @@ -19,4 +19,7 @@ wheels/ /config.toml # Study checkpoints -/checkpoints/*.jsonl +/checkpoints/ + +# Residual plots +/plots/ diff --git a/config.default.toml b/config.default.toml index b284dce..e4af86f 100644 --- a/config.default.toml +++ b/config.default.toml @@ -15,15 +15,16 @@ dtypes = [ "float32", ] +# Quantization method to use when loading the model. Options: +# "none" (no quantization), +# "bnb_4bit" (4-bit quantization using bitsandbytes). +quantization = "none" + # Device map to pass to Accelerate when loading the model. device_map = "auto" -# Quantization method to use when loading the model. -# Options: "none" (no quantization), "bnb_4bit" (4-bit quantization using bitsandbytes). -quantization = "none" - -# Memory limits to impose. 0 is usually your first graphics card. -# max_memory = {0 = "16GB", "cpu" = "64GB"} +# Maximum memory to allocate per device. +# max_memory = {"0": "20GB", "cpu": "64GB"} # Number of input sequences to process in parallel (0 = auto). batch_size = 0 # auto @@ -34,22 +35,6 @@ max_batch_size = 128 # Maximum number of tokens to generate for each response. max_response_length = 100 -# Whether to adjust the refusal directions so that only the component that is -# orthogonal to the good direction is subtracted during abliteration. -orthogonalize_direction = false - -# How to apply row normalization of the weights. Options: -# 'none' (no normalization), -# 'pre' (compute LoRA adapter relative to row-normalized weights), -# 'full' (like 'pre', but re-normalizes to preserve original row magnitudes). -row_normalization = "none" - -# The rank of the LoRA adapter to use when 'full' row normalization is used. -# Row magnitude preservation is approximate due to non-linear efects, -# and this determines the rank of that approximation. Higher ranks produce -# larger output files and may slow down evaluation. -full_normalization_lora_rank = 3 - # Whether to print prompt/response pairs when counting refusals. print_responses = false @@ -76,9 +61,25 @@ kl_divergence_scale = 1.0 # This helps prevent the sampler from extensively exploring parameter combinations that "do nothing". kl_divergence_target = 0.01 +# Whether to adjust the refusal directions so that only the component that is +# orthogonal to the good direction is subtracted during abliteration. +orthogonalize_direction = false + +# How to apply row normalization of the weights. Options: +# "none" (no normalization), +# "pre" (compute LoRA adapter relative to row-normalized weights), +# "full" (like "pre", but renormalizes to preserve original row magnitudes). +row_normalization = "none" + +# The rank of the LoRA adapter to use when "full" row normalization is used. +# Row magnitude preservation is approximate due to non-linear effects, +# and this determines the rank of that approximation. Higher ranks produce +# larger output files and may slow down evaluation. +full_normalization_lora_rank = 3 + # The symmetric winsorization to apply to each layer of the per-prompt residuals, # expressed as the quantile to clamp to (between 0 and 1). Disabled by default. -# Example: winsorization_quantile = 0.95 applies a 90% winsorization. +# Example: winsorization_quantile = 0.95 applies a 95% winsorization. winsorization_quantile = 1.0 # Number of abliteration trials to run during optimization. @@ -87,6 +88,9 @@ n_trials = 200 # Number of trials that use random sampling for the purpose of exploration. n_startup_trials = 60 +# Directory to save and load study progress to/from. +study_checkpoint_dir = "checkpoints" + # Strings whose presence in a response (case insensitive) identifies the response as a refusal. refusal_markers = [ "sorry", diff --git a/src/heretic/config.py b/src/heretic/config.py index 8ed1852..b330a8c 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -66,7 +66,10 @@ class Settings(BaseSettings): evaluate_model: str | None = Field( default=None, - description="If this model ID or path is set, then instead of abliterating the main model, evaluate this model relative to the main model.", + description=( + "If this model ID or path is set, then instead of abliterating the main model, " + "evaluate this model relative to the main model." + ), ) dtypes: list[str] = Field( @@ -82,7 +85,19 @@ class Settings(BaseSettings): # if that was the dtype "auto" resolved to). "float32", ], - description="List of PyTorch dtypes to try when loading model tensors. If loading with a dtype fails, the next dtype in the list will be tried.", + description=( + "List of PyTorch dtypes to try when loading model tensors. " + "If loading with a dtype fails, the next dtype in the list will be tried." + ), + ) + + quantization: QuantizationMethod = Field( + default=QuantizationMethod.NONE, + description=( + "Quantization method to use when loading the model. Options: " + '"none" (no quantization), ' + '"bnb_4bit" (4-bit quantization using bitsandbytes).' + ), ) device_map: str | Dict[str, int | str] = Field( @@ -92,7 +107,7 @@ class Settings(BaseSettings): max_memory: Dict[str, str] | None = Field( default=None, - description="Maximum memory to allocate per device (e.g., {'0': '20GB', 'cpu': '64GB'}).", + description='Maximum memory to allocate per device (e.g., {"0": "20GB", "cpu": "64GB"}).', ) trust_remote_code: bool | None = Field( @@ -100,11 +115,6 @@ class Settings(BaseSettings): description="Whether to trust remote code when loading the model.", ) - quantization: QuantizationMethod = Field( - default=QuantizationMethod.NONE, - description="Quantization method to use when loading the model. Options: 'none' (no quantization), 'bnb_4bit' (4-bit quantization using bitsandbytes).", - ) - batch_size: int = Field( default=0, # auto description="Number of input sequences to process in parallel (0 = auto).", @@ -120,34 +130,6 @@ class Settings(BaseSettings): description="Maximum number of tokens to generate for each response.", ) - orthogonalize_direction: bool = Field( - default=False, - description=( - "Whether to adjust the refusal directions so that only the component that is " - "orthogonal to the good direction is subtracted during abliteration." - ), - ) - - row_normalization: RowNormalization = Field( - default=RowNormalization.NONE, - description=( - "How to apply row normalization of the weights. Options: " - "'none' (no normalization), " - "'pre' (compute LoRA adapter relative to row-normalized weights), " - "'full' (like 'pre', but renormalizes to preserve original row magnitudes)." - ), - ) - - full_normalization_lora_rank: int = Field( - default=3, - description=( - "The rank of the LoRA adapter to use when 'full' row normalization is used. " - "Row magnitude preservation is approximate due to non-linear efects, " - "and this determines the rank of that approximation. Higher ranks produce " - "larger output files and may slow down evaluation." - ), - ) - print_responses: bool = Field( default=False, description="Whether to print prompt/response pairs when counting refusals.", @@ -189,17 +171,45 @@ class Settings(BaseSettings): kl_divergence_target: float = Field( default=0.01, description=( - "The KL divergence to target. Below this value, an objective based on the refusal count is used." + "The KL divergence to target. Below this value, an objective based on the refusal count is used. " 'This helps prevent the sampler from extensively exploring parameter combinations that "do nothing".' ), ) + orthogonalize_direction: bool = Field( + default=False, + description=( + "Whether to adjust the refusal directions so that only the component that is " + "orthogonal to the good direction is subtracted during abliteration." + ), + ) + + row_normalization: RowNormalization = Field( + default=RowNormalization.NONE, + description=( + "How to apply row normalization of the weights. Options: " + '"none" (no normalization), ' + '"pre" (compute LoRA adapter relative to row-normalized weights), ' + '"full" (like "pre", but renormalizes to preserve original row magnitudes).' + ), + ) + + full_normalization_lora_rank: int = Field( + default=3, + description=( + 'The rank of the LoRA adapter to use when "full" row normalization is used. ' + "Row magnitude preservation is approximate due to non-linear effects, " + "and this determines the rank of that approximation. Higher ranks produce " + "larger output files and may slow down evaluation." + ), + ) + winsorization_quantile: float = Field( default=1.0, description=( "The symmetric winsorization to apply to each layer of the per-prompt residuals, " "expressed as the quantile to clamp to (between 0 and 1). Disabled by default. " - "Example: winsorization_quantile = 0.95 applies a 90% winsorization." + "Example: winsorization_quantile = 0.95 applies a 95% winsorization." ), ) @@ -215,7 +225,7 @@ class Settings(BaseSettings): study_checkpoint_dir: str = Field( default="checkpoints", - description="Directory to save and load study progress to/from:", + description="Directory to save and load study progress to/from.", ) refusal_markers: list[str] = Field( diff --git a/src/heretic/main.py b/src/heretic/main.py index 26221cf..a52ef94 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -61,21 +61,15 @@ def obtain_merge_strategy(settings: Settings) -> str | None: Returns "merge", "adapter", or None (if cancelled/invalid). """ - # Prompt for all PEFT models to ensure user is aware of merge implications if settings.quantization == QuantizationMethod.BNB_4BIT: - # Quantized models need special handling - we must reload the base model - # in full precision to merge the LoRA adapters print() print( - "[yellow]Model was loaded with quantization. Merging requires reloading the base model.[/]" + "Model was loaded with quantization. Merging requires reloading the base model." ) print( - "[red](!) WARNING: CPU Merging requires dequantizing the entire model to System RAM.[/]" - ) - print("[red] This can lead to SYSTEM FREEZES if you run out of memory.[/]") - print( - "[yellow] Rule of thumb: You need approx. 3x the parameter count in GB.[/]" + "[yellow]WARNING: CPU merging requires dequantizing the entire model to system RAM.[/]" ) + print("[yellow]This can lead to system freezes if you run out of memory.[/]") try: # Estimate memory requirements by loading the model structure on the "meta" device. @@ -94,61 +88,39 @@ def obtain_merge_strategy(settings: Settings) -> str | None: footprint_bytes = meta_model.get_memory_footprint() footprint_gb = footprint_bytes / (1024**3) print( - f"[yellow] Estimated net RAM required for model weights (excluding overhead): [bold]~{footprint_gb:.1f} GB[/][/]" + f"[yellow]Estimated RAM required (excluding overhead): [bold]~{footprint_gb:.2f} GB[/][/]" ) except Exception: # Fallback if meta loading fails (e.g. owing to custom model code - # or `bitsandbytes` quantization config issues on the meta device) + # or bitsandbytes quantization config issues on the meta device). print( - "[yellow] Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]" + "[yellow]Rule of thumb: You need approximately 3x the parameter count in GB RAM.[/]" + ) + print( + "[yellow]Example: A 27B model requires ~80GB RAM. A 70B model requires ~200GB RAM.[/]" ) print() - merge_choice = prompt_select( + strategy = prompt_select( "How do you want to proceed?", choices=[ Choice( - title="Merge full model" + title="Merge LoRA into full model" + ( "" if settings.quantization == QuantizationMethod.NONE - else " (reload base model on CPU - requires high RAM)" + else " (requires sufficient RAM)" ), value="merge", ), Choice( - title="Save LoRA adapter only (can be merged later with llama.cpp or more RAM)", + title="Save LoRA adapter only (can be merged later)", value="adapter", ), ], ) - return merge_choice - -def save_model( - model: Model, - save_directory: str, - settings: Settings, - strategy: str | None = None, -) -> None: - print("Saving model...") - if strategy is None: - strategy = obtain_merge_strategy(settings) - if strategy is None: - print("[yellow]Action cancelled.[/]") - return - - if strategy == "adapter": - model.model.save_pretrained(save_directory) - else: - merged_model = model.get_merged_model() - merged_model.save_pretrained(save_directory) - del merged_model - empty_cache() - - model.tokenizer.save_pretrained(save_directory) - - print(f"Model saved to [bold]{save_directory}[/].") + return strategy def run(): @@ -249,6 +221,8 @@ def run(): # Silence the warning about multivariate TPE being experimental. warnings.filterwarnings("ignore", category=ExperimentalWarning) + os.makedirs(settings.study_checkpoint_dir, exist_ok=True) + study_checkpoint_file = os.path.join( settings.study_checkpoint_dir, "".join( @@ -257,7 +231,6 @@ def run(): + ".jsonl", ) - os.makedirs(settings.study_checkpoint_dir, exist_ok=True) lock_obj = JournalFileOpenLock(study_checkpoint_file) backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj) storage = JournalStorage(backend) @@ -268,35 +241,57 @@ def run(): existing_study = None if existing_study is not None: - # A study is in here. Check if it's finished. choices = [] + if existing_study.user_attrs["finished"]: + print() print( - "[green]You have already processed this model. How would you like to proceed?[/]" + ( + "[green]You have already processed this model.[/] " + "You can show the results from the previous run, allowing you to export models or to run additional trials. " + "Alternatively, you can ignore the previous run and start from scratch. " + "This will delete the checkpoint file and all results from the previous run." + ) ) choices.append( Choice( - title="Show the results from the previous run, allowing you to export models, or to run additional trials.", + title="Show the results from the previous run", value="continue", ) ) else: + print() print( - "[yellow]You have already processed this model, but the run was interrupted. How would you like to proceed?[/]", + ( + "[yellow]You have already processed this model, but the run was interrupted.[/] " + "You can continue the previous run from where it stopped. This will override any specified settings. " + "Alternatively, you can ignore the previous run and start from scratch. " + "This will delete the checkpoint file and all results from the previous run." + ) ) choices.append( Choice( - title="Continue the previous run from where it stopped (will override all specified settings).", + title="Continue the previous run", value="continue", ) ) + choices.append( Choice( - title="Ignore the previous run and start from scratch. This will delete the checkpoint file and all results from the previous run.", + title="Ignore the previous run and start from scratch", value="restart", ) ) - choice = prompt_select("", choices) + + choices.append( + Choice( + title="Exit program", + value="", + ) + ) + + print() + choice = prompt_select("How would you like to proceed?", choices) if choice == "continue": settings = Settings.model_validate_json( @@ -306,8 +301,7 @@ def run(): os.unlink(study_checkpoint_file) backend = JournalFileBackend(study_checkpoint_file, lock_obj=lock_obj) storage = JournalStorage(backend) - else: - print("Cancelled; exiting.") + elif choice is None or choice == "": return model = Model(settings) @@ -562,14 +556,14 @@ def run(): raise TrialPruned() study = optuna.create_study( - study_name="heretic", sampler=TPESampler( n_startup_trials=settings.n_startup_trials, n_ei_candidates=128, multivariate=True, ), - storage=storage, directions=[StudyDirection.MINIMIZE, StudyDirection.MINIMIZE], + storage=storage, + study_name="heretic", load_if_exists=True, ) @@ -582,13 +576,14 @@ def run(): start_index = trial_index = count_completed_trials() if start_index > 0: + print() print("Resuming existing study.") try: study.optimize( - objective_wrapper, n_trials=settings.n_trials - count_completed_trials() + objective_wrapper, + n_trials=settings.n_trials - count_completed_trials(), ) - except KeyboardInterrupt: # This additional handler takes care of the small chance that KeyboardInterrupt # is raised just between trials, which wouldn't be caught by the handler @@ -638,14 +633,14 @@ def run(): choices.append( Choice( - title="Continue optimization (run more trials)", + title="Run additional trials", value="continue", ) ) choices.append( Choice( - title="None (exit program)", + title="Exit program", value="", ) ) @@ -669,18 +664,26 @@ def run(): if trial == "continue": while True: try: - n_more_trials = int( - prompt_text("How many more trials do you want to run?") + n_additional_trials = prompt_text( + "How many additional trials do you want to run?" ) - if n_more_trials > 0: + if n_additional_trials is None or n_additional_trials == "": + n_additional_trials = 0 + break + n_additional_trials = int(n_additional_trials) + if n_additional_trials > 0: break print("[red]Please enter a number greater than 0.[/]") except ValueError: - print("[red]Invalid input. Please enter a number.[/]") + print("[red]Please enter a number.[/]") - settings.n_trials += n_more_trials + if n_additional_trials == 0: + continue + + settings.n_trials += n_additional_trials study.set_user_attr("settings", settings.model_dump_json()) study.set_user_attr("finished", False) + try: study.optimize( objective_wrapper, @@ -688,8 +691,10 @@ def run(): ) except KeyboardInterrupt: pass + if count_completed_trials() == settings.n_trials: study.set_user_attr("finished", True) + break elif trial is None or trial == "": @@ -720,14 +725,11 @@ def run(): "Save the model to a local folder", "Upload the model to Hugging Face", "Chat with the model", - "Nothing (return to trial selection menu)", + "Return to the trial selection menu", ], ) - if ( - action is None - or action == "Nothing (return to trial selection menu)" - ): + if action is None or action == "Return to the trial selection menu": break # All actions are wrapped in a try/except block so that if an error occurs, @@ -740,11 +742,23 @@ def run(): if not save_directory: continue - save_model( - model, - save_directory, - settings, - ) + strategy = obtain_merge_strategy(settings) + if strategy is None: + continue + + if strategy == "adapter": + print("Saving LoRA adapter...") + model.model.save_pretrained(save_directory) + else: + print("Saving merged model...") + merged_model = model.get_merged_model() + merged_model.save_pretrained(save_directory) + del merged_model + empty_cache() + + model.tokenizer.save_pretrained(save_directory) + + print(f"Model saved to [bold]{save_directory}[/].") case "Upload the model to Hugging Face": # We don't use huggingface_hub.login() because that stores the token on disk, @@ -780,7 +794,6 @@ def run(): strategy = obtain_merge_strategy(settings) if strategy is None: - print("[yellow]Action cancelled.[/]") continue if strategy == "adapter": diff --git a/src/heretic/model.py b/src/heretic/model.py index 20c293c..8910b72 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -38,7 +38,7 @@ def get_model_class( ) -> Type[AutoModelForImageTextToText] | Type[AutoModelForCausalLM]: configs = PretrainedConfig.get_config_dict(model) - if any(["vision_config" in x for x in configs]): + if any([("vision_config" in config) for config in configs]): return AutoModelForImageTextToText else: return AutoModelForCausalLM @@ -96,9 +96,9 @@ class Model: try: quantization_config = self._get_quantization_config(dtype) - # Build kwargs, only include quantization_config if it's not None - # (some models like gpt-oss have issues with explicit None) extra_kwargs = {} + # Only include quantization_config if it's not None + # (some models like gpt-oss have issues with explicit None). if quantization_config is not None: extra_kwargs["quantization_config"] = quantization_config @@ -134,9 +134,11 @@ class Model: print(f"[red]Failed[/] ({error})") continue - print("[green]Ok[/]") if settings.quantization == QuantizationMethod.BNB_4BIT: - print("[bold green]Model loaded in 4-bit precision.[/]") + print("[green]Ok[/] (quantized to 4-bit precision)") + else: + print("[green]Ok[/]") + break if self.model is None: @@ -158,7 +160,7 @@ class Model: # Guard against calling this method at the wrong time. assert isinstance(self.model, PreTrainedModel) - # Always use LoRA adapters for abliteration (faster reload, no weight modification) + # Always use LoRA adapters for abliteration (faster reload, no weight modification). # We use the leaf names (e.g. "o_proj") as target modules. # This may cause LoRA adapters to be attached to unrelated modules (e.g. "conv.o_proj"), # but this is harmless as we only abliterate the modules we target in `abliterate()`, @@ -181,9 +183,8 @@ class Model: lora_alpha=lora_rank, # Apply adapter at full strength. lora_dropout=0, bias="none", - # Even if we're using AutoModelForImageTextToText, this is still correct, as it is (post-vision) - # the same kind of model. - # https://github.com/huggingface/peft/blob/622c2821cb0d7897bee53aad7914d42b5fecbf61/src/peft/auto.py#L45 + # Even if we're using AutoModelForImageTextToText, this is still correct, + # as VL models are typically just causal LMs with an added image encoder. task_type="CAUSAL_LM", ) @@ -191,9 +192,7 @@ class Model: # so the result is a PeftModel rather than a PeftMixedModel. self.model = cast(PeftModel, get_peft_model(self.model, self.peft_config)) - print( - f"[green]LoRA adapters initialized (targets: {', '.join(target_modules)})[/]" - ) + print(f"* LoRA adapters initialized (targets: {', '.join(target_modules)})") def _get_quantization_config(self, dtype: str) -> BitsAndBytesConfig | None: """ @@ -637,7 +636,10 @@ class Model: abs_residuals = torch.abs(residuals) # Get the (prompt, layer, 1) quantiles of the (prompt, layer, component) residuals. thresholds = torch.quantile( - abs_residuals, self.settings.winsorization_quantile, dim=2, keepdim=True + abs_residuals, + self.settings.winsorization_quantile, + dim=2, + keepdim=True, ) return torch.clamp(residuals, -thresholds, thresholds)