Skip to content

Commit

Permalink
Merge pull request ggozad#59 from ggozad/feat/edit-model
Browse files Browse the repository at this point in the history
Add support for "editing" a chat, allowing for changing system prompt and template.
  • Loading branch information
ggozad authored Feb 14, 2024
2 parents 10d34c5 + 3992c45 commit b00ac87
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 54 deletions.
3 changes: 3 additions & 0 deletions CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Changelog
0.1.23 -
-------------------

- Add support for "editing" a chat, allowing for changing system prompt and template.
[ggozad]

- Update textual and remove our own monkey patching
for Markdown. Increase Markdown size from 20 lines to 50.
[ggozad]
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ OLLAMA_URL=http://host:port/api
The following keyboard shortcuts are available:

* `ctrl+n` - create a new chat session
* `ctrl+e` - edit the chat session (change template, system prompt or format)
* `ctrl+r` - rename the current chat session
* `ctrl+x` - delete the current chat session
* `ctrl+t` - toggle between dark/light theme
Expand All @@ -48,6 +49,8 @@ While Ollama is inferring the next message, you can press `ESC` to cancel the in

When creating a new chat, you may not only select the model, but also customize the `template` as well as the `system` instruction to pass to the model. Checking the `JSON output` checkbox will cause the model reply in JSON format. Please note that `oterm` will not (yet) pull models for you, use `ollama` to do that. All the models you have pulled or created will be available to `oterm`.

You can also "edit" the chat to change the template, system prompt or format. Note, that the model cannot be changed once the chat has started. In addition whatever "context" the chat had (an embedding of the previous messages) will be kept.

### Chat session storage

All your chat sessions are stored locally in a sqlite database.
Expand Down
115 changes: 71 additions & 44 deletions oterm/app/model_selection.py → oterm/app/chat_edit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from ast import literal_eval
from typing import Any

from rich.text import Text
from textual.app import ComposeResult
Expand All @@ -13,34 +14,28 @@
from oterm.ollama import OllamaAPI


class ModelSelection(ModalScreen[str]):
class ChatEdit(ModalScreen[str]):
api = OllamaAPI()
models = []
models_info: dict[str, dict] = {}

model_name: reactive[str] = reactive("")
tag: reactive[str] = reactive("")
bytes: reactive[int] = reactive(0)
model_info: reactive[dict[str, str]] = reactive({}, layout=True)
model_info: dict[str, str] = {}
template: reactive[str] = reactive("")
system: reactive[str] = reactive("")
params: reactive[list[tuple[str, str]]] = reactive([], layout=True)
params: reactive[list[tuple[str, str]]] = reactive([])
json_format: reactive[bool] = reactive(False)

edit_mode: reactive[bool] = reactive(False)
last_highlighted_index = None

BINDINGS = [
("escape", "cancel", "Cancel"),
("enter", "create", "Create"),
("enter", "save", "Save"),
]

def action_cancel(self) -> None:
self.dismiss()

def action_create(self) -> None:
self._create_chat()

def _create_chat(self) -> None:
def _return_chat_meta(self) -> None:
model = f"{self.model_name}:{self.tag}"
template = self.query_one(".template", TextArea).text
template = template if template != self.model_info.get("template", "") else None
Expand All @@ -57,6 +52,32 @@ def _create_chat(self) -> None:
)
self.dismiss(result)

def _parse_model_params(self, parameter_text: str) -> list[tuple[str, Any]]:
lines = parameter_text.split("\n")
params = []
for line in lines:
if line:
key, value = line.split(maxsplit=1)
try:
value = literal_eval(value)
except (SyntaxError, ValueError):
pass
params.append((key, value))
return params

def action_cancel(self) -> None:
self.dismiss()

def action_save(self) -> None:
self._return_chat_meta()

def select_model(self, model: str) -> None:
select = self.query_one("#model-select", OptionList)
for index, option in enumerate(select._options):
if str(option.prompt) == model:
select.highlighted = index
break

