Skip to content

Commit

Permalink
fix: types for console and wrapper funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
codito committed Oct 10, 2024
1 parent 28a41ab commit 0dfe693
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 42 deletions.
1 change: 1 addition & 0 deletions .github/workflows/linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
- name: Lint with ruff
run: |
uv run ruff check
uv run basedpyright
- name: Test with pytest
run: |
uv run pytest
Expand Down
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ repos:
args: [--fix]
- id: ruff-format
types_or: [python, pyi, jupyter]
# - repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror
# rev: 1.18.3
#hooks:
# - id: basedpyright
- repo: https://github.com/DetachHead/basedpyright-pre-commit-mirror
rev: 1.18.3
hooks:
- id: basedpyright
60 changes: 32 additions & 28 deletions arey/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
)
from arey.platform.console import capture_stderr
from arey.platform.llm import get_completion_llm
from arey.prompt import Prompt

config = get_config()
completion_settings = config.chat.profile
Expand All @@ -39,6 +38,10 @@ class Message(ChatMessage):
timestamp: int # unix timestamp
context: MessageContext | None

def to_chat(self):
"""Convert to a chat message."""
return ChatMessage(sender=self.sender, text=self.text)


@dataclass
class ChatContext:
Expand All @@ -56,14 +59,14 @@ class Chat:
context: ChatContext = field(default_factory=ChatContext)


def _get_max_tokens(model: CompletionModel, prompt_model: Prompt, text: str) -> int:
context_size = model.context_size
prompt_tokens_without_history = model.count_tokens(
prompt_model.get("chat", {"user_query": text, "chat_history": ""})
)
buffer = 200

return context_size - prompt_tokens_without_history - buffer
# def _get_max_tokens(model: CompletionModel, prompt_model: Prompt, text: str) -> int:
# context_size = model.context_size
# prompt_tokens_without_history = model.count_tokens(
# prompt_model.get("chat", {"user_query": text, "chat_history": ""})
# )
# buffer = 200
#
# return context_size - prompt_tokens_without_history - buffer


def create_chat() -> tuple[Chat, ModelMetrics]:
Expand All @@ -79,23 +82,23 @@ def create_chat() -> tuple[Chat, ModelMetrics]:
return chat, model.metrics


def get_history(
model: CompletionModel, chat: Chat, prompt_model: Prompt, max_tokens: int
) -> str:
"""Get the messages for a chat."""
messages = []
token_count = 0
for message in reversed(chat.messages):
role = message.sender.role()
formatted_message = prompt_model.get_message(role, message.text)
message_tokens = model.count_tokens(formatted_message)
messages.insert(0, formatted_message)

token_count += message_tokens
if message.sender == SenderType.USER and token_count >= max_tokens:
break

return "".join(messages)
# def get_history(
# model: CompletionModel, chat: Chat, prompt_model: Prompt, max_tokens: int
# ) -> str:
# """Get the messages for a chat."""
# messages = []
# token_count = 0
# for message in reversed(chat.messages):
# role = message.sender.role()
# formatted_message = prompt_model.get_message(role, message.text)
# message_tokens = model.count_tokens(formatted_message)
# messages.insert(0, formatted_message)
#
# token_count += message_tokens
# if message.sender == SenderType.USER and token_count >= max_tokens:
# break
#
# return "".join(messages)


def create_response(chat: Chat, message: str) -> str:
Expand All @@ -120,11 +123,12 @@ def stream_response(chat: Chat, message: str) -> Iterator[str]:
chat.messages.append(user_msg)

ai_msg_text = ""
usage_series = []
usage_series: list[CompletionMetrics] = []
finish_reason = ""
with capture_stderr() as stderr:
# for chunk in model.complete(chat.messages, {"stop": prompt_model.stop_words}):
for chunk in model.complete(chat.messages, {}):
chat_messages = [m.to_chat() for m in chat.messages]
for chunk in model.complete(chat_messages, {}):
ai_msg_text += chunk.text
finish_reason = chunk.finish_reason
usage_series.append(chunk.metrics)
Expand Down
12 changes: 7 additions & 5 deletions arey/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import signal
from collections.abc import Iterable
from functools import wraps
from types import FrameType
from typing import Any, Callable

import click
Expand All @@ -29,7 +30,7 @@ def _generate_response(
) -> None:
stop_completion = False

def stop_completion_handler(_signal: signal.Signals, _frame: Any):
def stop_completion_handler(_signal: signal.Signals, _frame: FrameType):
nonlocal stop_completion
stop_completion = True

Expand Down Expand Up @@ -96,11 +97,11 @@ def _print_logs(console: Console, verbose: bool, logs: str | None) -> None:
console.print()


def error_handler(func: Callable[..., Any]):
def error_handler(func: Callable[..., int]):
"""Global error handler for Arey."""

@wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> int:
try:
return func(*args, **kwargs)
except AreyError as e:
Expand All @@ -125,18 +126,19 @@ def wrapper(*args, **kwargs):
Markdown(help_text),
)
console.print(error_text)
return 1

return wrapper


def common_options(func: Callable[..., Any]):
def common_options(func: Callable[..., int]):
"""Get common options for arey commands."""

@click.option(
"-v", "--verbose", is_flag=True, default=False, help="Show verbose logs."
)
@wraps(func)
def wrapper(*args, **kwargs):
def wrapper(*args: list[Any], **kwargs: dict[str, Any]) -> int:
return func(*args, **kwargs)

return wrapper
Expand Down
2 changes: 1 addition & 1 deletion arey/platform/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_console() -> Console:
class SignalContextManager:
"""Context manager for console signals."""

def __init__(self, signal_num: int, handler: Callable[..., Any]):
def __init__(self, signal_num: int, handler: Callable[..., None]):
"""Create a signal context instance."""
self.signal_num = signal_num
self.handler = handler
Expand Down
2 changes: 1 addition & 1 deletion arey/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def get_message(
) -> str:
"""Get a chat message for given role and text."""
merged_context = {"message_text": text} | self.custom_tokens | token_overrides
return Template(self.message_formats[role]).substitute(merged_context)
return Template(self.message_formats[role.role()]).substitute(merged_context)


@lru_cache(maxsize=1)
Expand Down
6 changes: 3 additions & 3 deletions arey/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class TaskResult:
class Task:
"""A task is a stateless script invocation with a prompt."""

prompt_overrides: dict = field(default_factory=dict)
prompt_overrides: dict[str, str] = field(default_factory=dict)
result: TaskResult | None = None


Expand Down Expand Up @@ -74,7 +74,7 @@ def run(task: Task, user_input: str) -> Iterator[str]:
] # prompt_model.get("task", context)

ai_msg_text = ""
usage_series = []
usage_series: list[CompletionMetrics] = []
finish_reason = ""
with capture_stderr() as stderr:
for chunk in model.complete(
Expand All @@ -93,7 +93,7 @@ def run(task: Task, user_input: str) -> Iterator[str]:
)


def close(task: Task):
def close(_task: Task):
"""Close a task and free the model."""
if model:
model.free()
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,5 @@ select = ["D", "E", "F", "W"]
[tool.basedpyright]
include = ["arey", "tests", "docs"]
reportUnusedCallResult = "none"
venvPath = "."
venv = ".venv"

0 comments on commit 0dfe693

Please sign in to comment.