Add chat functionality

This commit is contained in:
Philipp Emanuel Weidmann
2025-09-24 18:09:23 +05:30
parent f00d35dc46
commit fd0fa52552
2 changed files with 83 additions and 10 deletions
+54 -10
View File
@@ -222,14 +222,58 @@ def main():
)
print(f" * Score: [bold]{-study.best_value:.4f}[/]")
return
print()
action = questionary.select(
"What do you want to do with the optimized model?",
choices=[
"Save to a local folder",
"Upload to Hugging Face",
"Nothing (discard the model)",
],
).ask()
print("Restoring best model...")
print("* Reloading model...")
model.reload_model()
print("* Abliterating...")
model.abliterate(
refusal_directions,
study.best_params["max_weight"],
study.best_params["max_weight_position"],
study.best_params["min_weight"],
study.best_params["min_weight_distance"],
)
while True:
print()
action = questionary.select(
"What do you want to do with the optimized model?",
choices=[
"Save the model to a local folder",
"Upload the model to Hugging Face",
"Chat with the model",
"Nothing (Quit)",
],
).ask()
match action:
case "Save the model to a local folder":
# TODO
pass
case "Upload the model to Hugging Face":
# TODO
pass
case "Chat with the model":
print()
print("[cyan]Press Ctrl+C at any time to return to the menu.[/]")
chat = [
{"role": "system", "content": settings.system_prompt},
]
while True:
try:
message = questionary.text("User:", qmark=">").unsafe_ask()
if not message:
break
chat.append({"role": "user", "content": message})
print("[bold]Assistant:[/] ", end="")
response = model.stream_chat_response(chat)
chat.append({"role": "assistant", "content": response})
except (KeyboardInterrupt, EOFError):
# Ctrl+C/Ctrl+D
break
case "Nothing (Quit)":
break
+29
View File
@@ -12,6 +12,7 @@ from transformers import (
AutoTokenizer,
BatchEncoding,
PreTrainedTokenizerBase,
TextStreamer,
)
from transformers.generation.utils import GenerateOutput
@@ -251,3 +252,31 @@ class Model:
logprobs.append(self.get_logprobs(batch))
return torch.cat(logprobs, dim=0)
def stream_chat_response(self, chat: list[dict[str, str]]) -> str:
chat_prompt: str = self.tokenizer.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=False,
)
inputs = self.tokenizer(
chat_prompt,
return_tensors="pt",
).to(self.model.device)
streamer = TextStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
outputs = self.model.generate(
**inputs,
streamer=streamer,
)
return self.tokenizer.decode(
outputs[0, inputs["input_ids"].shape[1] :],
skip_special_tokens=True,
)