mirror of
https://github.com/p-e-w/heretic.git
synced 2026-06-02 05:03:33 +02:00
feat: allow injecting prefixes and suffixes into prompts
This commit is contained in:
@@ -27,6 +27,16 @@ class DatasetSpecification(BaseModel):
|
||||
|
||||
column: str = Field(description="Column in the dataset that contains the prompts.")
|
||||
|
||||
prefix: str = Field(
|
||||
default="",
|
||||
description="Text to prepend to each prompt.",
|
||||
)
|
||||
|
||||
suffix: str = Field(
|
||||
default="",
|
||||
description="Text to append to each prompt.",
|
||||
)
|
||||
|
||||
residual_plot_label: str | None = Field(
|
||||
default=None,
|
||||
description="Label to use for the dataset in plots of residual vectors.",
|
||||
|
||||
@@ -171,7 +171,15 @@ def load_prompts(specification: DatasetSpecification) -> list[str]:
|
||||
# Probably a repository path; let load_dataset figure it out.
|
||||
dataset = load_dataset(path, split=split_str)
|
||||
|
||||
return list(dataset[specification.column])
|
||||
prompts = list(dataset[specification.column])
|
||||
|
||||
if specification.prefix:
|
||||
prompts = [f"{specification.prefix} {prompt}" for prompt in prompts]
|
||||
|
||||
if specification.suffix:
|
||||
prompts = [f"{prompt} {specification.suffix}" for prompt in prompts]
|
||||
|
||||
return prompts
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
Reference in New Issue
Block a user