mirror of
https://github.com/TheR1D/shell_gpt.git
synced 2026-06-02 06:14:32 +02:00
159 lines
5.2 KiB
Python
159 lines
5.2 KiB
Python
import json
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, Generator, List, Optional
|
|
|
|
from ..cache import Cache
|
|
from ..config import cfg
|
|
from ..function import get_function
|
|
from ..printer import MarkdownPrinter, Printer, TextPrinter
|
|
from ..role import DefaultRoles, SystemRole
|
|
|
|
completion: Callable[..., Any] = lambda *args, **kwargs: Generator[Any, None, None]
|
|
base_url = cfg.get("API_BASE_URL")
|
|
use_litellm = cfg.get("USE_LITELLM") == "true"
|
|
additional_kwargs = {
|
|
"timeout": int(cfg.get("REQUEST_TIMEOUT")),
|
|
"api_key": cfg.get("OPENAI_API_KEY"),
|
|
"base_url": None if base_url == "default" else base_url,
|
|
}
|
|
|
|
if use_litellm:
|
|
import litellm # type: ignore
|
|
|
|
completion = litellm.completion
|
|
litellm.suppress_debug_info = True
|
|
else:
|
|
from openai import OpenAI
|
|
|
|
client = OpenAI(**additional_kwargs) # type: ignore
|
|
completion = client.chat.completions.create
|
|
additional_kwargs = {}
|
|
|
|
|
|
class Handler:
|
|
cache = Cache(int(cfg.get("CACHE_LENGTH")), Path(cfg.get("CACHE_PATH")))
|
|
|
|
def __init__(self, role: SystemRole, markdown: bool) -> None:
|
|
self.role = role
|
|
|
|
api_base_url = cfg.get("API_BASE_URL")
|
|
self.base_url = None if api_base_url == "default" else api_base_url
|
|
self.timeout = int(cfg.get("REQUEST_TIMEOUT"))
|
|
|
|
self.markdown = "APPLY MARKDOWN" in self.role.role and markdown
|
|
self.code_theme, self.color = cfg.get("CODE_THEME"), cfg.get("DEFAULT_COLOR")
|
|
|
|
@property
|
|
def printer(self) -> Printer:
|
|
return (
|
|
MarkdownPrinter(self.code_theme)
|
|
if self.markdown
|
|
else TextPrinter(self.color)
|
|
)
|
|
|
|
def make_messages(self, prompt: str) -> List[Dict[str, str]]:
|
|
raise NotImplementedError
|
|
|
|
def handle_function_call(
|
|
self,
|
|
messages: List[dict[str, Any]],
|
|
name: str,
|
|
arguments: str,
|
|
) -> Generator[str, None, None]:
|
|
messages.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": "",
|
|
"function_call": {"name": name, "arguments": arguments},
|
|
}
|
|
)
|
|
|
|
if messages and messages[-1]["role"] == "assistant":
|
|
yield "\n"
|
|
|
|
dict_args = json.loads(arguments)
|
|
joined_args = ", ".join(f'{k}="{v}"' for k, v in dict_args.items())
|
|
yield f"> @FunctionCall `{name}({joined_args})` \n\n"
|
|
|
|
result = get_function(name)(**dict_args)
|
|
if cfg.get("SHOW_FUNCTIONS_OUTPUT") == "true":
|
|
yield f"```text\n{result}\n```\n"
|
|
messages.append({"role": "function", "content": result, "name": name})
|
|
|
|
@cache
|
|
def get_completion(
|
|
self,
|
|
model: str,
|
|
temperature: float,
|
|
top_p: float,
|
|
messages: List[Dict[str, Any]],
|
|
functions: Optional[List[Dict[str, str]]],
|
|
) -> Generator[str, None, None]:
|
|
name = arguments = ""
|
|
is_shell_role = self.role.name == DefaultRoles.SHELL.value
|
|
is_code_role = self.role.name == DefaultRoles.CODE.value
|
|
is_dsc_shell_role = self.role.name == DefaultRoles.DESCRIBE_SHELL.value
|
|
if is_shell_role or is_code_role or is_dsc_shell_role:
|
|
functions = None
|
|
|
|
response = completion(
|
|
model=model,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
messages=messages,
|
|
functions=functions,
|
|
stream=True,
|
|
**additional_kwargs,
|
|
)
|
|
|
|
try:
|
|
for chunk in response:
|
|
delta = chunk.choices[0].delta
|
|
# LiteLLM uses dict instead of Pydantic object like OpenAI does.
|
|
function_call = (
|
|
delta.get("function_call") if use_litellm else delta.function_call
|
|
)
|
|
if function_call:
|
|
if function_call.name:
|
|
name = function_call.name
|
|
if function_call.arguments:
|
|
arguments += function_call.arguments
|
|
if chunk.choices[0].finish_reason == "function_call":
|
|
yield from self.handle_function_call(messages, name, arguments)
|
|
yield from self.get_completion(
|
|
model=model,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
messages=messages,
|
|
functions=functions,
|
|
caching=False,
|
|
)
|
|
return
|
|
|
|
yield delta.content or ""
|
|
except KeyboardInterrupt:
|
|
response.close()
|
|
|
|
def handle(
|
|
self,
|
|
prompt: str,
|
|
model: str,
|
|
temperature: float,
|
|
top_p: float,
|
|
caching: bool,
|
|
functions: Optional[List[Dict[str, str]]] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
disable_stream = cfg.get("DISABLE_STREAMING") == "true"
|
|
messages = self.make_messages(prompt.strip())
|
|
generator = self.get_completion(
|
|
model=model,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
messages=messages,
|
|
functions=functions,
|
|
caching=caching,
|
|
**kwargs,
|
|
)
|
|
return self.printer(generator, not disable_stream)
|