Skip to content

Commit

Permalink
Type the backend properly to avoid code duplication and ensure type e…
Browse files Browse the repository at this point in the history
…rrors when a stack configuration is not properly added
  • Loading branch information
abi committed Jan 8, 2024
1 parent 15dc74a commit 1aeb8c4
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 71 deletions.
72 changes: 8 additions & 64 deletions backend/prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,10 @@
from typing import List, Literal, NoReturn, Union
from typing import List, NoReturn, Union

from openai.types.chat import ChatCompletionMessageParam, ChatCompletionContentPartParam

from prompts.imported_code_prompts import (
IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT,
IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT,
IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT,
IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT,
IMPORTED_CODE_SVG_SYSTEM_PROMPT,
)
from prompts.screenshot_system_prompts import (
BOOTSTRAP_SYSTEM_PROMPT,
IONIC_TAILWIND_SYSTEM_PROMPT,
REACT_TAILWIND_SYSTEM_PROMPT,
TAILWIND_SYSTEM_PROMPT,
SVG_SYSTEM_PROMPT,
VUE_TAILWIND_SYSTEM_PROMPT,
)
from prompts.imported_code_prompts import IMPORTED_CODE_SYSTEM_PROMPTS
from prompts.screenshot_system_prompts import SYSTEM_PROMPTS
from prompts.types import Stack


USER_PROMPT = """
Expand All @@ -27,39 +15,11 @@
Generate code for a SVG that looks exactly like this.
"""

Stack = Literal[
"html_tailwind",
"react_tailwind",
"bootstrap",
"ionic_tailwind",
"vue_tailwind",
"svg",
]


def assert_never(x: NoReturn) -> NoReturn:
raise AssertionError(f"Stack is not one of available options: {x}")


def assemble_imported_code_prompt(
code: str, stack: Stack, result_image_data_url: Union[str, None] = None
) -> List[ChatCompletionMessageParam]:
system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
if stack == "html_tailwind":
system_content = IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT
elif stack == "react_tailwind":
system_content = IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT
elif stack == "bootstrap":
system_content = IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT
elif stack == "ionic_tailwind":
system_content = IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
elif stack == "vue_tailwind":
# TODO: Fix this prompt to be vue tailwind
system_content = IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT
elif stack == "svg":
system_content = IMPORTED_CODE_SVG_SYSTEM_PROMPT
else:
assert_never(stack)
system_content = IMPORTED_CODE_SYSTEM_PROMPTS[stack]