async def on_mount(self) -> None:
self.models = await self.api.get_models()
models = [model["name"] for model in self.models]
Expand All @@ -73,7 +94,7 @@ async def on_mount(self) -> None:
option_list.highlighted = self.last_highlighted_index

def on_option_list_option_selected(self, option: OptionList.OptionSelected) -> None:
self._create_chat()
self._return_chat_meta()

def on_option_list_option_highlighted(
self, option: OptionList.OptionHighlighted
Expand All @@ -87,15 +108,27 @@ def on_option_list_option_highlighted(
self.bytes = model_meta["size"]

self.model_info = self.models_info[model_meta["name"]]

# Now that there is a model selected we can create the chat.
create_button = self.query_one("#create-btn", Button)
create_button.disabled = False
ModelSelection.last_highlighted_index = option.option_index
self.params = self._parse_model_params(
self.model_info.get("parameters", "")
)
try:
widget = self.query_one(".parameters", Pretty)
widget.update(self.params)
widget = self.query_one(".template", TextArea)
widget.load_text(self.template or self.model_info.get("template", ""))
widget = self.query_one(".system", TextArea)
widget.load_text(self.system or self.model_info.get("system", ""))
except NoMatches:
pass

# Now that there is a model selected we can save the chat.
save_button = self.query_one("#save-btn", Button)
save_button.disabled = False
ChatEdit.last_highlighted_index = option.option_index

def on_button_pressed(self, event: Button.Pressed) -> None:
if event.button.name == "create":
self._create_chat()
if event.button.name == "save":
self._return_chat_meta()
else:
self.dismiss()

Expand Down Expand Up @@ -124,30 +157,24 @@ def watch_bytes(self, size: int) -> None:
except NoMatches:
pass

def watch_model_info(self, model_info: dict[str, str]) -> None:
self.template = model_info.get("template", "")
self.system = model_info.get("system", "")
params = model_info.get("parameters", "")
lines = params.split("\n")
params = []
for line in lines:
if line:
key, value = line.split(maxsplit=1)
try:
value = literal_eval(value)
except (SyntaxError, ValueError):
pass
params.append((key, value))
self.params = params

def watch_template(self, template: str) -> None:
try:
widget = self.query_one(".parameters", Pretty)
widget.update(self.params)
widget = self.query_one(".template", TextArea)
widget.clear()
widget.load_text(self.template)
widget.load_text(template)
except NoMatches:
pass

def watch_system(self, system: str) -> None:
try:
widget = self.query_one(".system", TextArea)
widget.load_text(self.system)
widget.load_text(system)
except NoMatches:
pass

def watch_edit_mode(self, edit_mode: bool) -> None:
try:
widget = self.query_one("#model-select", OptionList)
widget.disabled = edit_mode
except NoMatches:
pass

Expand All @@ -174,9 +201,9 @@ def compose(self) -> ComposeResult:

with Horizontal(classes="button-container"):
yield Button(
"Create",
id="create-btn",
name="create",
"Save",
id="save-btn",
name="save",
disabled=True,
variant="primary",
)
Expand Down
6 changes: 3 additions & 3 deletions oterm/app/oterm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from textual.app import App, ComposeResult
from textual.widgets import Footer, Header, TabbedContent, TabPane

from oterm.app.model_selection import ModelSelection
from oterm.app.chat_edit import ChatEdit
from oterm.app.splash import SplashScreen
from oterm.app.widgets.chat import ChatContainer
from oterm.config import appConfig
Expand All @@ -16,7 +16,7 @@ class OTerm(App):
SUB_TITLE = "A terminal-based Ollama client."
CSS_PATH = "oterm.tcss"
BINDINGS = [
("ctrl+n", "new_chat", "new chat"),
("ctrl+n", "new_chat", "new"),
("ctrl+t", "toggle_dark", "toggle theme"),
("ctrl+q", "quit", "quit"),
]
Expand Down Expand Up @@ -58,7 +58,7 @@ async def on_model_select(model_info: str) -> None:
tabs.add_pane(pane)
tabs.active = f"chat-{id}"

self.push_screen(ModelSelection(), on_model_select)
self.push_screen(ChatEdit(), on_model_select)

async def on_mount(self) -> None:
self.store = await Store.create()
Expand Down
50 changes: 46 additions & 4 deletions oterm/app/widgets/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@
from textual.events import Click
from textual.reactive import reactive
from textual.widget import Widget
from textual.widgets import LoadingIndicator, Markdown, Static, TabbedContent

from textual.widgets import (
LoadingIndicator,
Markdown,
Static,
TabbedContent,
)

from oterm.app.chat_edit import ChatEdit
from oterm.app.chat_rename import ChatRename
from oterm.app.widgets.image import ImageAdded
from oterm.app.widgets.prompt import FlexibleInput
Expand All @@ -35,8 +41,9 @@ class ChatContainer(Widget):
images: list[tuple[Path, str]] = []

BINDINGS = [
("ctrl+r", "rename_chat", "rename chat"),
("ctrl+x", "forget_chat", "forget chat"),
Binding("ctrl+e", "edit_chat", "edit", priority=True),
("ctrl+r", "rename_chat", "rename"),
("ctrl+x", "forget_chat", "forget"),
Binding(
"escape", "cancel_inference", "cancel inference", show=False, priority=True
),
Expand Down Expand Up @@ -152,6 +159,41 @@ def key_escape(self) -> None:
if hasattr(self, "inference_task"):
self.inference_task.cancel()

async def action_edit_chat(self) -> None:
async def on_model_select(model_info: str) -> None:
model: dict = json.loads(model_info)
self.template = model.get("template")
self.system = model.get("system")
self.format = model.get("format")
await self.app.store.edit_chat(
id=self.db_id,
name=self.chat_name,
template=model["template"],
system=model["system"],
format=model["format"],
)
_, _, _, context, _, _, _ = await self.app.store.get_chat(self.db_id)
self.ollama = OllamaLLM(
model=model["name"],
context=context,
template=model["template"],
system=model["system"],
format=model["format"],
)

screen = ChatEdit()
screen.model_name = self.ollama.model

await self.app.push_screen(screen, on_model_select)
screen.edit_mode = True
screen.select_model(self.ollama.model)

if self.template:
screen.template = self.template

if self.system:
screen.system = self.system

async def action_rename_chat(self) -> None:
async def on_chat_rename(name: str) -> None:
tabs = self.app.query_one(TabbedContent)
Expand Down
2 changes: 2 additions & 0 deletions oterm/store/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
UPDATE chat SET context = :context WHERE id = :id;
-- name: rename_chat
UPDATE chat SET name = :name WHERE id = :id;
-- name: edit_chat
UPDATE chat SET name = :name, template = :template, system = :system, format = :format WHERE id = :id;
-- name: get_chats
SELECT id, name, model, context, template, system, format FROM chat;
-- name: get_chat
Expand Down
26 changes: 23 additions & 3 deletions oterm/store/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,25 @@ async def rename_chat(self, id: int, name: str) -> None:
)
await connection.commit()

async def edit_chat(
self,
id: int,
name: str,
template: str | None,
system: str | None,
format: str | None,
) -> None:
async with aiosqlite.connect(self.db_path) as connection:
await chat_queries.edit_chat( # type: ignore
connection,
id=id,
name=name,
template=template,
system=system,
format=format,
)
await connection.commit()

async def get_chats(
self,
) -> list[
Expand All @@ -111,9 +130,10 @@ async def get_chats(

async def get_chat(
self, id
) -> tuple[
int, str, str, list[int], str | None, str | None, Literal["json"] | None
] | None:
) -> (
tuple[int, str, str, list[int], str | None, str | None, Literal["json"] | None]
| None
):
async with aiosqlite.connect(self.db_path) as connection:
chat = await chat_queries.get_chat(connection, id=id) # type: ignore
if chat:
Expand Down

0 comments on commit b00ac87

Please sign in to comment.