Skip to content

Commit

Permalink
strictly type python backend
Browse files Browse the repository at this point in the history
  • Loading branch information
abi committed Dec 9, 2023
1 parent 68a8d27 commit 6a28ee2
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 37 deletions.
3 changes: 3 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"python.analysis.typeCheckingMode": "strict"
}
20 changes: 11 additions & 9 deletions backend/image_generation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import asyncio
import os
import re
from typing import Dict, List, Union
from openai import AsyncOpenAI
from bs4 import BeautifulSoup


async def process_tasks(prompts, api_key, base_url):
async def process_tasks(prompts: List[str], api_key: str, base_url: str):
tasks = [generate_image(prompt, api_key, base_url) for prompt in prompts]
results = await asyncio.gather(*tasks, return_exceptions=True)

processed_results = []
processed_results: List[Union[str, None]] = []
for result in results:
if isinstance(result, Exception):
print(f"An exception occurred: {result}")
Expand All @@ -20,9 +20,9 @@ async def process_tasks(prompts, api_key, base_url):
return processed_results


async def generate_image(prompt, api_key, base_url):
async def generate_image(prompt: str, api_key: str, base_url: str):
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
image_params = {
image_params: Dict[str, Union[str, int]] = {
"model": "dall-e-3",
"quality": "standard",
"style": "natural",
Expand All @@ -35,7 +35,7 @@ async def generate_image(prompt, api_key, base_url):
return res.data[0].url


def extract_dimensions(url):
def extract_dimensions(url: str):
# Regular expression to match numbers in the format '300x200'
matches = re.findall(r"(\d+)x(\d+)", url)

Expand All @@ -48,11 +48,11 @@ def extract_dimensions(url):
return (100, 100)


def create_alt_url_mapping(code):
def create_alt_url_mapping(code: str) -> Dict[str, str]:
soup = BeautifulSoup(code, "html.parser")
images = soup.find_all("img")

mapping = {}
mapping: Dict[str, str] = {}

for image in images:
if not image["src"].startswith("https://placehold.co"):
Expand All @@ -61,7 +61,9 @@ def create_alt_url_mapping(code):
return mapping


async def generate_images(code, api_key, base_url, image_cache):
async def generate_images(
code: str, api_key: str, base_url: Union[str, None], image_cache: Dict[str, str]
):
# Find all images
soup = BeautifulSoup(code, "html.parser")
images = soup.find_all("img")
Expand Down
13 changes: 7 additions & 6 deletions backend/llm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import os
from typing import Awaitable, Callable
from typing import Awaitable, Callable, List
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk

MODEL_GPT_4_VISION = "gpt-4-vision-preview"


async def stream_openai_response(
messages,
messages: List[ChatCompletionMessageParam],
api_key: str,
base_url: str | None,
callback: Callable[[str], Awaitable[None]],
):
) -> str:
client = AsyncOpenAI(api_key=api_key, base_url=base_url)

model = MODEL_GPT_4_VISION
Expand All @@ -23,9 +23,10 @@ async def stream_openai_response(
params["max_tokens"] = 4096
params["temperature"] = 0

completion = await client.chat.completions.create(**params)
stream = await client.chat.completions.create(**params) # type: ignore
full_response = ""
async for chunk in completion:
async for chunk in stream: # type: ignore
assert isinstance(chunk, ChatCompletionChunk)
content = chunk.choices[0].delta.content or ""
full_response += content
await callback(content)
Expand Down
27 changes: 19 additions & 8 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from fastapi.responses import HTMLResponse
import openai
from llm import stream_openai_response
from mock import mock_completion
from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion
from utils import pprint_prompt
from typing import Dict, List
from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_prompt
from routes import screenshot
Expand Down Expand Up @@ -53,7 +55,7 @@ async def get_status():
)


def write_logs(prompt_messages, completion):
def write_logs(prompt_messages: List[ChatCompletionMessageParam], completion: str):
# Get the logs path from environment, default to the current working directory
logs_path = os.environ.get("LOGS_PATH", os.getcwd())

Expand Down Expand Up @@ -84,7 +86,8 @@ async def throw_error(
await websocket.send_json({"type": "error", "value": message})
await websocket.close()

params = await websocket.receive_json()
# TODO: Are the values always strings?
params: Dict[str, str] = await websocket.receive_json()

print("Received params")

Expand Down Expand Up @@ -154,7 +157,7 @@ async def throw_error(
print("generating code...")
await websocket.send_json({"type": "status", "value": "Generating code..."})

async def process_chunk(content):
async def process_chunk(content: str):
await websocket.send_json({"type": "chunk", "value": content})

# Assemble the prompt
Expand All @@ -176,15 +179,23 @@ async def process_chunk(content):
return

# Image cache for updates so that we don't have to regenerate images
image_cache = {}
image_cache: Dict[str, str] = {}

if params["generationType"] == "update":
# Transform into message format
# TODO: Move this to frontend
for index, text in enumerate(params["history"]):
prompt_messages += [
{"role": "assistant" if index % 2 == 0 else "user", "content": text}
]
if index % 2 == 0:
message: ChatCompletionMessageParam = {
"role": "assistant",
"content": text,
}
else:
message: ChatCompletionMessageParam = {
"role": "user",
"content": text,
}
prompt_messages.append(message)
image_cache = create_alt_url_mapping(params["history"][-2])

if SHOULD_MOCK_AI_RESPONSE:
Expand Down
3 changes: 2 additions & 1 deletion backend/mock.py → backend/mock_llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
from typing import Awaitable, Callable


async def mock_completion(process_chunk):
async def mock_completion(process_chunk: Callable[[str], Awaitable[None]]) -> str:
code_to_return = NO_IMAGES_NYTIMES_MOCK_CODE

for i in range(0, len(code_to_return), 10):
Expand Down
13 changes: 10 additions & 3 deletions backend/prompts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from typing import List, Union

from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam


TAILWIND_SYSTEM_PROMPT = """
You are an expert Tailwind developer
You take screenshots of a reference web page from the user, and then build single page apps
Expand Down Expand Up @@ -117,8 +122,10 @@


def assemble_prompt(
image_data_url, generated_code_config: str, result_image_data_url=None
):
image_data_url: str,
generated_code_config: str,
result_image_data_url: Union[str, None] = None,
) -> List[ChatCompletionMessageParam]:
# Set the system prompt based on the output settings
system_content = TAILWIND_SYSTEM_PROMPT
if generated_code_config == "html_tailwind":
Expand All @@ -132,7 +139,7 @@ def assemble_prompt(
else:
raise Exception("Code config is not one of available options")

user_content = [
user_content: List[ChatCompletionContentPartParam] = [
{
"type": "image_url",
"image_url": {"url": image_data_url, "detail": "high"},
Expand Down
4 changes: 3 additions & 1 deletion backend/routes/screenshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ def bytes_to_data_url(image_bytes: bytes, mime_type: str) -> str:
return f"data:{mime_type};base64,{base64_image}"


async def capture_screenshot(target_url, api_key, device="desktop") -> bytes:
async def capture_screenshot(
target_url: str, api_key: str, device: str = "desktop"
) -> bytes:
api_base_url = "https://api.screenshotone.com/take"

params = {
Expand Down
20 changes: 11 additions & 9 deletions backend/utils.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
import copy
import json
from typing import List
from openai.types.chat import ChatCompletionMessageParam


def pprint_prompt(prompt_messages):
def pprint_prompt(prompt_messages: List[ChatCompletionMessageParam]):
print(json.dumps(truncate_data_strings(prompt_messages), indent=4))


def truncate_data_strings(data):
def truncate_data_strings(data: List[ChatCompletionMessageParam]): # type: ignore
# Deep clone the data to avoid modifying the original object
cloned_data = copy.deepcopy(data)

if isinstance(cloned_data, dict):
for key, value in cloned_data.items():
for key, value in cloned_data.items(): # type: ignore
# Recursively call the function if the value is a dictionary or a list
if isinstance(value, (dict, list)):
cloned_data[key] = truncate_data_strings(value)
cloned_data[key] = truncate_data_strings(value) # type: ignore
# Truncate the string if it it's long and add ellipsis and length
elif isinstance(value, str):
cloned_data[key] = value[:40]
cloned_data[key] = value[:40] # type: ignore
if len(value) > 40:
cloned_data[key] += "..." + f" ({len(value)} chars)"
cloned_data[key] += "..." + f" ({len(value)} chars)" # type: ignore

elif isinstance(cloned_data, list):
elif isinstance(cloned_data, list): # type: ignore
# Process each item in the list
cloned_data = [truncate_data_strings(item) for item in cloned_data]
cloned_data = [truncate_data_strings(item) for item in cloned_data] # type: ignore

return cloned_data
return cloned_data # type: ignore

0 comments on commit 6a28ee2

Please sign in to comment.