forked from kharvd/gpt-cli
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcli.py
197 lines (159 loc) · 6.12 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import re
from prompt_toolkit import PromptSession
from prompt_toolkit.history import FileHistory
from openai import OpenAIError, InvalidRequestError
from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
from typing import Any, Dict, Optional, Tuple
from rich.text import Text
from gptcli.session import (
ALL_COMMANDS,
COMMAND_CLEAR,
COMMAND_QUIT,
COMMAND_RERUN,
ChatListener,
InvalidArgumentError,
ResponseStreamer,
UserInputProvider,
)
TERMINAL_WELCOME = """
Hi! I'm here to help. Type `q` or Ctrl-D to exit, `c` or Ctrl-C to clear
the conversation, `r` or Ctrl-R to re-generate the last response.
To enter multi-line mode, enter a backslash `\\` followed by a new line.
Exit the multi-line mode by pressing ESC and then Enter (Meta+Enter).
"""
class StreamingMarkdownPrinter:
def __init__(self, console: Console, markdown: bool):
self.console = console
self.current_text = ""
self.markdown = markdown
self.live: Optional[Live] = None
def __enter__(self) -> "StreamingMarkdownPrinter":
if self.markdown:
self.live = Live(
console=self.console, auto_refresh=False, vertical_overflow="visible"
)
self.live.__enter__()
return self
def print(self, text: str):
self.current_text += text
if self.markdown:
assert self.live
content = Markdown(self.current_text, style="green")
self.live.update(content)
self.live.refresh()
else:
self.console.print(Text(text, style="green"), end="")
def __exit__(self, *args):
if self.markdown:
assert self.live
self.live.__exit__(*args)
self.console.print()
class CLIResponseStreamer(ResponseStreamer):
def __init__(self, console: Console, markdown: bool):
self.console = console
self.markdown = markdown
self.printer = StreamingMarkdownPrinter(self.console, self.markdown)
self.first_token = True
def __enter__(self):
self.printer.__enter__()
return self
def on_next_token(self, token: str):
if self.first_token and token.startswith(" "):
token = token[1:]
self.first_token = False
self.printer.print(token)
def __exit__(self, *args):
self.printer.__exit__(*args)
class CLIChatListener(ChatListener):
def __init__(self, markdown: bool):
self.markdown = markdown
self.console = Console()
def on_chat_start(self):
console = Console(width=80)
console.print(Markdown(TERMINAL_WELCOME))
def on_chat_clear(self):
self.console.print("[bold]Cleared the conversation.[/bold]")
def on_chat_rerun(self, success: bool):
if success:
self.console.print("[bold]Re-running the last message.[/bold]")
else:
self.console.print("[bold]Nothing to re-run.[/bold]")
def on_error(self, e: Exception):
if isinstance(e, InvalidRequestError):
self.console.print(
f"[red]Request Error. The last prompt was not saved: {type(e)}: {e}[/red]"
)
elif isinstance(e, OpenAIError):
self.console.print(
f"[red]API Error. Type `r` or Ctrl-R to try again: {type(e)}: {e}[/red]"
)
elif isinstance(e, InvalidArgumentError):
self.console.print(f"[red]{e.message}[/red]")
else:
self.console.print(f"[red]Error: {type(e)}: {e}[/red]")
def response_streamer(self) -> ResponseStreamer:
return CLIResponseStreamer(self.console, self.markdown)
def parse_args(input: str) -> Tuple[str, Dict[str, Any]]:
args = {}
regex = r"--(\w+)(?:\s+|=)([^\s]+)"
matches = re.findall(regex, input)
if matches:
args = dict(matches)
input = input.split("--")[0].strip()
return input, args
class CLIFileHistory(FileHistory):
def append_string(self, string: str) -> None:
if string in ALL_COMMANDS:
return
return super().append_string(string)
class CLIUserInputProvider(UserInputProvider):
def __init__(self, history_filename) -> None:
self.prompt_session = PromptSession[str](
history=CLIFileHistory(history_filename)
)
def get_user_input(self) -> Tuple[str, Dict[str, Any]]:
while (next_user_input := self._request_input()) == "":
pass
user_input, args = self._parse_input(next_user_input)
return user_input, args
def prompt(self, multiline=False):
bindings = KeyBindings()
@bindings.add("c-c")
def _(event: KeyPressEvent):
if len(event.current_buffer.text) == 0 and not multiline:
event.current_buffer.text = COMMAND_CLEAR[0]
event.current_buffer.validate_and_handle()
else:
event.app.exit(exception=KeyboardInterrupt, style="class:aborting")
@bindings.add("c-d")
def _(event: KeyPressEvent):
if len(event.current_buffer.text) == 0:
if not multiline:
event.current_buffer.text = COMMAND_QUIT[0]
event.current_buffer.validate_and_handle()
@bindings.add("c-r")
def _(event: KeyPressEvent):
if len(event.current_buffer.text) == 0:
event.current_buffer.text = COMMAND_RERUN[0]
event.current_buffer.validate_and_handle()
try:
return self.prompt_session.prompt(
"> " if not multiline else "multiline> ",
vi_mode=True,
multiline=multiline,
enable_open_in_editor=True,
key_bindings=bindings,
)
except KeyboardInterrupt:
return ""
def _request_input(self):
line = self.prompt()
if line != "\\":
return line
return self.prompt(multiline=True)
def _parse_input(self, input: str) -> Tuple[str, Dict[str, Any]]:
input, args = parse_args(input)
return input, args