user_content = (
"Here is the code of the app: " + code
Expand All @@ -81,27 +41,11 @@ def assemble_imported_code_prompt(

def assemble_prompt(
image_data_url: str,
generated_code_config: Stack,
stack: Stack,
result_image_data_url: Union[str, None] = None,
) -> List[ChatCompletionMessageParam]:
# Set the system prompt based on the output settings
system_content = TAILWIND_SYSTEM_PROMPT
if generated_code_config == "html_tailwind":
system_content = TAILWIND_SYSTEM_PROMPT
elif generated_code_config == "react_tailwind":
system_content = REACT_TAILWIND_SYSTEM_PROMPT
elif generated_code_config == "bootstrap":
system_content = BOOTSTRAP_SYSTEM_PROMPT
elif generated_code_config == "ionic_tailwind":
system_content = IONIC_TAILWIND_SYSTEM_PROMPT
elif generated_code_config == "vue_tailwind":
system_content = VUE_TAILWIND_SYSTEM_PROMPT
elif generated_code_config == "svg":
system_content = SVG_SYSTEM_PROMPT
else:
assert_never(generated_code_config)

user_prompt = USER_PROMPT if generated_code_config != "svg" else SVG_USER_PROMPT
system_content = SYSTEM_PROMPTS[stack]
user_prompt = USER_PROMPT if stack != "svg" else SVG_USER_PROMPT

user_content: List[ChatCompletionContentPartParam] = [
{
Expand Down
13 changes: 13 additions & 0 deletions backend/prompts/imported_code_prompts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from prompts.types import SystemPrompts


IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT = """
You are an expert Tailwind developer.
Expand Down Expand Up @@ -90,3 +93,13 @@
Return only the full code in <svg></svg> tags.
Do not include markdown "```" or "```svg" at the start or end.
"""

IMPORTED_CODE_SYSTEM_PROMPTS = SystemPrompts(
html_tailwind=IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT,
react_tailwind=IMPORTED_CODE_REACT_TAILWIND_SYSTEM_PROMPT,
bootstrap=IMPORTED_CODE_BOOTSTRAP_SYSTEM_PROMPT,
ionic_tailwind=IMPORTED_CODE_IONIC_TAILWIND_SYSTEM_PROMPT,
# TODO: Fix this prompt to actually be Vue
vue_tailwind=IMPORTED_CODE_TAILWIND_SYSTEM_PROMPT,
svg=IMPORTED_CODE_SVG_SYSTEM_PROMPT,
)
15 changes: 14 additions & 1 deletion backend/prompts/screenshot_system_prompts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
TAILWIND_SYSTEM_PROMPT = """
from prompts.types import SystemPrompts


HTML_TAILWIND_SYSTEM_PROMPT = """
You are an expert Tailwind developer
You take screenshots of a reference web page from the user, and then build single page apps
using Tailwind, HTML and JS.
Expand Down Expand Up @@ -170,3 +173,13 @@
Return only the full code in <svg></svg> tags.
Do not include markdown "```" or "```svg" at the start or end.
"""


SYSTEM_PROMPTS = SystemPrompts(
html_tailwind=HTML_TAILWIND_SYSTEM_PROMPT,
react_tailwind=REACT_TAILWIND_SYSTEM_PROMPT,
bootstrap=BOOTSTRAP_SYSTEM_PROMPT,
ionic_tailwind=IONIC_TAILWIND_SYSTEM_PROMPT,
vue_tailwind=VUE_TAILWIND_SYSTEM_PROMPT,
svg=SVG_SYSTEM_PROMPT,
)
20 changes: 20 additions & 0 deletions backend/prompts/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import Literal, TypedDict


class SystemPrompts(TypedDict):
html_tailwind: str
react_tailwind: str
bootstrap: str
ionic_tailwind: str
vue_tailwind: str
svg: str


Stack = Literal[
"html_tailwind",
"react_tailwind",
"bootstrap",
"ionic_tailwind",
"vue_tailwind",
"svg",
]
18 changes: 12 additions & 6 deletions backend/routes/generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from llm import stream_openai_response
from openai.types.chat import ChatCompletionMessageParam
from mock_llm import mock_completion
from typing import Dict, List
from typing import Dict, List, cast, get_args
from image_generation import create_alt_url_mapping, generate_images
from prompts import assemble_imported_code_prompt, assemble_prompt
from access_token import validate_access_token
from datetime import datetime
import json
from prompts.types import Stack

from utils import pprint_prompt # type: ignore

Expand Down Expand Up @@ -96,6 +97,13 @@ async def throw_error(
)
return

# Validate the generated code config
if not generated_code_config in get_args(Stack):
await throw_error(f"Invalid generated code config: {generated_code_config}")
return
# Cast the variable to the Stack type
valid_stack = cast(Stack, generated_code_config)

# Get the OpenAI Base URL from the request. Fall back to environment variable if not provided.
openai_base_url = None
# Disable user-specified OpenAI Base URL in prod
Expand Down Expand Up @@ -131,7 +139,7 @@ async def process_chunk(content: str):
if params.get("isImportedFromCode") and params["isImportedFromCode"]:
original_imported_code = params["history"][0]
prompt_messages = assemble_imported_code_prompt(
original_imported_code, generated_code_config
original_imported_code, valid_stack
)
for index, text in enumerate(params["history"][1:]):
if index % 2 == 0:
Expand All @@ -150,12 +158,10 @@ async def process_chunk(content: str):
try:
if params.get("resultImage") and params["resultImage"]:
prompt_messages = assemble_prompt(
params["image"], generated_code_config, params["resultImage"]
params["image"], valid_stack, params["resultImage"]
)
else:
prompt_messages = assemble_prompt(
params["image"], generated_code_config
)
prompt_messages = assemble_prompt(params["image"], valid_stack)
except:
await websocket.send_json(
{
Expand Down

0 comments on commit 1aeb8c4

Please sign in to comment.