Skip to content

Commit

Permalink
Add in retries, improve reliability
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Sep 28, 2023
1 parent 060132c commit f9fbdbd
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 13 deletions.
16 changes: 12 additions & 4 deletions app/course/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def run_query(
query_text: str | List[str], embeddings, model, result_count=1, score_thresh=0.6
):
query_embedding = model.encode(query_text, convert_to_tensor=True)

cos_scores = util.cos_sim(query_embedding, embeddings)
top_results = torch.topk(cos_scores, k=result_count, dim=-1)

Expand All @@ -45,7 +46,10 @@ def dedup_list(topics, score_thresh=0.9):
res = tc.query(topic, score_thresh=score_thresh)
if len(res) == 0:
clean_topics.append(topic)
tc.add_topics([topic])
try:
tc.add_topics([topic])
except KeyError:
pass
return clean_topics


Expand All @@ -66,9 +70,13 @@ def add_topics(self, topics):
self.embeddings = torch.cat((self.embeddings, embeddings), dim=0)

def query(self, query_text, result_count=1, score_thresh=0.9) -> List[str]:
scores, selected_indices, item_mapping = run_query(
query_text, self.embeddings, model, result_count, score_thresh=score_thresh
)
try:
scores, selected_indices, item_mapping = run_query(
query_text, self.embeddings, model, result_count, score_thresh=score_thresh
)
except KeyError as e:
print(f"Error querying topic embedding: {e}")
return []

results = []
for index in selected_indices:
Expand Down
2 changes: 1 addition & 1 deletion app/course/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def create_course_outline(
queries = outline_data.queries
except (GenerationError, RateLimitError, InvalidRequestError) as e:
debug_print_trace()
print(f"Error generating outline for {course_name}: {e}")
print(f"Error generating outline for {course_name}")

return outline_list, queries

Expand Down
25 changes: 24 additions & 1 deletion app/llm/generators/concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import List

from pydantic import BaseModel
from tenacity import stop_after_attempt, wait_fixed, before, after, retry, retry_if_exception_type
import threading

from app.llm.exceptions import GenerationError
from app.llm.llm import GenerationSettings, generate_response
Expand Down Expand Up @@ -37,10 +39,31 @@ def concept_prompt(topic: str) -> str:
return prompt


local_data = threading.local()


def before_retry_callback(retry_state):
local_data.is_retry = True


def after_retry_callback(retry_state):
local_data.is_retry = False


@retry(
retry=retry_if_exception_type(GenerationError),
stop=stop_after_attempt(2),
wait=wait_fixed(2),
before=before_retry_callback,
after=after_retry_callback,
reraise=True,
)
async def generate_concepts(topic: str) -> CourseGeneratedConcepts:
prompt = concept_prompt(topic)
text = ""
response = generate_response(prompt, concept_settings)
# If we should cache the prompt - skip cache if we're retrying
should_cache = not getattr(local_data, "is_retry", False)
response = generate_response(prompt, concept_settings, cache=should_cache)
async for chunk in response:
text += chunk
try:
Expand Down
25 changes: 24 additions & 1 deletion app/llm/generators/outline.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json
import os
import re
import threading
from collections import OrderedDict
from json import JSONDecodeError
from typing import AsyncGenerator, List

from pydantic import BaseModel, parse_obj_as
from tenacity import retry_if_exception_type, stop_after_attempt, retry, wait_fixed

from app.llm.exceptions import GenerationError
from app.llm.llm import GenerationSettings, generate_response
Expand Down Expand Up @@ -66,6 +68,25 @@ def try_parse_json(text: str) -> dict | None:
return data


local_data = threading.local()


def before_retry_callback(retry_state):
local_data.is_retry = True


def after_retry_callback(retry_state):
local_data.is_retry = False


@retry(
retry=retry_if_exception_type(GenerationError),
stop=stop_after_attempt(2),
wait=wait_fixed(2),
before=before_retry_callback,
after=after_retry_callback,
reraise=True,
)
async def generate_outline(
topic: str,
concepts: List[str],
Expand All @@ -76,7 +97,9 @@ async def generate_outline(
concepts = sorted(concepts)
prompt = outline_prompt(topic, concepts, item_count=item_count)
text = ""
response = generate_response(prompt, outline_settings)
# Do not hit cache on retries
should_cache = not getattr(local_data, "is_retry", False)
response = generate_response(prompt, outline_settings, cache=should_cache)

chunk_len = 0
async for chunk in response:
Expand Down
2 changes: 1 addition & 1 deletion app/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Settings(BaseSettings):
SEARCH_BACKEND: Optional[str] = "serply"

# General
THREADS_PER_WORKER: int = 4 # How many threads to use per worker process to save RAM
THREADS_PER_WORKER: int = 1 # How many threads to use per worker process to save RAM

class Config:
env_file = find_dotenv("local.env")
Expand Down
11 changes: 7 additions & 4 deletions book_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,18 @@ async def generate_single_course(course_name, outline_items=12):
return course


async def _process_courses(courses):
async def process_course(topic):
try:
return await asyncio.gather(*[generate_single_course(course) for course in courses], return_exceptions=True)
return await generate_single_course(topic)
except Exception as e:
debug_print_trace()
print(f"Unhandled error generating course: {e}")


async def _process_courses(courses):
return await asyncio.gather(*[process_course(course) for course in courses])


def process_courses(courses):
return asyncio.run(_process_courses(courses))

Expand Down Expand Up @@ -130,8 +134,7 @@ def load_topics(in_file: str, max_topics: Optional[str]):
"model": settings.LLM_TYPE,
"concepts": course.concepts,
"outline": course.outline,
"markdown": course.markdown,
"components": course.components
"markdown": course.markdown
}
f.write(json.dumps(json_data) + '\n')

Expand Down
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ sentence-transformers = "^2.2.2"
datasets = "^2.14.5"
pyyaml = "^6.0.1"
ftfy = "^6.1.1"
tenacity = "^8.2.3"

[tool.poetry.group.dev.dependencies]
invoke = "^2.2.0"
Expand Down

0 comments on commit f9fbdbd

Please sign in to comment.