Skip to content

Commit

Permalink
identify exact llm being used during generation
Browse files Browse the repository at this point in the history
  • Loading branch information
abi committed Mar 18, 2024
1 parent 4e30b20 commit 81c4fbe
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 25 deletions.
28 changes: 17 additions & 11 deletions backend/llm.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
from enum import Enum
from typing import Any, Awaitable, Callable, List, cast
from anthropic import AsyncAnthropic
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionChunk

from utils import pprint_prompt

MODEL_GPT_4_VISION = "gpt-4-vision-preview"
MODEL_CLAUDE_SONNET = "claude-3-sonnet-20240229"
MODEL_CLAUDE_OPUS = "claude-3-opus-20240229"
MODEL_CLAUDE_HAIKU = "claude-3-haiku-20240307"

# Actual model versions that are passed to the LLMs and stored in our logs
class Llm(Enum):
GPT_4_VISION = "gpt-4-vision-preview"
CLAUDE_3_SONNET = "claude-3-sonnet-20240229"
CLAUDE_3_OPUS = "claude-3-opus-20240229"
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"


# Keep in sync with frontend (lib/models.ts)
# User-facing names for the models (for example, in the future, gpt_4_vision might
# be backed by a different model version)
CODE_GENERATION_MODELS = [
"gpt_4_vision",
"claude_3_sonnet",
Expand All @@ -26,13 +32,13 @@ async def stream_openai_response(
) -> str:
client = AsyncOpenAI(api_key=api_key, base_url=base_url)

model = MODEL_GPT_4_VISION
model = Llm.GPT_4_VISION

# Base parameters
params = {"model": model, "messages": messages, "stream": True, "timeout": 600}

# Add 'max_tokens' only if the model is a GPT4 vision model
if model == MODEL_GPT_4_VISION:
if model == Llm.GPT_4_VISION:
params["max_tokens"] = 4096
params["temperature"] = 0

Expand All @@ -59,12 +65,12 @@ async def stream_claude_response(
client = AsyncAnthropic(api_key=api_key)

# Base parameters
model = MODEL_CLAUDE_SONNET
model = Llm.CLAUDE_3_SONNET
max_tokens = 4096
temperature = 0.0

# Translate OpenAI messages to Claude messages
system_prompt = cast(str, messages[0]["content"])
system_prompt = cast(str, messages[0].get("content"))
claude_messages = [dict(message) for message in messages[1:]]
for message in claude_messages:
if not isinstance(message["content"], list):
Expand All @@ -91,7 +97,7 @@ async def stream_claude_response(

# Stream Claude response
async with client.messages.stream(
model=model,
model=model.value,
max_tokens=max_tokens,
temperature=temperature,
system=system_prompt,
Expand All @@ -111,7 +117,7 @@ async def stream_claude_response_native(
api_key: str,
callback: Callable[[str], Awaitable[None]],
include_thinking: bool = False,
model: str = MODEL_CLAUDE_OPUS,
model: Llm = Llm.CLAUDE_3_OPUS,
) -> str:

client = AsyncAnthropic(api_key=api_key)
Expand Down Expand Up @@ -140,7 +146,7 @@ async def stream_claude_response_native(
pprint_prompt(messages_to_send)

async with client.messages.stream(
model=model,
model=model.value,
max_tokens=max_tokens,
temperature=temperature,
system=system_prompt,
Expand Down
12 changes: 6 additions & 6 deletions backend/prompts/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,35 +311,35 @@ def test_prompts():
tailwind_prompt = assemble_prompt(
"image_data_url", "html_tailwind", "result_image_data_url"
)
assert tailwind_prompt[0]["content"] == TAILWIND_SYSTEM_PROMPT
assert tailwind_prompt[0].get("content") == TAILWIND_SYSTEM_PROMPT
assert tailwind_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore

react_tailwind_prompt = assemble_prompt(
"image_data_url", "react_tailwind", "result_image_data_url"
)
assert react_tailwind_prompt[0]["content"] == REACT_TAILWIND_SYSTEM_PROMPT
assert react_tailwind_prompt[0].get("content") == REACT_TAILWIND_SYSTEM_PROMPT
assert react_tailwind_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore

bootstrap_prompt = assemble_prompt(
"image_data_url", "bootstrap", "result_image_data_url"
)
assert bootstrap_prompt[0]["content"] == BOOTSTRAP_SYSTEM_PROMPT
assert bootstrap_prompt[0].get("content") == BOOTSTRAP_SYSTEM_PROMPT
assert bootstrap_prompt[1]["content"][2]["text"] == USER_PROMPT # type: ignore

ionic_tailwind = assemble_prompt(
"image_data_url", "ionic_tailwind", "result_image_data_url"
)
assert ionic_tailwind[0]["content"] == IONIC_TAILWIND_SYSTEM_PROMPT
assert ionic_tailwind[0].get("content") == IONIC_TAILWIND_SYSTEM_PROMPT
assert ionic_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore

vue_tailwind = assemble_prompt(
"image_data_url", "vue_tailwind", "result_image_data_url"
)
assert vue_tailwind[0]["content"] == VUE_TAILWIND_SYSTEM_PROMPT
assert vue_tailwind[0].get("content") == VUE_TAILWIND_SYSTEM_PROMPT
assert vue_tailwind[1]["content"][2]["text"] == USER_PROMPT # type: ignore

svg_prompt = assemble_prompt("image_data_url", "svg", "result_image_data_url")
assert svg_prompt[0]["content"] == SVG_SYSTEM_PROMPT
assert svg_prompt[0].get("content") == SVG_SYSTEM_PROMPT
assert svg_prompt[1]["content"][2]["text"] == SVG_USER_PROMPT # type: ignore


Expand Down
10 changes: 8 additions & 2 deletions backend/routes/generate_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from custom_types import InputMode
from llm import (
CODE_GENERATION_MODELS,
MODEL_CLAUDE_OPUS,
Llm,
stream_claude_response,
stream_claude_response_native,
stream_openai_response,
Expand Down Expand Up @@ -88,6 +88,7 @@ async def throw_error(
if code_generation_model not in CODE_GENERATION_MODELS:
await throw_error(f"Invalid model: {code_generation_model}")
raise Exception(f"Invalid model: {code_generation_model}")
exact_llm_version = None

print(
f"Generating {generated_code_config} code for uploaded {input_mode} using {code_generation_model} model..."
Expand Down Expand Up @@ -238,9 +239,10 @@ async def process_chunk(content: str):
messages=prompt_messages, # type: ignore
api_key=ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x),
model=MODEL_CLAUDE_OPUS,
model=Llm.CLAUDE_3_OPUS,
include_thinking=True,
)
exact_llm_version = Llm.CLAUDE_3_OPUS
elif code_generation_model == "claude_3_sonnet":
if not ANTHROPIC_API_KEY:
await throw_error(
Expand All @@ -253,13 +255,15 @@ async def process_chunk(content: str):
api_key=ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x),
)
exact_llm_version = Llm.CLAUDE_3_SONNET
else:
completion = await stream_openai_response(
prompt_messages, # type: ignore
api_key=openai_api_key,
base_url=openai_base_url,
callback=lambda x: process_chunk(x),
)
exact_llm_version = Llm.GPT_4_VISION
except openai.AuthenticationError as e:
print("[GENERATE_CODE] Authentication failed", e)
error_message = (
Expand Down Expand Up @@ -298,6 +302,8 @@ async def process_chunk(content: str):
if validated_input_mode == "video":
completion = extract_tag_content("html", completion)

print("Exact used model for generation: ", exact_llm_version)

# Write the messages dict into a log so that we can debug later
write_logs(prompt_messages, completion) # type: ignore

Expand Down
6 changes: 3 additions & 3 deletions backend/video/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import tempfile
import uuid
from typing import Union, cast
from typing import Any, Union, cast
from moviepy.editor import VideoFileClip # type: ignore
from PIL import Image
import math
Expand All @@ -17,7 +17,7 @@
)


async def assemble_claude_prompt_video(video_data_url: str):
async def assemble_claude_prompt_video(video_data_url: str) -> list[Any]:
images = split_video_into_screenshots(video_data_url)

# Save images to tmp if we're debugging
Expand All @@ -28,7 +28,7 @@ async def assemble_claude_prompt_video(video_data_url: str):
print(f"Number of frames extracted from video: {len(images)}")
if len(images) > 20:
print(f"Too many screenshots: {len(images)}")
return
raise ValueError("Too many screenshots extracted from video")

# Convert images to the message format for Claude
content_messages: list[dict[str, Union[dict[str, str], str]]] = []
Expand Down
5 changes: 2 additions & 3 deletions backend/video_to_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from config import ANTHROPIC_API_KEY
from video.utils import extract_tag_content, assemble_claude_prompt_video
from llm import (
MODEL_CLAUDE_OPUS,
# MODEL_CLAUDE_SONNET,
Llm,
stream_claude_response_native,
)

Expand Down Expand Up @@ -87,7 +86,7 @@ async def process_chunk(content: str):
messages=prompt_messages,
api_key=ANTHROPIC_API_KEY,
callback=lambda x: process_chunk(x),
model=MODEL_CLAUDE_OPUS,
model=Llm.CLAUDE_3_OPUS,
include_thinking=True,
)

Expand Down

0 comments on commit 81c4fbe

Please sign in to comment.