mirror of
https://github.com/TheR1D/shell_gpt.git
synced 2026-06-02 06:14:32 +02:00
Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| dee88ff87b | |||
| 17be969232 | |||
| 9d6e75dfe8 | |||
| 29b77522ca | |||
| 4ea2f834cf | |||
| 6bd0bdebe1 | |||
| 880e7db0d0 | |||
| a04167c723 | |||
| 9615dfbec8 | |||
| 30f39782b0 | |||
| 439a3e848e | |||
| b7cad0bd85 | |||
| 005d4fc8fb | |||
| 8f93f280ce | |||
| 859c97915a | |||
| 47b1715bc3 | |||
| 8fbd94f60b | |||
| 9bd9420286 | |||
| aac2f5461b | |||
| ab6b475c9d |
@@ -13,7 +13,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10"]
|
||||
python-version: ["3.10", "3.11", "3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
@@ -35,4 +35,4 @@ jobs:
|
||||
- name: tests
|
||||
run: |
|
||||
export OPENAI_API_KEY=test_api_key
|
||||
pytest tests/ -p no:warnings
|
||||
pytest tests/ -p no:warnings -v -s
|
||||
|
||||
@@ -27,7 +27,7 @@ jobs:
|
||||
- name: Build a binary wheel and a source tarball
|
||||
run: python3 -m hatchling build
|
||||
- name: Store the distribution packages
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: python-package-distributions
|
||||
path: dist/
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Download all the dists
|
||||
uses: actions/download-artifact@v3
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: python-package-distributions
|
||||
path: dist/
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Download all the dists
|
||||
uses: actions/download-artifact@v3
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: python-package-distributions
|
||||
path: dist/
|
||||
@@ -73,7 +73,7 @@ jobs:
|
||||
echo "SGPT_VERSION=$(find dist -type f -name '*.tar.gz' | grep -oP '\d+.\d+.\d+')" >> $GITHUB_ENV
|
||||
echo "Release version $SGPT_VERSION"
|
||||
- name: Sign the dists with Sigstore
|
||||
uses: sigstore/gh-action-sigstore-python@v1.2.3
|
||||
uses: sigstore/gh-action-sigstore-python@v3.0.0
|
||||
with:
|
||||
inputs: >-
|
||||
./dist/*.tar.gz
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
# ShellGPT
|
||||
A command-line productivity tool powered by AI large language models (LLM). This command-line tool offers streamlined generation of **shell commands, code snippets, documentation**, eliminating the need for external resources (like Google search). Supports Linux, macOS, Windows and compatible with all major Shells like PowerShell, CMD, Bash, Zsh, etc.
|
||||
|
||||
https://github.com/TheR1D/shell_gpt/assets/16740832/9197283c-db6a-4b46-bfea-3eb776dd9093
|
||||
https://github.com/TheR1D/shell_gpt/assets/16740832/721ddb19-97e7-428f-a0ee-107d027ddd59
|
||||
|
||||
## Installation
|
||||
```shell
|
||||
pip install shell-gpt
|
||||
```
|
||||
By default, ShellGPT uses OpenAI's API and GPT-4 model. You'll need an API key, you can generate one [here](https://beta.openai.com/account/api-keys). You will be prompted for your key which will then be stored in `~/.config/shell_gpt/.sgptrc`. OpenAI API is not free of charge, please refer to the [OpenAI pricing](https://openai.com/pricing) for more information.
|
||||
By default, ShellGPT uses OpenAI's API and GPT-4 model. You'll need an API key, you can generate one [here](https://platform.openai.com/api-keys). You will be prompted for your key which will then be stored in `~/.config/shell_gpt/.sgptrc`. OpenAI API is not free of charge, please refer to the [OpenAI pricing](https://openai.com/pricing) for more information.
|
||||
|
||||
> [!TIP]
|
||||
> Alternatively, you can use locally hosted open source models which are available for free. To use local models, you will need to run your own LLM backend server such as [Ollama](https://github.com/ollama/ollama). To set up ShellGPT with Ollama, please follow this comprehensive [guide](https://github.com/TheR1D/shell_gpt/wiki/Ollama).
|
||||
> Alternatively, you can run open-source models locally for free. This requires setting up your own LLM backend, such as [Ollama](https://github.com/ollama/ollama). To get ShellGPT working with Ollama, follow this detailed [guide](https://github.com/TheR1D/shell_gpt/wiki/Ollama)
|
||||
>
|
||||
> **❗️Note that ShellGPT is not optimized for local models and may not work as expected.**
|
||||
|
||||
@@ -290,28 +290,7 @@ The snippet of code you've provided is written in Python. It prompts the user...
|
||||
sgpt --install-functions
|
||||
```
|
||||
|
||||
ShellGPT has a convenient way to define functions and use them. In order to create your custom function, navigate to `~/.config/shell_gpt/functions` and create a new .py file with the function name. Inside this file, you can define your function using the following syntax:
|
||||
```python
|
||||
# execute_shell_command.py
|
||||
import subprocess
|
||||
from pydantic import Field
|
||||
from instructor import OpenAISchema
|
||||
|
||||
|
||||
class Function(OpenAISchema):
|
||||
"""
|
||||
Executes a shell command and returns the output (result).
|
||||
"""
|
||||
shell_command: str = Field(..., example="ls -la", descriptions="Shell command to execute.")
|
||||
|
||||
class Config:
|
||||
title = "execute_shell_command"
|
||||
|
||||
@classmethod
|
||||
def execute(cls, shell_command: str) -> str:
|
||||
result = subprocess.run(shell_command.split(), capture_output=True, text=True)
|
||||
return f"Exit code: {result.returncode}, Output:\n{result.stdout}"
|
||||
```
|
||||
ShellGPT has a convenient way to define functions and use them. In order to create your custom function, navigate to `~/.config/shell_gpt/functions` and create a new .py file with the function name. Inside this file, you can define your function using this [example](https://github.com/TheR1D/shell_gpt/blob/main/sgpt/llm_functions/common/execute_shell.py).
|
||||
|
||||
The docstring comment inside the class will be passed to OpenAI API as a description for the function, along with the `title` attribute and parameters descriptions. The `execute` function will be called if LLM decides to use your function. In this case we are allowing LLM to execute any Shell commands in our system. Since we are returning the output of the command, LLM will be able to analyze it and decide if it is a good fit for the prompt. Here is an example how the function might be executed by LLM:
|
||||
```shell
|
||||
@@ -365,7 +344,7 @@ sgpt --role json_generator "random: user, password, email, address"
|
||||
}
|
||||
```
|
||||
|
||||
If the description of the role contains the words "APPLY MARKDOWN" (case sensitive), then chats will be displayed using markdown formatting.
|
||||
If the description of the role contains the words "APPLY MARKDOWN" (case sensitive), then chats will be displayed using markdown formatting unless it is explicitly turned off with `--no-md`.
|
||||
|
||||
### Request cache
|
||||
Control cache using `--cache` (default) and `--no-cache` options. This caching applies for all `sgpt` requests to OpenAI API:
|
||||
@@ -395,7 +374,7 @@ CACHE_PATH=/tmp/shell_gpt/cache
|
||||
# Request timeout in seconds.
|
||||
REQUEST_TIMEOUT=60
|
||||
# Default OpenAI model to use.
|
||||
DEFAULT_MODEL=gpt-4o
|
||||
DEFAULT_MODEL=gpt-5.4-mini
|
||||
# Default color for shell and code completions.
|
||||
DEFAULT_COLOR=magenta
|
||||
# When in --shell mode, default to "Y" for no input.
|
||||
@@ -422,7 +401,7 @@ Possible options for `CODE_THEME`: https://pygments.org/styles/
|
||||
│ prompt [PROMPT] The prompt to generate completions for. │
|
||||
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
|
||||
╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────╮
|
||||
│ --model TEXT Large language model to use. [default: gpt-4o] │
|
||||
│ --model TEXT Large language model to use. [default: gpt-5.4-mini] │
|
||||
│ --temperature FLOAT RANGE [0.0<=x<=2.0] Randomness of generated output. [default: 0.0] │
|
||||
│ --top-p FLOAT RANGE [0.0<=x<=1.0] Limits highest probable tokens (words). [default: 1.0] │
|
||||
│ --md --no-md Prettify markdown output. [default: md] │
|
||||
@@ -477,35 +456,6 @@ You also can use the provided `Dockerfile` to build your own image:
|
||||
docker build -t sgpt .
|
||||
```
|
||||
|
||||
### Docker + Ollama
|
||||
|
||||
If you want to send your requests to an Ollama instance and run ShellGPT inside a Docker container, you need to adjust the Dockerfile and build the container yourself: the litellm package is needed and env variables need to be set correctly.
|
||||
|
||||
Example Dockerfile:
|
||||
```
|
||||
FROM python:3-slim
|
||||
|
||||
ENV DEFAULT_MODEL=ollama/mistral:7b-instruct-v0.2-q4_K_M
|
||||
ENV API_BASE_URL=http://10.10.10.10:11434
|
||||
ENV USE_LITELLM=true
|
||||
ENV OPENAI_API_KEY=bad_key
|
||||
ENV SHELL_INTERACTION=false
|
||||
ENV PRETTIFY_MARKDOWN=false
|
||||
ENV OS_NAME="Arch Linux"
|
||||
ENV SHELL_NAME=auto
|
||||
|
||||
WORKDIR /app
|
||||
COPY . /app
|
||||
|
||||
RUN apt-get update && apt-get install -y gcc
|
||||
RUN pip install --no-cache /app[litellm] && mkdir -p /tmp/shell_gpt
|
||||
|
||||
VOLUME /tmp/shell_gpt
|
||||
|
||||
ENTRYPOINT ["sgpt"]
|
||||
```
|
||||
|
||||
|
||||
## Additional documentation
|
||||
* [Azure integration](https://github.com/TheR1D/shell_gpt/wiki/Azure)
|
||||
* [Ollama integration](https://github.com/TheR1D/shell_gpt/wiki/Ollama)
|
||||
|
||||
+6
-11
@@ -8,7 +8,7 @@ description = "A command-line productivity tool powered by large language models
|
||||
keywords = ["shell", "gpt", "openai", "ollama", "cli", "productivity", "cheet-sheet"]
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
requires-python = ">=3.6"
|
||||
requires-python = ">=3.10"
|
||||
authors = [{ name = "Farkhod Sadykov", email = "farkhod@sadykov.dev" }]
|
||||
dynamic = ["version"]
|
||||
classifiers = [
|
||||
@@ -18,28 +18,23 @@ classifiers = [
|
||||
"Intended Audience :: Information Technology",
|
||||
"Intended Audience :: System Administrators",
|
||||
"Intended Audience :: Developers",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
"Programming Language :: Python :: 3.6",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
]
|
||||
dependencies = [
|
||||
"openai >= 1.34.0, < 2.0.0",
|
||||
"openai >= 2.0.0, < 3.0.0",
|
||||
"typer >= 0.7.0, < 1.0.0",
|
||||
"click >= 7.1.1, < 9.0.0",
|
||||
"rich >= 13.1.0, < 14.0.0",
|
||||
"distro >= 1.8.0, < 2.0.0",
|
||||
"instructor >= 0.4.5, < 1.0.0",
|
||||
'pyreadline3 >= 3.4.1, < 4.0.0; sys_platform == "win32"',
|
||||
"prompt_toolkit >= 3.0.51",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
litellm = [
|
||||
"litellm == 1.24.5"
|
||||
"litellm == 1.83.4"
|
||||
]
|
||||
test = [
|
||||
"pytest >= 7.2.2, < 8.0.0",
|
||||
|
||||
Regular → Executable
Regular → Executable
Regular → Executable
+1
-1
@@ -1 +1 @@
|
||||
__version__ = "1.4.4"
|
||||
__version__ = "1.5.1"
|
||||
|
||||
+17
-8
@@ -5,8 +5,9 @@ import readline # noqa: F401
|
||||
import sys
|
||||
|
||||
import typer
|
||||
from click import BadArgumentUsage
|
||||
from click import UsageError
|
||||
from click.types import Choice
|
||||
from prompt_toolkit import PromptSession
|
||||
|
||||
from sgpt.config import cfg
|
||||
from sgpt.function import get_openai_schemas
|
||||
@@ -101,13 +102,12 @@ def main(
|
||||
),
|
||||
repl: str = typer.Option(
|
||||
None,
|
||||
help="Start a REPL (Read–eval–print loop) session.",
|
||||
help="Start a REPL (Read-eval-print loop) session.",
|
||||
rich_help_panel="Chat Options",
|
||||
),
|
||||
show_chat: str = typer.Option(
|
||||
None,
|
||||
help="Show all messages from provided chat id.",
|
||||
callback=ChatHandler.show_messages_callback,
|
||||
rich_help_panel="Chat Options",
|
||||
),
|
||||
list_chats: bool = typer.Option(
|
||||
@@ -183,16 +183,19 @@ def main(
|
||||
# Non-interactive shell.
|
||||
pass
|
||||
|
||||
if show_chat:
|
||||
ChatHandler.show_messages(show_chat, md)
|
||||
|
||||
if sum((shell, describe_shell, code)) > 1:
|
||||
raise BadArgumentUsage(
|
||||
raise UsageError(
|
||||
"Only one of --shell, --describe-shell, and --code options can be used at a time."
|
||||
)
|
||||
|
||||
if chat and repl:
|
||||
raise BadArgumentUsage("--chat and --repl options cannot be used together.")
|
||||
raise UsageError("--chat and --repl options cannot be used together.")
|
||||
|
||||
if editor and stdin_passed:
|
||||
raise BadArgumentUsage("--editor option cannot be used with stdin input.")
|
||||
raise UsageError("--editor option cannot be used with stdin input.")
|
||||
|
||||
if editor:
|
||||
prompt = get_edited_prompt()
|
||||
@@ -235,17 +238,23 @@ def main(
|
||||
functions=function_schemas,
|
||||
)
|
||||
|
||||
session: PromptSession[str] = PromptSession()
|
||||
|
||||
while shell and interaction:
|
||||
option = typer.prompt(
|
||||
text="[E]xecute, [D]escribe, [A]bort",
|
||||
type=Choice(("e", "d", "a", "y"), case_sensitive=False),
|
||||
text="[E]xecute, [M]odify, [D]escribe, [A]bort",
|
||||
type=Choice(("e", "m", "d", "a", "y"), case_sensitive=False),
|
||||
default="e" if cfg.get("DEFAULT_EXECUTE_SHELL_CMD") == "true" else "a",
|
||||
show_choices=False,
|
||||
show_default=False,
|
||||
)
|
||||
|
||||
if option in ("e", "y"):
|
||||
# "y" option is for keeping compatibility with old version.
|
||||
run_command(full_completion)
|
||||
elif option == "m":
|
||||
full_completion = session.prompt("", default=full_completion)
|
||||
continue
|
||||
elif option == "d":
|
||||
DefaultHandler(DefaultRoles.DESCRIBE_SHELL.get_role(), md).handle(
|
||||
full_completion,
|
||||
|
||||
+1
-1
@@ -38,7 +38,7 @@ class Cache:
|
||||
result += i
|
||||
yield i
|
||||
if "@FunctionCall" not in result:
|
||||
file.write_text(result)
|
||||
file.write_text(result, encoding="utf-8")
|
||||
self._delete_oldest_files(self.length) # type: ignore
|
||||
|
||||
return wrapper
|
||||
|
||||
+1
-1
@@ -22,7 +22,7 @@ DEFAULT_CONFIG = {
|
||||
"CHAT_CACHE_LENGTH": int(os.getenv("CHAT_CACHE_LENGTH", "100")),
|
||||
"CACHE_LENGTH": int(os.getenv("CHAT_CACHE_LENGTH", "100")),
|
||||
"REQUEST_TIMEOUT": int(os.getenv("REQUEST_TIMEOUT", "60")),
|
||||
"DEFAULT_MODEL": os.getenv("DEFAULT_MODEL", "gpt-4o"),
|
||||
"DEFAULT_MODEL": os.getenv("DEFAULT_MODEL", "gpt-5.4-mini"),
|
||||
"DEFAULT_COLOR": os.getenv("DEFAULT_COLOR", "magenta"),
|
||||
"ROLE_STORAGE_PATH": os.getenv("ROLE_STORAGE_PATH", str(ROLE_STORAGE_PATH)),
|
||||
"DEFAULT_EXECUTE_SHELL_CMD": os.getenv("DEFAULT_EXECUTE_SHELL_CMD", "false"),
|
||||
|
||||
+11
-17
@@ -1,9 +1,10 @@
|
||||
import importlib.util
|
||||
import sys
|
||||
from abc import ABCMeta
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .config import cfg
|
||||
|
||||
|
||||
@@ -11,8 +12,8 @@ class Function:
|
||||
def __init__(self, path: str):
|
||||
module = self._read(path)
|
||||
self._function = module.Function.execute
|
||||
self._openai_schema = module.Function.openai_schema
|
||||
self._name = self._openai_schema["name"]
|
||||
self._openai_schema = module.Function.openai_schema()
|
||||
self._name = self._openai_schema["function"]["name"]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -34,13 +35,17 @@ class Function:
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module) # type: ignore
|
||||
|
||||
if not isinstance(module.Function, ABCMeta):
|
||||
if not issubclass(module.Function, BaseModel):
|
||||
raise TypeError(
|
||||
f"Function {module_name} must be a subclass of pydantic.BaseModel"
|
||||
)
|
||||
if not hasattr(module.Function, "execute"):
|
||||
raise TypeError(
|
||||
f"Function {module_name} must have a 'execute' static method"
|
||||
f"Function {module_name} must have an 'execute' classmethod"
|
||||
)
|
||||
if not hasattr(module.Function, "openai_schema"):
|
||||
raise TypeError(
|
||||
f"Function {module_name} must have an 'openai_schema' classmethod"
|
||||
)
|
||||
|
||||
return module
|
||||
@@ -59,15 +64,4 @@ def get_function(name: str) -> Callable[..., Any]:
|
||||
|
||||
|
||||
def get_openai_schemas() -> List[Dict[str, Any]]:
|
||||
transformed_schemas = []
|
||||
for function in functions:
|
||||
schema = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function.openai_schema["name"],
|
||||
"description": function.openai_schema.get("description", ""),
|
||||
"parameters": function.openai_schema.get("parameters", {}),
|
||||
},
|
||||
}
|
||||
transformed_schemas.append(schema)
|
||||
return transformed_schemas
|
||||
return [function.openai_schema for function in functions]
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional
|
||||
|
||||
import typer
|
||||
from click import BadArgumentUsage
|
||||
from click import BadParameter, UsageError
|
||||
from rich.console import Console
|
||||
from rich.markdown import Markdown
|
||||
|
||||
@@ -71,7 +71,11 @@ class ChatSession:
|
||||
|
||||
def _write(self, messages: List[Dict[str, str]], chat_id: str) -> None:
|
||||
file_path = self.storage_path / chat_id
|
||||
json.dump(messages[-self.length :], file_path.open("w"))
|
||||
# Retain the first message since it defines the role
|
||||
truncated_messages = (
|
||||
messages[:1] + messages[1 + max(0, len(messages) - self.length) :]
|
||||
)
|
||||
json.dump(truncated_messages, file_path.open("w"))
|
||||
|
||||
def invalidate(self, chat_id: str) -> None:
|
||||
file_path = self.storage_path / chat_id
|
||||
@@ -127,9 +131,9 @@ class ChatHandler(Handler):
|
||||
typer.echo(chat_id)
|
||||
|
||||
@classmethod
|
||||
def show_messages(cls, chat_id: str) -> None:
|
||||
def show_messages(cls, chat_id: str, markdown: bool) -> None:
|
||||
color = cfg.get("DEFAULT_COLOR")
|
||||
if "APPLY MARKDOWN" in cls.initial_message(chat_id):
|
||||
if "APPLY MARKDOWN" in cls.initial_message(chat_id) and markdown:
|
||||
theme = cfg.get("CODE_THEME")
|
||||
for message in cls.chat_session.get_messages(chat_id):
|
||||
if message.startswith("assistant:"):
|
||||
@@ -143,24 +147,17 @@ class ChatHandler(Handler):
|
||||
running_color = color if index % 2 == 0 else "green"
|
||||
typer.secho(message, fg=running_color)
|
||||
|
||||
@classmethod
|
||||
@option_callback
|
||||
def show_messages_callback(cls, chat_id: str) -> None:
|
||||
cls.show_messages(chat_id)
|
||||
|
||||
def validate(self) -> None:
|
||||
if self.initiated:
|
||||
chat_role_name = self.role.get_role_name(self.initial_message(self.chat_id))
|
||||
if not chat_role_name:
|
||||
raise BadArgumentUsage(
|
||||
f'Could not determine chat role of "{self.chat_id}"'
|
||||
)
|
||||
raise BadParameter(f'Could not determine chat role of "{self.chat_id}"')
|
||||
if self.role.name == DefaultRoles.DEFAULT.value:
|
||||
# If user didn't pass chat mode, we will use the one that was used to initiate the chat.
|
||||
self.role = SystemRole.get(chat_role_name)
|
||||
else:
|
||||
if not self.is_same_role:
|
||||
raise BadArgumentUsage(
|
||||
raise UsageError(
|
||||
f'Cant change chat role to "{self.role.name}" '
|
||||
f'since it was initiated as "{chat_role_name}" chat.'
|
||||
)
|
||||
|
||||
@@ -58,14 +58,22 @@ class Handler:
|
||||
def handle_function_call(
|
||||
self,
|
||||
messages: List[dict[str, Any]],
|
||||
tool_call_id: str,
|
||||
name: str,
|
||||
arguments: str,
|
||||
) -> Generator[str, None, None]:
|
||||
# Add assistant message with tool call
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"function_call": {"name": name, "arguments": arguments},
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": tool_call_id,
|
||||
"type": "function",
|
||||
"function": {"name": name, "arguments": arguments},
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -79,7 +87,11 @@ class Handler:
|
||||
result = get_function(name)(**dict_args)
|
||||
if cfg.get("SHOW_FUNCTIONS_OUTPUT") == "true":
|
||||
yield f"```text\n{result}\n```\n"
|
||||
messages.append({"role": "function", "content": result, "name": name})
|
||||
|
||||
# Add tool response message
|
||||
messages.append(
|
||||
{"role": "tool", "content": result, "tool_call_id": tool_call_id}
|
||||
)
|
||||
|
||||
@cache
|
||||
def get_completion(
|
||||
@@ -90,7 +102,7 @@ class Handler:
|
||||
messages: List[Dict[str, Any]],
|
||||
functions: Optional[List[Dict[str, str]]],
|
||||
) -> Generator[str, None, None]:
|
||||
name = arguments = ""
|
||||
tool_call_id = name = arguments = ""
|
||||
is_shell_role = self.role.name == DefaultRoles.SHELL.value
|
||||
is_code_role = self.role.name == DefaultRoles.CODE.value
|
||||
is_dsc_shell_role = self.role.name == DefaultRoles.DESCRIBE_SHELL.value
|
||||
@@ -113,6 +125,8 @@ class Handler:
|
||||
|
||||
try:
|
||||
for chunk in response:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
|
||||
# LiteLLM uses dict instead of Pydantic object like OpenAI does.
|
||||
@@ -121,12 +135,21 @@ class Handler:
|
||||
)
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.function.name:
|
||||
name = tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
arguments += tool_call.function.arguments
|
||||
if use_litellm:
|
||||
# TODO: test.
|
||||
tool_call_id = tool_call.get("id") or tool_call_id
|
||||
name = tool_call.get("function", {}).get("name") or name
|
||||
arguments += tool_call.get("function", {}).get(
|
||||
"arguments", ""
|
||||
)
|
||||
else:
|
||||
tool_call_id = tool_call.id or tool_call_id
|
||||
name = tool_call.function.name or name
|
||||
arguments += tool_call.function.arguments or ""
|
||||
if chunk.choices[0].finish_reason == "tool_calls":
|
||||
yield from self.handle_function_call(messages, name, arguments)
|
||||
yield from self.handle_function_call(
|
||||
messages, tool_call_id, name, arguments
|
||||
)
|
||||
yield from self.get_completion(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
|
||||
@@ -24,7 +24,7 @@ class ReplHandler(ChatHandler):
|
||||
def handle(self, init_prompt: str, **kwargs: Any) -> None: # type: ignore
|
||||
if self.initiated:
|
||||
rich_print(Rule(title="Chat History", style="bold magenta"))
|
||||
self.show_messages(self.chat_id)
|
||||
self.show_messages(self.chat_id, self.markdown)
|
||||
rich_print(Rule(style="bold magenta"))
|
||||
|
||||
info_message = (
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import subprocess
|
||||
from typing import Any, Dict
|
||||
|
||||
from instructor import OpenAISchema
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Function(OpenAISchema):
|
||||
class Function(BaseModel):
|
||||
"""
|
||||
Executes a shell command and returns the output (result).
|
||||
"""
|
||||
@@ -12,11 +12,8 @@ class Function(OpenAISchema):
|
||||
shell_command: str = Field(
|
||||
...,
|
||||
example="ls -la",
|
||||
descriptions="Shell command to execute.",
|
||||
)
|
||||
|
||||
class Config:
|
||||
title = "execute_shell_command"
|
||||
description="Shell command to execute.",
|
||||
) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def execute(cls, shell_command: str) -> str:
|
||||
@@ -26,3 +23,20 @@ class Function(OpenAISchema):
|
||||
output, _ = process.communicate()
|
||||
exit_code = process.returncode
|
||||
return f"Exit code: {exit_code}, Output:\n{output.decode()}"
|
||||
|
||||
@classmethod
|
||||
def openai_schema(cls) -> Dict[str, Any]:
|
||||
"""Generate OpenAI function schema from Pydantic model."""
|
||||
schema = cls.model_json_schema()
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "execute_shell_command",
|
||||
"description": cls.__doc__.strip() if cls.__doc__ else "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": schema.get("properties", {}),
|
||||
"required": schema.get("required", []),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,23 +1,20 @@
|
||||
import subprocess
|
||||
from typing import Any, Dict
|
||||
|
||||
from instructor import OpenAISchema
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Function(OpenAISchema):
|
||||
class Function(BaseModel):
|
||||
"""
|
||||
Executes Apple Script on macOS and returns the output (result).
|
||||
Can be used for actions like: draft (prepare) an email, show calendar events, create a note.
|
||||
"""
|
||||
|
||||
apple_script: str = Field(
|
||||
...,
|
||||
default=...,
|
||||
example='tell application "Finder" to get the name of every disk',
|
||||
descriptions="Apple Script to execute.",
|
||||
)
|
||||
|
||||
class Config:
|
||||
title = "execute_apple_script"
|
||||
description="Apple Script to execute.",
|
||||
) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def execute(cls, apple_script):
|
||||
@@ -31,3 +28,20 @@ class Function(OpenAISchema):
|
||||
return f"Output: {output}"
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
@classmethod
|
||||
def openai_schema(cls) -> Dict[str, Any]:
|
||||
"""Generate OpenAI function schema from Pydantic model."""
|
||||
schema = cls.model_json_schema()
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "execute_apple_script",
|
||||
"description": cls.__doc__.strip() if cls.__doc__ else "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": schema.get("properties", {}),
|
||||
"required": schema.get("required", []),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
+2
-2
@@ -7,7 +7,7 @@ from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
import typer
|
||||
from click import BadArgumentUsage
|
||||
from click import UsageError
|
||||
from distro import name as distro_name
|
||||
|
||||
from .config import cfg
|
||||
@@ -76,7 +76,7 @@ class SystemRole:
|
||||
def get(cls, name: str) -> "SystemRole":
|
||||
file_path = cls.storage / f"{name}.json"
|
||||
if not file_path.exists():
|
||||
raise BadArgumentUsage(f'Role "{name}" not found.')
|
||||
raise UsageError(f'Role "{name}" not found.')
|
||||
return cls(**json.loads(file_path.read_text()))
|
||||
|
||||
@classmethod
|
||||
|
||||
+59
-59
@@ -34,8 +34,8 @@ class TestShellGpt(TestCase):
|
||||
def setUpClass(cls):
|
||||
# Response streaming should be enabled for these tests.
|
||||
assert cfg.get("DISABLE_STREAMING") == "false"
|
||||
# ShellGPT optimised and tested with gpt-4 turbo.
|
||||
assert cfg.get("DEFAULT_MODEL") == "gpt-4o"
|
||||
# ShellGPT optimised and tested with gpt-5.4-mini.
|
||||
assert cfg.get("DEFAULT_MODEL") == "gpt-5.4-mini"
|
||||
# Make sure we will not call any functions.
|
||||
assert cfg.get("OPENAI_USE_FUNCTIONS") == "false"
|
||||
|
||||
@@ -56,7 +56,7 @@ class TestShellGpt(TestCase):
|
||||
}
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 0
|
||||
assert "Prague" in result.stdout
|
||||
assert "Prague" in result.output
|
||||
|
||||
def test_shell(self):
|
||||
dict_arguments = {
|
||||
@@ -65,7 +65,7 @@ class TestShellGpt(TestCase):
|
||||
}
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 0
|
||||
assert "git commit" in result.stdout
|
||||
assert "git commit" in result.output
|
||||
|
||||
def test_describe_shell(self):
|
||||
dict_arguments = {
|
||||
@@ -74,7 +74,7 @@ class TestShellGpt(TestCase):
|
||||
}
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 0
|
||||
assert "lists" in result.stdout.lower()
|
||||
assert "lists" in result.output.lower()
|
||||
|
||||
def test_code(self):
|
||||
"""
|
||||
@@ -93,10 +93,10 @@ class TestShellGpt(TestCase):
|
||||
}
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 0
|
||||
print(result.stdout)
|
||||
print(result.output)
|
||||
# Since output will be slightly different, there is no way how to test it precisely.
|
||||
assert "print" in result.stdout
|
||||
assert "*" in result.stdout
|
||||
assert "print" in result.output
|
||||
assert "*" in result.output
|
||||
with NamedTemporaryFile("w+", delete=False) as file:
|
||||
try:
|
||||
compile(result.output, file.name, "exec")
|
||||
@@ -124,7 +124,7 @@ class TestShellGpt(TestCase):
|
||||
dict_arguments["prompt"] = "What is my favorite number + 2?"
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 0
|
||||
assert "8" in result.stdout
|
||||
assert "8" in result.output
|
||||
dict_arguments["--shell"] = True
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 2
|
||||
@@ -143,14 +143,14 @@ class TestShellGpt(TestCase):
|
||||
}
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 0
|
||||
assert "docker run" in result.stdout
|
||||
assert "-p 80:80" in result.stdout
|
||||
assert "nginx" in result.stdout
|
||||
assert "docker run" in result.output
|
||||
assert "-p 80:80" in result.output
|
||||
assert "nginx" in result.output
|
||||
dict_arguments["prompt"] = "Also forward port 443."
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 0
|
||||
assert "-p 80:80" in result.stdout
|
||||
assert "-p 443:443" in result.stdout
|
||||
assert "-p 80:80" in result.output
|
||||
assert "-p 443:443" in result.output
|
||||
dict_arguments["--code"] = True
|
||||
del dict_arguments["--shell"]
|
||||
assert "--shell" not in dict_arguments
|
||||
@@ -167,11 +167,11 @@ class TestShellGpt(TestCase):
|
||||
}
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 0
|
||||
assert "adds" in result.stdout.lower() or "stages" in result.stdout.lower()
|
||||
assert "adds" in result.output.lower() or "stages" in result.output.lower()
|
||||
dict_arguments["prompt"] = "'-A'"
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 0
|
||||
assert "all" in result.stdout
|
||||
assert "all" in result.output
|
||||
|
||||
def test_chat_code(self):
|
||||
chat_name = uuid4()
|
||||
@@ -182,11 +182,11 @@ class TestShellGpt(TestCase):
|
||||
}
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 0
|
||||
assert "localhost:80" in result.stdout
|
||||
assert "localhost:80" in result.output
|
||||
dict_arguments["prompt"] = "Change port to 443."
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 0
|
||||
assert "localhost:443" in result.stdout
|
||||
assert "localhost:443" in result.output
|
||||
del dict_arguments["--code"]
|
||||
assert "--code" not in dict_arguments
|
||||
dict_arguments["--shell"] = True
|
||||
@@ -197,7 +197,7 @@ class TestShellGpt(TestCase):
|
||||
def test_list_chat(self):
|
||||
result = runner.invoke(app, ["--list-chats"])
|
||||
assert result.exit_code == 0
|
||||
assert "test_" in result.stdout
|
||||
assert "test_" in result.output
|
||||
|
||||
def test_show_chat(self):
|
||||
chat_name = uuid4()
|
||||
@@ -210,9 +210,9 @@ class TestShellGpt(TestCase):
|
||||
runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
result = runner.invoke(app, ["--show-chat", f"test_{chat_name}"])
|
||||
assert result.exit_code == 0
|
||||
assert "Remember my favorite number: 6" in result.stdout
|
||||
assert "What is my favorite number + 2?" in result.stdout
|
||||
assert "8" in result.stdout
|
||||
assert "Remember my favorite number: 6" in result.output
|
||||
assert "What is my favorite number + 2?" in result.output
|
||||
assert "8" in result.output
|
||||
|
||||
def test_validation_code_shell(self):
|
||||
dict_arguments = {
|
||||
@@ -222,7 +222,7 @@ class TestShellGpt(TestCase):
|
||||
}
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 2
|
||||
assert "Only one of --shell, --describe-shell, and --code" in result.stdout
|
||||
assert "Only one of --shell, --describe-shell, and --code" in result.output
|
||||
|
||||
def test_repl_default(
|
||||
self,
|
||||
@@ -240,9 +240,9 @@ class TestShellGpt(TestCase):
|
||||
app, self.get_arguments(**dict_arguments), input="\n".join(inputs)
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert ">>> Please remember my favorite number: 6" in result.stdout
|
||||
assert ">>> What is my favorite number + 2?" in result.stdout
|
||||
assert "8" in result.stdout
|
||||
assert ">>> Please remember my favorite number: 6" in result.output
|
||||
assert ">>> What is my favorite number + 2?" in result.output
|
||||
assert "8" in result.output
|
||||
|
||||
def test_repl_multiline(
|
||||
self,
|
||||
@@ -263,11 +263,11 @@ class TestShellGpt(TestCase):
|
||||
)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert '"""' in result.stdout
|
||||
assert "Please remember my favorite number: 6" in result.stdout
|
||||
assert "What is my favorite number + 2?" in result.stdout
|
||||
assert '"""' in result.stdout
|
||||
assert "8" in result.stdout
|
||||
assert '"""' in result.output
|
||||
assert "Please remember my favorite number: 6" in result.output
|
||||
assert "What is my favorite number + 2?" in result.output
|
||||
assert '"""' in result.output
|
||||
assert "8" in result.output
|
||||
|
||||
def test_repl_shell(self):
|
||||
# Temp chat session from previous test should be overwritten.
|
||||
@@ -281,11 +281,11 @@ class TestShellGpt(TestCase):
|
||||
app, self.get_arguments(**dict_arguments), input="\n".join(inputs)
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert "type [e] to execute commands" in result.stdout
|
||||
assert ">>> What is in current folder?" in result.stdout
|
||||
assert ">>> Simple sort by name" in result.stdout
|
||||
assert "ls -la" in result.stdout
|
||||
assert "sort" in result.stdout
|
||||
assert "type [e] to execute commands" in result.output
|
||||
assert ">>> What is in current folder?" in result.output
|
||||
assert ">>> Simple sort by name" in result.output
|
||||
assert "ls -la" in result.output
|
||||
assert "sort" in result.output
|
||||
chat_storage = cfg.get("CHAT_CACHE_PATH")
|
||||
tmp_chat = Path(chat_storage) / "temp"
|
||||
chat_messages = json.loads(tmp_chat.read_text())
|
||||
@@ -313,8 +313,8 @@ class TestShellGpt(TestCase):
|
||||
app, self.get_arguments(**dict_arguments), input="\n".join(inputs)
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert "install" in result.stdout.lower()
|
||||
assert "upgrade" in result.stdout.lower()
|
||||
assert "install" in result.output.lower()
|
||||
assert "upgrade" in result.output.lower()
|
||||
|
||||
chat_storage = cfg.get("CHAT_CACHE_PATH")
|
||||
tmp_chat = Path(chat_storage) / "temp"
|
||||
@@ -338,11 +338,11 @@ class TestShellGpt(TestCase):
|
||||
app, self.get_arguments(**dict_arguments), input="\n".join(inputs)
|
||||
)
|
||||
assert result.exit_code == 0
|
||||
assert f">>> {inputs[0]}" in result.stdout
|
||||
assert "requests.get" in result.stdout
|
||||
assert "localhost:8080" in result.stdout
|
||||
assert f">>> {inputs[1]}" in result.stdout
|
||||
assert "localhost:443" in result.stdout
|
||||
assert f">>> {inputs[0]}" in result.output
|
||||
assert "requests.get" in result.output
|
||||
assert "localhost:8080" in result.output
|
||||
assert f">>> {inputs[1]}" in result.output
|
||||
assert "localhost:443" in result.output
|
||||
|
||||
chat_storage = cfg.get("CHAT_CACHE_PATH")
|
||||
tmp_chat = Path(chat_storage) / dict_arguments["--repl"]
|
||||
@@ -356,8 +356,8 @@ class TestShellGpt(TestCase):
|
||||
app, self.get_arguments(**dict_arguments), input="\n".join(new_inputs)
|
||||
)
|
||||
# Should include previous chat history.
|
||||
assert "Chat History" in result.stdout
|
||||
assert f"user: {inputs[1]}" in result.stdout
|
||||
assert "Chat History" in result.output
|
||||
assert f"user: {inputs[1]}" in result.output
|
||||
|
||||
def test_zsh_command(self):
|
||||
"""
|
||||
@@ -372,12 +372,12 @@ class TestShellGpt(TestCase):
|
||||
"--shell": True,
|
||||
}
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments), input="y\n")
|
||||
stdout = result.stdout.strip()
|
||||
stdout = result.output.strip()
|
||||
print(stdout)
|
||||
# TODO: Fix this test.
|
||||
# Not sure how os.system pipes the output to stdout,
|
||||
# but it is not part of the result.stdout.
|
||||
# assert "command not found" not in result.stdout
|
||||
# but it is not part of the result.output.
|
||||
# assert "command not found" not in result.output
|
||||
# assert "hello world" in stdout.split("\n")[-1]
|
||||
|
||||
@patch("sgpt.handlers.handler.Handler.get_completion")
|
||||
@@ -400,15 +400,15 @@ class TestShellGpt(TestCase):
|
||||
def test_color_output(self):
|
||||
color = cfg.get("DEFAULT_COLOR")
|
||||
role = SystemRole.get("ShellGPT")
|
||||
handler = Handler(role=role)
|
||||
handler = Handler(role=role, markdown=False)
|
||||
assert handler.color == color
|
||||
os.environ["DEFAULT_COLOR"] = "red"
|
||||
handler = Handler(role=role)
|
||||
handler = Handler(role=role, markdown=False)
|
||||
assert handler.color == "red"
|
||||
|
||||
def test_simple_stdin(self):
|
||||
result = runner.invoke(app, input="What is the capital of Germany?\n")
|
||||
assert "Berlin" in result.stdout
|
||||
assert "Berlin" in result.output
|
||||
|
||||
def test_shell_stdin_with_prompt(self):
|
||||
dict_arguments = {
|
||||
@@ -417,8 +417,8 @@ class TestShellGpt(TestCase):
|
||||
}
|
||||
stdin = "What is in current folder\n"
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments), input=stdin)
|
||||
assert "ls" in result.stdout
|
||||
assert "sort" in result.stdout
|
||||
assert "ls" in result.output
|
||||
assert "sort" in result.output
|
||||
|
||||
def test_role(self):
|
||||
test_role = Path(cfg.get("ROLE_STORAGE_PATH")) / "json_generator.json"
|
||||
@@ -440,7 +440,7 @@ class TestShellGpt(TestCase):
|
||||
}
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 0
|
||||
assert "json_generator" in result.stdout
|
||||
assert "json_generator" in result.output
|
||||
|
||||
dict_arguments = {
|
||||
"prompt": "test",
|
||||
@@ -448,7 +448,7 @@ class TestShellGpt(TestCase):
|
||||
}
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 0
|
||||
assert "You are json_generator" in result.stdout
|
||||
assert "You are json_generator" in result.output
|
||||
|
||||
# Test with command line argument prompt.
|
||||
dict_arguments = {
|
||||
@@ -457,7 +457,7 @@ class TestShellGpt(TestCase):
|
||||
}
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments))
|
||||
assert result.exit_code == 0
|
||||
generated_json = json.loads(result.stdout)
|
||||
generated_json = json.loads(result.output)
|
||||
assert "username" in generated_json
|
||||
assert "password" in generated_json
|
||||
assert "email" in generated_json
|
||||
@@ -470,7 +470,7 @@ class TestShellGpt(TestCase):
|
||||
stdin = "random username, password, email"
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments), input=stdin)
|
||||
assert result.exit_code == 0
|
||||
generated_json = json.loads(result.stdout)
|
||||
generated_json = json.loads(result.output)
|
||||
assert "username" in generated_json
|
||||
assert "password" in generated_json
|
||||
assert "email" in generated_json
|
||||
@@ -485,7 +485,7 @@ class TestShellGpt(TestCase):
|
||||
assert result.exit_code == 0
|
||||
# Can't really test it since stdin in disable for --shell flag.
|
||||
# for word in ("prints", "hello", "console"):
|
||||
# assert word in result.stdout
|
||||
# assert word in result.output
|
||||
|
||||
def test_version(self):
|
||||
dict_arguments = {
|
||||
@@ -493,6 +493,6 @@ class TestShellGpt(TestCase):
|
||||
"--version": True,
|
||||
}
|
||||
result = runner.invoke(app, self.get_arguments(**dict_arguments), input="d\n")
|
||||
assert __version__ in result.stdout
|
||||
assert __version__ in result.output
|
||||
|
||||
# TODO: Implement function call tests.
|
||||
|
||||
+13
-13
@@ -18,7 +18,7 @@ def test_code_generation(completion):
|
||||
|
||||
completion.assert_called_once_with(**comp_args(role, args["prompt"]))
|
||||
assert result.exit_code == 0
|
||||
assert "print('Hello World')" in result.stdout
|
||||
assert "print('Hello World')" in result.output
|
||||
|
||||
|
||||
@patch("sgpt.printer.TextPrinter.live_print")
|
||||
@@ -47,8 +47,8 @@ def test_code_generation_stdin(completion):
|
||||
expected_prompt = f"{stdin}\n\n{args['prompt']}"
|
||||
completion.assert_called_once_with(**comp_args(role, expected_prompt))
|
||||
assert result.exit_code == 0
|
||||
assert "# Hello" in result.stdout
|
||||
assert "print('Hello')" in result.stdout
|
||||
assert "# Hello" in result.output
|
||||
assert "print('Hello')" in result.output
|
||||
|
||||
|
||||
@patch("sgpt.handlers.handler.completion")
|
||||
@@ -64,14 +64,14 @@ def test_code_chat(completion):
|
||||
args = {"prompt": "print hello", "--code": True, "--chat": chat_name}
|
||||
result = runner.invoke(app, cmd_args(**args))
|
||||
assert result.exit_code == 0
|
||||
assert "print('hello')" in result.stdout
|
||||
assert "print('hello')" in result.output
|
||||
assert chat_path.exists()
|
||||
|
||||
args["prompt"] = "also print world"
|
||||
result = runner.invoke(app, cmd_args(**args))
|
||||
assert result.exit_code == 0
|
||||
assert "print('hello')" in result.stdout
|
||||
assert "print('world')" in result.stdout
|
||||
assert "print('hello')" in result.output
|
||||
assert "print('world')" in result.output
|
||||
|
||||
expected_messages = [
|
||||
{"role": "system", "content": role.role},
|
||||
@@ -87,7 +87,7 @@ def test_code_chat(completion):
|
||||
args["--shell"] = True
|
||||
result = runner.invoke(app, cmd_args(**args))
|
||||
assert result.exit_code == 2
|
||||
assert "Error" in result.stdout
|
||||
assert "Error" in result.output
|
||||
chat_path.unlink()
|
||||
# TODO: Code chat can be recalled without --code option.
|
||||
|
||||
@@ -118,10 +118,10 @@ def test_code_repl(completion):
|
||||
assert completion.call_count == 2
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert ">>> print hello" in result.stdout
|
||||
assert "print('hello')" in result.stdout
|
||||
assert ">>> also print world" in result.stdout
|
||||
assert "print('world')" in result.stdout
|
||||
assert ">>> print hello" in result.output
|
||||
assert "print('hello')" in result.output
|
||||
assert ">>> also print world" in result.output
|
||||
assert "print('world')" in result.output
|
||||
|
||||
|
||||
@patch("sgpt.handlers.handler.completion")
|
||||
@@ -131,7 +131,7 @@ def test_code_and_shell(completion):
|
||||
|
||||
completion.assert_not_called()
|
||||
assert result.exit_code == 2
|
||||
assert "Error" in result.stdout
|
||||
assert "Error" in result.output
|
||||
|
||||
|
||||
@patch("sgpt.handlers.handler.completion")
|
||||
@@ -141,4 +141,4 @@ def test_code_and_describe_shell(completion):
|
||||
|
||||
completion.assert_not_called()
|
||||
assert result.exit_code == 2
|
||||
assert "Error" in result.stdout
|
||||
assert "Error" in result.output
|
||||
|
||||
+23
-23
@@ -23,7 +23,7 @@ def test_default(completion):
|
||||
|
||||
completion.assert_called_once_with(**comp_args(role, **args))
|
||||
assert result.exit_code == 0
|
||||
assert "Prague" in result.stdout
|
||||
assert "Prague" in result.output
|
||||
|
||||
|
||||
@patch("sgpt.handlers.handler.completion")
|
||||
@@ -35,7 +35,7 @@ def test_default_stdin(completion):
|
||||
|
||||
completion.assert_called_once_with(**comp_args(role, stdin))
|
||||
assert result.exit_code == 0
|
||||
assert "Prague" in result.stdout
|
||||
assert "Prague" in result.output
|
||||
|
||||
|
||||
@patch("rich.console.Console.print")
|
||||
@@ -70,7 +70,7 @@ def test_show_chat_no_use_markdown(completion, console_print):
|
||||
assert result.exit_code == 0
|
||||
assert chat_path.exists()
|
||||
|
||||
result = runner.invoke(app, ["--show-chat", chat_name])
|
||||
result = runner.invoke(app, ["--show-chat", chat_name, "--no-md"])
|
||||
assert result.exit_code == 0
|
||||
console_print.assert_not_called()
|
||||
|
||||
@@ -85,13 +85,13 @@ def test_default_chat(completion):
|
||||
args = {"prompt": "my number is 2", "--chat": chat_name}
|
||||
result = runner.invoke(app, cmd_args(**args))
|
||||
assert result.exit_code == 0
|
||||
assert "ok" in result.stdout
|
||||
assert "ok" in result.output
|
||||
assert chat_path.exists()
|
||||
|
||||
args["prompt"] = "my number + 2?"
|
||||
result = runner.invoke(app, cmd_args(**args))
|
||||
assert result.exit_code == 0
|
||||
assert "4" in result.stdout
|
||||
assert "4" in result.output
|
||||
|
||||
expected_messages = [
|
||||
{"role": "system", "content": role.role},
|
||||
@@ -106,24 +106,24 @@ def test_default_chat(completion):
|
||||
|
||||
result = runner.invoke(app, ["--list-chats"])
|
||||
assert result.exit_code == 0
|
||||
assert "_test" in result.stdout
|
||||
assert "_test" in result.output
|
||||
|
||||
result = runner.invoke(app, ["--show-chat", chat_name])
|
||||
assert result.exit_code == 0
|
||||
assert "my number is 2" in result.stdout
|
||||
assert "ok" in result.stdout
|
||||
assert "my number + 2?" in result.stdout
|
||||
assert "4" in result.stdout
|
||||
assert "my number is 2" in result.output
|
||||
assert "ok" in result.output
|
||||
assert "my number + 2?" in result.output
|
||||
assert "4" in result.output
|
||||
|
||||
args["--shell"] = True
|
||||
result = runner.invoke(app, cmd_args(**args))
|
||||
assert result.exit_code == 2
|
||||
assert "Error" in result.stdout
|
||||
assert "Error" in result.output
|
||||
|
||||
args["--code"] = True
|
||||
result = runner.invoke(app, cmd_args(**args))
|
||||
assert result.exit_code == 2
|
||||
assert "Error" in result.stdout
|
||||
assert "Error" in result.output
|
||||
chat_path.unlink()
|
||||
|
||||
|
||||
@@ -150,10 +150,10 @@ def test_default_repl(completion):
|
||||
assert completion.call_count == 2
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert ">>> my number is 6" in result.stdout
|
||||
assert "ok" in result.stdout
|
||||
assert ">>> my number + 2?" in result.stdout
|
||||
assert "8" in result.stdout
|
||||
assert ">>> my number is 6" in result.output
|
||||
assert "ok" in result.output
|
||||
assert ">>> my number + 2?" in result.output
|
||||
assert "8" in result.output
|
||||
|
||||
|
||||
@patch("sgpt.handlers.handler.completion")
|
||||
@@ -183,11 +183,11 @@ def test_default_repl_stdin(completion):
|
||||
assert completion.call_count == 2
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "this is stdin" in result.stdout
|
||||
assert ">>> prompt" in result.stdout
|
||||
assert "ok init" in result.stdout
|
||||
assert ">>> another" in result.stdout
|
||||
assert "ok another" in result.stdout
|
||||
assert "this is stdin" in result.output
|
||||
assert ">>> prompt" in result.output
|
||||
assert "ok init" in result.output
|
||||
assert ">>> another" in result.output
|
||||
assert "ok another" in result.output
|
||||
|
||||
|
||||
@patch("sgpt.handlers.handler.completion")
|
||||
@@ -212,7 +212,7 @@ def test_llm_options(completion):
|
||||
)
|
||||
completion.assert_called_once_with(**expected_args)
|
||||
assert result.exit_code == 0
|
||||
assert "Berlin" in result.stdout
|
||||
assert "Berlin" in result.output
|
||||
|
||||
|
||||
@patch("sgpt.handlers.handler.completion")
|
||||
@@ -221,7 +221,7 @@ def test_version(completion):
|
||||
result = runner.invoke(app, cmd_args(**args))
|
||||
|
||||
completion.assert_not_called()
|
||||
assert __version__ in result.stdout
|
||||
assert __version__ in result.output
|
||||
|
||||
|
||||
@patch("sgpt.printer.TextPrinter.live_print")
|
||||
|
||||
+4
-4
@@ -23,13 +23,13 @@ def test_role(completion):
|
||||
result = runner.invoke(app, cmd_args(**args))
|
||||
completion.assert_not_called()
|
||||
assert result.exit_code == 0
|
||||
assert "json_gen_test" in result.stdout
|
||||
assert "json_gen_test" in result.output
|
||||
|
||||
args = {"--show-role": "json_gen_test"}
|
||||
result = runner.invoke(app, cmd_args(**args))
|
||||
completion.assert_not_called()
|
||||
assert result.exit_code == 0
|
||||
assert "you are a JSON generator" in result.stdout
|
||||
assert "you are a JSON generator" in result.output
|
||||
|
||||
# Test with argument prompt.
|
||||
args = {
|
||||
@@ -40,7 +40,7 @@ def test_role(completion):
|
||||
role = SystemRole.get("json_gen_test")
|
||||
completion.assert_called_once_with(**comp_args(role, args["prompt"]))
|
||||
assert result.exit_code == 0
|
||||
generated_json = json.loads(result.stdout)
|
||||
generated_json = json.loads(result.output)
|
||||
assert "foo" in generated_json
|
||||
|
||||
# Test with stdin prompt.
|
||||
@@ -50,6 +50,6 @@ def test_role(completion):
|
||||
result = runner.invoke(app, cmd_args(**args), input=stdin)
|
||||
completion.assert_called_with(**comp_args(role, stdin))
|
||||
assert result.exit_code == 0
|
||||
generated_json = json.loads(result.stdout)
|
||||
generated_json = json.loads(result.output)
|
||||
assert "foo" in generated_json
|
||||
path.unlink(missing_ok=True)
|
||||
|
||||
+20
-25
@@ -17,9 +17,8 @@ def test_shell(completion):
|
||||
result = runner.invoke(app, cmd_args(**args))
|
||||
|
||||
completion.assert_called_once_with(**comp_args(role, args["prompt"]))
|
||||
assert result.exit_code == 0
|
||||
assert "git commit" in result.stdout
|
||||
assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout
|
||||
assert "git commit" in result.output
|
||||
assert "[E]xecute, [M]odify, [D]escribe, [A]bort:" in result.output
|
||||
|
||||
|
||||
@patch("sgpt.printer.TextPrinter.live_print")
|
||||
@@ -29,9 +28,8 @@ def test_shell_no_markdown(completion, markdown_printer, text_printer):
|
||||
completion.return_value = mock_comp("git commit -m test")
|
||||
|
||||
args = {"prompt": "make a commit using git", "--shell": True, "--md": True}
|
||||
result = runner.invoke(app, cmd_args(**args))
|
||||
runner.invoke(app, cmd_args(**args))
|
||||
|
||||
assert result.exit_code == 0
|
||||
# Should ignore --md for --shell option and output text without markdown.
|
||||
markdown_printer.assert_not_called()
|
||||
text_printer.assert_called()
|
||||
@@ -48,9 +46,8 @@ def test_shell_stdin(completion):
|
||||
|
||||
expected_prompt = f"{stdin}\n\n{args['prompt']}"
|
||||
completion.assert_called_once_with(**comp_args(role, expected_prompt))
|
||||
assert result.exit_code == 0
|
||||
assert "ls -l | sort" in result.stdout
|
||||
assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout
|
||||
assert "ls -l | sort" in result.output
|
||||
assert "[E]xecute, [M]odify, [D]escribe, [A]bort:" in result.output
|
||||
|
||||
|
||||
@patch("sgpt.handlers.handler.completion")
|
||||
@@ -63,7 +60,7 @@ def test_describe_shell(completion):
|
||||
|
||||
completion.assert_called_once_with(**comp_args(role, args["prompt"]))
|
||||
assert result.exit_code == 0
|
||||
assert "lists" in result.stdout
|
||||
assert "lists" in result.output
|
||||
|
||||
|
||||
@patch("sgpt.handlers.handler.completion")
|
||||
@@ -78,7 +75,7 @@ def test_describe_shell_stdin(completion):
|
||||
expected_prompt = f"{stdin}"
|
||||
completion.assert_called_once_with(**comp_args(role, expected_prompt))
|
||||
assert result.exit_code == 0
|
||||
assert "lists" in result.stdout
|
||||
assert "lists" in result.output
|
||||
|
||||
|
||||
@patch("os.system")
|
||||
@@ -91,8 +88,8 @@ def test_shell_run_description(completion, system):
|
||||
shell = os.environ.get("SHELL", "/bin/sh")
|
||||
system.assert_called_once_with(f"{shell} -c 'echo hello'")
|
||||
assert result.exit_code == 0
|
||||
assert "echo hello" in result.stdout
|
||||
assert "prints hello" in result.stdout
|
||||
assert "echo hello" in result.output
|
||||
assert "prints hello" in result.output
|
||||
|
||||
|
||||
@patch("sgpt.handlers.handler.completion")
|
||||
@@ -105,14 +102,12 @@ def test_shell_chat(completion):
|
||||
|
||||
args = {"prompt": "list folder", "--shell": True, "--chat": chat_name}
|
||||
result = runner.invoke(app, cmd_args(**args))
|
||||
assert result.exit_code == 0
|
||||
assert "ls" in result.stdout
|
||||
assert "ls" in result.output
|
||||
assert chat_path.exists()
|
||||
|
||||
args["prompt"] = "sort by name"
|
||||
result = runner.invoke(app, cmd_args(**args))
|
||||
assert result.exit_code == 0
|
||||
assert "ls | sort" in result.stdout
|
||||
assert "ls | sort" in result.output
|
||||
|
||||
expected_messages = [
|
||||
{"role": "system", "content": role.role},
|
||||
@@ -128,7 +123,7 @@ def test_shell_chat(completion):
|
||||
args["--code"] = True
|
||||
result = runner.invoke(app, cmd_args(**args))
|
||||
assert result.exit_code == 2
|
||||
assert "Error" in result.stdout
|
||||
assert "Error" in result.output
|
||||
chat_path.unlink()
|
||||
# TODO: Shell chat can be recalled without --shell option.
|
||||
|
||||
@@ -146,7 +141,7 @@ def test_shell_repl(completion, mock_system):
|
||||
inputs = ["__sgpt__eof__", "list folder", "sort by name", "e", "exit()"]
|
||||
result = runner.invoke(app, cmd_args(**args), input="\n".join(inputs))
|
||||
shell = os.environ.get("SHELL", "/bin/sh")
|
||||
mock_system.called_once_with(f"{shell} -c 'ls | sort'")
|
||||
mock_system.assert_called_once_with(f"{shell} -c 'ls | sort'")
|
||||
|
||||
expected_messages = [
|
||||
{"role": "system", "content": role.role},
|
||||
@@ -160,10 +155,10 @@ def test_shell_repl(completion, mock_system):
|
||||
assert completion.call_count == 2
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert ">>> list folder" in result.stdout
|
||||
assert "ls" in result.stdout
|
||||
assert ">>> sort by name" in result.stdout
|
||||
assert "ls | sort" in result.stdout
|
||||
assert ">>> list folder" in result.output
|
||||
assert "ls" in result.output
|
||||
assert ">>> sort by name" in result.output
|
||||
assert "ls | sort" in result.output
|
||||
|
||||
|
||||
@patch("sgpt.handlers.handler.completion")
|
||||
@@ -173,7 +168,7 @@ def test_shell_and_describe_shell(completion):
|
||||
|
||||
completion.assert_not_called()
|
||||
assert result.exit_code == 2
|
||||
assert "Error" in result.stdout
|
||||
assert "Error" in result.output
|
||||
|
||||
|
||||
@patch("sgpt.handlers.handler.completion")
|
||||
@@ -190,5 +185,5 @@ def test_shell_no_interaction(completion):
|
||||
|
||||
completion.assert_called_once_with(**comp_args(role, args["prompt"]))
|
||||
assert result.exit_code == 0
|
||||
assert "git commit" in result.stdout
|
||||
assert "[E]xecute" not in result.stdout
|
||||
assert "git commit" in result.output
|
||||
assert "[E]xecute" not in result.output
|
||||
|
||||
Reference in New Issue
Block a user