mirror of
https://github.com/p-e-w/heretic.git
synced 2026-06-02 05:03:33 +02:00
Add chat functionality
This commit is contained in:
+54
-10
@@ -222,14 +222,58 @@ def main():
|
||||
)
|
||||
print(f" * Score: [bold]{-study.best_value:.4f}[/]")
|
||||
|
||||
return
|
||||
|
||||
print()
|
||||
action = questionary.select(
|
||||
"What do you want to do with the optimized model?",
|
||||
choices=[
|
||||
"Save to a local folder",
|
||||
"Upload to Hugging Face",
|
||||
"Nothing (discard the model)",
|
||||
],
|
||||
).ask()
|
||||
print("Restoring best model...")
|
||||
print("* Reloading model...")
|
||||
model.reload_model()
|
||||
print("* Abliterating...")
|
||||
model.abliterate(
|
||||
refusal_directions,
|
||||
study.best_params["max_weight"],
|
||||
study.best_params["max_weight_position"],
|
||||
study.best_params["min_weight"],
|
||||
study.best_params["min_weight_distance"],
|
||||
)
|
||||
|
||||
while True:
|
||||
print()
|
||||
action = questionary.select(
|
||||
"What do you want to do with the optimized model?",
|
||||
choices=[
|
||||
"Save the model to a local folder",
|
||||
"Upload the model to Hugging Face",
|
||||
"Chat with the model",
|
||||
"Nothing (Quit)",
|
||||
],
|
||||
).ask()
|
||||
|
||||
match action:
|
||||
case "Save the model to a local folder":
|
||||
# TODO
|
||||
pass
|
||||
case "Upload the model to Hugging Face":
|
||||
# TODO
|
||||
pass
|
||||
case "Chat with the model":
|
||||
print()
|
||||
print("[cyan]Press Ctrl+C at any time to return to the menu.[/]")
|
||||
|
||||
chat = [
|
||||
{"role": "system", "content": settings.system_prompt},
|
||||
]
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = questionary.text("User:", qmark=">").unsafe_ask()
|
||||
if not message:
|
||||
break
|
||||
chat.append({"role": "user", "content": message})
|
||||
|
||||
print("[bold]Assistant:[/] ", end="")
|
||||
response = model.stream_chat_response(chat)
|
||||
chat.append({"role": "assistant", "content": response})
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
# Ctrl+C/Ctrl+D
|
||||
break
|
||||
case "Nothing (Quit)":
|
||||
break
|
||||
|
||||
@@ -12,6 +12,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
BatchEncoding,
|
||||
PreTrainedTokenizerBase,
|
||||
TextStreamer,
|
||||
)
|
||||
from transformers.generation.utils import GenerateOutput
|
||||
|
||||
@@ -251,3 +252,31 @@ class Model:
|
||||
logprobs.append(self.get_logprobs(batch))
|
||||
|
||||
return torch.cat(logprobs, dim=0)
|
||||
|
||||
def stream_chat_response(self, chat: list[dict[str, str]]) -> str:
|
||||
chat_prompt: str = self.tokenizer.apply_chat_template(
|
||||
chat,
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
inputs = self.tokenizer(
|
||||
chat_prompt,
|
||||
return_tensors="pt",
|
||||
).to(self.model.device)
|
||||
|
||||
streamer = TextStreamer(
|
||||
self.tokenizer,
|
||||
skip_prompt=True,
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
outputs = self.model.generate(
|
||||
**inputs,
|
||||
streamer=streamer,
|
||||
)
|
||||
|
||||
return self.tokenizer.decode(
|
||||
outputs[0, inputs["input_ids"].shape[1] :],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user