Ollama integration 🦙 (#463)

This commit is contained in:
Farkhod Sadykov
2024-02-09 23:18:39 +01:00
committed by GitHub
parent ad6d297b28
commit 1cb61dee0a
9 changed files with 90 additions and 106 deletions
+7 -7
View File
@@ -7,9 +7,12 @@ https://github.com/TheR1D/shell_gpt/assets/16740832/9197283c-db6a-4b46-bfea-3eb7
```shell ```shell
pip install shell-gpt pip install shell-gpt
``` ```
By default, ShellGPT uses OpenAI's API and GPT-4 model. You'll need an API key, you can generate one [here](https://beta.openai.com/account/api-keys). You will be prompted for your key which will then be stored in `~/.config/shell_gpt/.sgptrc`. OpenAI API is not free of charge, please refer to the [OpenAI pricing](https://openai.com/pricing) for more information.
You'll need an OpenAI API key, you can generate one [here](https://beta.openai.com/account/api-keys). > [!TIP]
You will be prompted for your key which will then be stored in `~/.config/shell_gpt/.sgptrc`. > Alternatively, you can use locally hosted open source models which are available for free. To use local models, you will need to run your own LLM backend server such as [Ollama](https://github.com/ollama/ollama). To set up ShellGPT with Ollama, please follow this comprehensive [guide](https://github.com/TheR1D/shell_gpt/wiki/Ollama).
>
> **❗️Note that ShellGPT is not optimized for local models and may not work as expected.**
## Usage ## Usage
**ShellGPT** is designed to quickly analyse and retrieve information. It's useful for straightforward requests ranging from technical configurations to general knowledge. **ShellGPT** is designed to quickly analyse and retrieve information. It's useful for straightforward requests ranging from technical configurations to general knowledge.
@@ -24,7 +27,7 @@ git diff | sgpt "Generate git commit message, for my changes"
# -> Added main feature details into README.md # -> Added main feature details into README.md
``` ```
You can analyze logs from various sources by passing them using stdin, along with a prompt. This enables you to quickly identify errors and get suggestions for possible solutions: You can analyze logs from various sources by passing them using stdin, along with a prompt. For instance, we can use it to quickly analyze logs, identify errors and get suggestions for possible solutions:
```shell ```shell
docker logs -n 20 my_app | sgpt "check logs, find errors, provide possible solutions" docker logs -n 20 my_app | sgpt "check logs, find errors, provide possible solutions"
``` ```
@@ -40,7 +43,7 @@ You can also use all kind of redirection operators to pass input:
sgpt "summarise" < document.txt sgpt "summarise" < document.txt
# -> The document discusses the impact... # -> The document discusses the impact...
sgpt << EOF sgpt << EOF
What is the best way to lear Golang. What is the best way to lear Golang?
Provide simple hello world example. Provide simple hello world example.
EOF EOF
# -> The best way to learn Golang... # -> The best way to learn Golang...
@@ -444,9 +447,6 @@ Possible options for `CODE_THEME`: https://pygments.org/styles/
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯ ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────╯
``` ```
## LocalAI
By default, ShellGPT leverages OpenAI's large language models. However, it also provides the flexibility to use locally hosted models, which can be a cost-effective alternative. To use local models, you will need to run your own API server. You can accomplish this by using [LocalAI](https://github.com/go-skynet/LocalAI), a self-hosted, OpenAI-compatible API. Setting up LocalAI allows you to run language models on your own hardware, potentially without the need for an internet connection, depending on your usage. To set up your LocalAI, please follow this comprehensive [guide](https://github.com/TheR1D/shell_gpt/wiki/LocalAI). Remember that the performance of your local models may depend on the specifications of your hardware and the specific language model you choose to deploy.
## Docker ## Docker
Run the container using the `OPENAI_API_KEY` environment variable, and a docker volume to store cache: Run the container using the `OPENAI_API_KEY` environment variable, and a docker volume to store cache:
```shell ```shell
+11 -12
View File
@@ -4,8 +4,8 @@ build-backend = "hatchling.build"
[project] [project]
name = "shell_gpt" name = "shell_gpt"
description = "A command-line productivity tool powered by OpenAI GPT models, will help you accomplish your tasks faster and more efficiently." description = "A command-line productivity tool powered by large language models, will help you accomplish your tasks faster and more efficiently."
keywords = ["shell", "gpt", "openai", "cli", "productivity", "cheet-sheet"] keywords = ["shell", "gpt", "openai", "ollama", "cli", "productivity", "cheet-sheet"]
readme = "README.md" readme = "README.md"
license = "MIT" license = "MIT"
requires-python = ">=3.6" requires-python = ">=3.6"
@@ -28,24 +28,15 @@ classifiers = [
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
] ]
dependencies = [ dependencies = [
"requests >= 2.28.2, < 3.0.0", "litellm >= 1.20.1, < 2.0.0",
"typer >= 0.7.0, < 1.0.0", "typer >= 0.7.0, < 1.0.0",
"click >= 7.1.1, < 9.0.0", "click >= 7.1.1, < 9.0.0",
"rich >= 13.1.0, < 14.0.0", "rich >= 13.1.0, < 14.0.0",
"distro >= 1.8.0, < 2.0.0", "distro >= 1.8.0, < 2.0.0",
"openai >= 1.6.1, < 2.0.0",
"instructor >= 0.4.5, < 1.0.0", "instructor >= 0.4.5, < 1.0.0",
'pyreadline3 >= 3.4.1, < 4.0.0; sys_platform == "win32"', 'pyreadline3 >= 3.4.1, < 4.0.0; sys_platform == "win32"',
] ]
[project.scripts]
sgpt = "sgpt:cli"
[project.urls]
homepage = "https://github.com/ther1d/shell_gpt"
repository = "https://github.com/ther1d/shell_gpt"
documentation = "https://github.com/TheR1D/shell_gpt/blob/main/README.md"
[project.optional-dependencies] [project.optional-dependencies]
test = [ test = [
"pytest >= 7.2.2, < 8.0.0", "pytest >= 7.2.2, < 8.0.0",
@@ -61,6 +52,14 @@ dev = [
"pre-commit >= 3.1.1, < 4.0.0", "pre-commit >= 3.1.1, < 4.0.0",
] ]
[project.scripts]
sgpt = "sgpt:cli"
[project.urls]
homepage = "https://github.com/ther1d/shell_gpt"
repository = "https://github.com/ther1d/shell_gpt"
documentation = "https://github.com/TheR1D/shell_gpt/blob/main/README.md"
[tool.hatch.version] [tool.hatch.version]
path = "sgpt/__version__.py" path = "sgpt/__version__.py"
+1 -1
View File
@@ -1 +1 @@
__version__ = "1.2.0" __version__ = "1.3.0"
+16 -17
View File
@@ -2,7 +2,7 @@ import json
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Generator, List, Optional from typing import Any, Dict, Generator, List, Optional
from openai import OpenAI import litellm # type: ignore
from ..cache import Cache from ..cache import Cache
from ..config import cfg from ..config import cfg
@@ -10,16 +10,13 @@ from ..function import get_function
from ..printer import MarkdownPrinter, Printer, TextPrinter from ..printer import MarkdownPrinter, Printer, TextPrinter
from ..role import DefaultRoles, SystemRole from ..role import DefaultRoles, SystemRole
litellm.suppress_debug_info = True
class Handler: class Handler:
cache = Cache(int(cfg.get("CACHE_LENGTH")), Path(cfg.get("CACHE_PATH"))) cache = Cache(int(cfg.get("CACHE_LENGTH")), Path(cfg.get("CACHE_PATH")))
def __init__(self, role: SystemRole) -> None: def __init__(self, role: SystemRole) -> None:
self.client = OpenAI(
base_url=cfg.get("OPENAI_BASE_URL"),
api_key=cfg.get("OPENAI_API_KEY"),
timeout=int(cfg.get("REQUEST_TIMEOUT")),
)
self.role = role self.role = role
@property @property
@@ -73,28 +70,30 @@ class Handler:
if is_shell_role or is_code_role or is_dsc_shell_role: if is_shell_role or is_code_role or is_dsc_shell_role:
functions = None functions = None
for chunk in self.client.chat.completions.create( for chunk in litellm.completion(
model=model, model=model,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
messages=messages, # type: ignore messages=messages,
functions=functions, # type: ignore functions=functions,
stream=True, stream=True,
api_key=cfg.get("OPENAI_API_KEY"),
): ):
delta = chunk.choices[0].delta # type: ignore delta = chunk.choices[0].delta
if delta.function_call: function_call = delta.get("function_call")
if delta.function_call.name: if function_call:
name = delta.function_call.name if function_call.name:
if delta.function_call.arguments: name = function_call.name
arguments += delta.function_call.arguments if function_call.arguments:
if chunk.choices[0].finish_reason == "function_call": # type: ignore arguments += function_call.arguments
if chunk.choices[0].finish_reason == "function_call":
yield from self.handle_function_call(messages, name, arguments) yield from self.handle_function_call(messages, name, arguments)
yield from self.get_completion( yield from self.get_completion(
model, temperature, top_p, messages, functions, caching=False model, temperature, top_p, messages, functions, caching=False
) )
return return
yield delta.content or "" yield delta.get("content") or ""
def handle( def handle(
self, self,
+13 -13
View File
@@ -4,14 +4,14 @@ from unittest.mock import patch
from sgpt.config import cfg from sgpt.config import cfg
from sgpt.role import DefaultRoles, SystemRole from sgpt.role import DefaultRoles, SystemRole
from .utils import app, cmd_args, comp_args, comp_chunks, runner from .utils import app, cmd_args, comp_args, mock_comp, runner
role = SystemRole.get(DefaultRoles.CODE.value) role = SystemRole.get(DefaultRoles.CODE.value)
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_code_generation(mock): def test_code_generation(mock):
mock.return_value = comp_chunks("print('Hello World')") mock.return_value = mock_comp("print('Hello World')")
args = {"prompt": "hello world python", "--code": True} args = {"prompt": "hello world python", "--code": True}
result = runner.invoke(app, cmd_args(**args)) result = runner.invoke(app, cmd_args(**args))
@@ -21,9 +21,9 @@ def test_code_generation(mock):
assert "print('Hello World')" in result.stdout assert "print('Hello World')" in result.stdout
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_code_generation_stdin(completion): def test_code_generation_stdin(completion):
completion.return_value = comp_chunks("# Hello\nprint('Hello')") completion.return_value = mock_comp("# Hello\nprint('Hello')")
args = {"prompt": "make comments for code", "--code": True} args = {"prompt": "make comments for code", "--code": True}
stdin = "print('Hello')" stdin = "print('Hello')"
@@ -36,11 +36,11 @@ def test_code_generation_stdin(completion):
assert "print('Hello')" in result.stdout assert "print('Hello')" in result.stdout
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_code_chat(completion): def test_code_chat(completion):
completion.side_effect = [ completion.side_effect = [
comp_chunks("print('hello')"), mock_comp("print('hello')"),
comp_chunks("print('hello')\nprint('world')"), mock_comp("print('hello')\nprint('world')"),
] ]
chat_name = "_test" chat_name = "_test"
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name
@@ -77,11 +77,11 @@ def test_code_chat(completion):
# TODO: Code chat can be recalled without --code option. # TODO: Code chat can be recalled without --code option.
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_code_repl(completion): def test_code_repl(completion):
completion.side_effect = [ completion.side_effect = [
comp_chunks("print('hello')"), mock_comp("print('hello')"),
comp_chunks("print('hello')\nprint('world')"), mock_comp("print('hello')\nprint('world')"),
] ]
chat_name = "_test" chat_name = "_test"
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name
@@ -109,7 +109,7 @@ def test_code_repl(completion):
assert "print('world')" in result.stdout assert "print('world')" in result.stdout
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_code_and_shell(completion): def test_code_and_shell(completion):
args = {"--code": True, "--shell": True} args = {"--code": True, "--shell": True}
result = runner.invoke(app, cmd_args(**args)) result = runner.invoke(app, cmd_args(**args))
@@ -119,7 +119,7 @@ def test_code_and_shell(completion):
assert "Error" in result.stdout assert "Error" in result.stdout
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_code_and_describe_shell(completion): def test_code_and_describe_shell(completion):
args = {"--code": True, "--describe-shell": True} args = {"--code": True, "--describe-shell": True}
result = runner.invoke(app, cmd_args(**args)) result = runner.invoke(app, cmd_args(**args))
+14 -14
View File
@@ -8,15 +8,15 @@ from sgpt import config, main
from sgpt.__version__ import __version__ from sgpt.__version__ import __version__
from sgpt.role import DefaultRoles, SystemRole from sgpt.role import DefaultRoles, SystemRole
from .utils import app, cmd_args, comp_args, comp_chunks, runner from .utils import app, cmd_args, comp_args, mock_comp, runner
role = SystemRole.get(DefaultRoles.DEFAULT.value) role = SystemRole.get(DefaultRoles.DEFAULT.value)
cfg = config.cfg cfg = config.cfg
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_default(completion): def test_default(completion):
completion.return_value = comp_chunks("Prague") completion.return_value = mock_comp("Prague")
args = {"prompt": "capital of the Czech Republic?"} args = {"prompt": "capital of the Czech Republic?"}
result = runner.invoke(app, cmd_args(**args)) result = runner.invoke(app, cmd_args(**args))
@@ -26,9 +26,9 @@ def test_default(completion):
assert "Prague" in result.stdout assert "Prague" in result.stdout
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_default_stdin(completion): def test_default_stdin(completion):
completion.return_value = comp_chunks("Prague") completion.return_value = mock_comp("Prague")
stdin = "capital of the Czech Republic?" stdin = "capital of the Czech Republic?"
result = runner.invoke(app, cmd_args(), input=stdin) result = runner.invoke(app, cmd_args(), input=stdin)
@@ -38,9 +38,9 @@ def test_default_stdin(completion):
assert "Prague" in result.stdout assert "Prague" in result.stdout
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_default_chat(completion): def test_default_chat(completion):
completion.side_effect = [comp_chunks("ok"), comp_chunks("4")] completion.side_effect = [mock_comp("ok"), mock_comp("4")]
chat_name = "_test" chat_name = "_test"
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name
chat_path.unlink(missing_ok=True) chat_path.unlink(missing_ok=True)
@@ -90,9 +90,9 @@ def test_default_chat(completion):
chat_path.unlink() chat_path.unlink()
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_default_repl(completion): def test_default_repl(completion):
completion.side_effect = [comp_chunks("ok"), comp_chunks("8")] completion.side_effect = [mock_comp("ok"), mock_comp("8")]
chat_name = "_test" chat_name = "_test"
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name
chat_path.unlink(missing_ok=True) chat_path.unlink(missing_ok=True)
@@ -119,9 +119,9 @@ def test_default_repl(completion):
assert "8" in result.stdout assert "8" in result.stdout
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_default_repl_stdin(completion): def test_default_repl_stdin(completion):
completion.side_effect = [comp_chunks("ok init"), comp_chunks("ok another")] completion.side_effect = [mock_comp("ok init"), mock_comp("ok another")]
chat_name = "_test" chat_name = "_test"
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name
chat_path.unlink(missing_ok=True) chat_path.unlink(missing_ok=True)
@@ -153,9 +153,9 @@ def test_default_repl_stdin(completion):
assert "ok another" in result.stdout assert "ok another" in result.stdout
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_llm_options(completion): def test_llm_options(completion):
completion.return_value = comp_chunks("Berlin") completion.return_value = mock_comp("Berlin")
args = { args = {
"prompt": "capital of the Germany?", "prompt": "capital of the Germany?",
@@ -179,7 +179,7 @@ def test_llm_options(completion):
assert "Berlin" in result.stdout assert "Berlin" in result.stdout
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_version(completion): def test_version(completion):
args = {"--version": True} args = {"--version": True}
result = runner.invoke(app, cmd_args(**args)) result = runner.invoke(app, cmd_args(**args))
+4 -3
View File
@@ -5,12 +5,12 @@ from unittest.mock import patch
from sgpt.config import cfg from sgpt.config import cfg
from sgpt.role import SystemRole from sgpt.role import SystemRole
from .utils import app, cmd_args, comp_args, comp_chunks, runner from .utils import app, cmd_args, comp_args, mock_comp, runner
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_role(completion): def test_role(completion):
completion.return_value = comp_chunks('{"foo": "bar"}') completion.return_value = mock_comp('{"foo": "bar"}')
path = Path(cfg.get("ROLE_STORAGE_PATH")) / "json_gen_test.json" path = Path(cfg.get("ROLE_STORAGE_PATH")) / "json_gen_test.json"
path.unlink(missing_ok=True) path.unlink(missing_ok=True)
args = {"--create-role": "json_gen_test"} args = {"--create-role": "json_gen_test"}
@@ -44,6 +44,7 @@ def test_role(completion):
assert "foo" in generated_json assert "foo" in generated_json
# Test with stdin prompt. # Test with stdin prompt.
completion.return_value = mock_comp('{"foo": "bar"}')
args = {"--role": "json_gen_test"} args = {"--role": "json_gen_test"}
stdin = "generate foo, bar" stdin = "generate foo, bar"
result = runner.invoke(app, cmd_args(**args), input=stdin) result = runner.invoke(app, cmd_args(**args), input=stdin)
+18 -18
View File
@@ -5,13 +5,13 @@ from unittest.mock import patch
from sgpt.config import cfg from sgpt.config import cfg
from sgpt.role import DefaultRoles, SystemRole from sgpt.role import DefaultRoles, SystemRole
from .utils import app, cmd_args, comp_args, comp_chunks, runner from .utils import app, cmd_args, comp_args, mock_comp, runner
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_shell(completion): def test_shell(completion):
role = SystemRole.get(DefaultRoles.SHELL.value) role = SystemRole.get(DefaultRoles.SHELL.value)
completion.return_value = comp_chunks("git commit -m test") completion.return_value = mock_comp("git commit -m test")
args = {"prompt": "make a commit using git", "--shell": True} args = {"prompt": "make a commit using git", "--shell": True}
result = runner.invoke(app, cmd_args(**args)) result = runner.invoke(app, cmd_args(**args))
@@ -22,9 +22,9 @@ def test_shell(completion):
assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_shell_stdin(completion): def test_shell_stdin(completion):
completion.return_value = comp_chunks("ls -l | sort") completion.return_value = mock_comp("ls -l | sort")
role = SystemRole.get(DefaultRoles.SHELL.value) role = SystemRole.get(DefaultRoles.SHELL.value)
args = {"prompt": "Sort by name", "--shell": True} args = {"prompt": "Sort by name", "--shell": True}
@@ -38,9 +38,9 @@ def test_shell_stdin(completion):
assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout assert "[E]xecute, [D]escribe, [A]bort:" in result.stdout
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_describe_shell(completion): def test_describe_shell(completion):
completion.return_value = comp_chunks("lists the contents of a folder") completion.return_value = mock_comp("lists the contents of a folder")
role = SystemRole.get(DefaultRoles.DESCRIBE_SHELL.value) role = SystemRole.get(DefaultRoles.DESCRIBE_SHELL.value)
args = {"prompt": "ls", "--describe-shell": True} args = {"prompt": "ls", "--describe-shell": True}
@@ -51,9 +51,9 @@ def test_describe_shell(completion):
assert "lists" in result.stdout assert "lists" in result.stdout
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_describe_shell_stdin(completion): def test_describe_shell_stdin(completion):
completion.return_value = comp_chunks("lists the contents of a folder") completion.return_value = mock_comp("lists the contents of a folder")
role = SystemRole.get(DefaultRoles.DESCRIBE_SHELL.value) role = SystemRole.get(DefaultRoles.DESCRIBE_SHELL.value)
args = {"--describe-shell": True} args = {"--describe-shell": True}
@@ -67,9 +67,9 @@ def test_describe_shell_stdin(completion):
@patch("os.system") @patch("os.system")
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_shell_run_description(completion, system): def test_shell_run_description(completion, system):
completion.side_effect = [comp_chunks("echo hello"), comp_chunks("prints hello")] completion.side_effect = [mock_comp("echo hello"), mock_comp("prints hello")]
args = {"prompt": "echo hello", "--shell": True} args = {"prompt": "echo hello", "--shell": True}
inputs = "__sgpt__eof__\nd\ne\n" inputs = "__sgpt__eof__\nd\ne\n"
result = runner.invoke(app, cmd_args(**args), input=inputs) result = runner.invoke(app, cmd_args(**args), input=inputs)
@@ -80,9 +80,9 @@ def test_shell_run_description(completion, system):
assert "prints hello" in result.stdout assert "prints hello" in result.stdout
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_shell_chat(completion): def test_shell_chat(completion):
completion.side_effect = [comp_chunks("ls"), comp_chunks("ls | sort")] completion.side_effect = [mock_comp("ls"), mock_comp("ls | sort")]
role = SystemRole.get(DefaultRoles.SHELL.value) role = SystemRole.get(DefaultRoles.SHELL.value)
chat_name = "_test" chat_name = "_test"
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name
@@ -119,9 +119,9 @@ def test_shell_chat(completion):
@patch("os.system") @patch("os.system")
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_shell_repl(completion, mock_system): def test_shell_repl(completion, mock_system):
completion.side_effect = [comp_chunks("ls"), comp_chunks("ls | sort")] completion.side_effect = [mock_comp("ls"), mock_comp("ls | sort")]
role = SystemRole.get(DefaultRoles.SHELL.value) role = SystemRole.get(DefaultRoles.SHELL.value)
chat_name = "_test" chat_name = "_test"
chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name chat_path = Path(cfg.get("CHAT_CACHE_PATH")) / chat_name
@@ -151,7 +151,7 @@ def test_shell_repl(completion, mock_system):
assert "ls | sort" in result.stdout assert "ls | sort" in result.stdout
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_shell_and_describe_shell(completion): def test_shell_and_describe_shell(completion):
args = {"prompt": "ls", "--describe-shell": True, "--shell": True} args = {"prompt": "ls", "--describe-shell": True, "--shell": True}
result = runner.invoke(app, cmd_args(**args)) result = runner.invoke(app, cmd_args(**args))
@@ -161,9 +161,9 @@ def test_shell_and_describe_shell(completion):
assert "Error" in result.stdout assert "Error" in result.stdout
@patch("openai.resources.chat.Completions.create") @patch("litellm.completion")
def test_shell_no_interaction(completion): def test_shell_no_interaction(completion):
completion.return_value = comp_chunks("git commit -m test") completion.return_value = mock_comp("git commit -m test")
role = SystemRole.get(DefaultRoles.SHELL.value) role = SystemRole.get(DefaultRoles.SHELL.value)
args = { args = {
+6 -21
View File
@@ -1,9 +1,7 @@
import datetime from unittest.mock import ANY
import typer import typer
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from litellm import completion as completion
from openai.types.chat.chat_completion_chunk import Choice as StreamChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from typer.testing import CliRunner from typer.testing import CliRunner
from sgpt import main from sgpt import main
@@ -14,23 +12,9 @@ app = typer.Typer()
app.command()(main) app.command()(main)
def comp_chunks(tokens_string): def mock_comp(tokens_string):
return [ model = cfg.get("DEFAULT_MODEL")
ChatCompletionChunk( return completion(model=model, mock_response=tokens_string, stream=True)
id="foo",
model=cfg.get("DEFAULT_MODEL"),
object="chat.completion.chunk",
choices=[
StreamChoice(
index=0,
finish_reason=None,
delta=ChoiceDelta(content=token, role="assistant"),
),
],
created=int(datetime.datetime.now().timestamp()),
)
for token in tokens_string
]
def cmd_args(prompt="", **kwargs): def cmd_args(prompt="", **kwargs):
@@ -56,5 +40,6 @@ def comp_args(role, prompt, **kwargs):
"top_p": 1.0, "top_p": 1.0,
"functions": None, "functions": None,
"stream": True, "stream": True,
"api_key": ANY,
**kwargs, **kwargs,
} }