Skip to content

Commit

Permalink
Merge pull request #12 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Add wiki retrieval
  • Loading branch information
VikParuchuri authored Oct 13, 2023
2 parents 055a92a + 6bc45eb commit 3548c55
Show file tree
Hide file tree
Showing 23 changed files with 336 additions and 145 deletions.
8 changes: 4 additions & 4 deletions app/course/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,13 @@ def __init__(self, model):
self.lengths = []
self.text_data = []
self.model = model
self.kinds = []

def add_resources(self, resources):
for resource in resources:
self.content += resource.content
self.lengths.append(len(self.content))
self.kinds.append(resource.kind)

embeddings = create_embeddings(resource.content, self.model)

Expand Down Expand Up @@ -129,10 +131,8 @@ def query(self, query_text, result_count=1, score_thresh=0.6) -> List[ResearchNo
text_data = self.text_data[i]
result = ResearchNote(
content=self.content[index],
title=text_data.title,
link=text_data.link,
description=text_data.description,
outline_items=[k for k in item_mapping.keys() if item_mapping[k] == index]
outline_items=[k for k in item_mapping.keys() if item_mapping[k] == index],
kind=self.kinds[i]
)
results.append(result)
break
Expand Down
4 changes: 1 addition & 3 deletions app/course/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,5 @@

class ResearchNote(BaseModel):
content: str
title: str
link: str
description: str
outline_items: List[int]
kind: str = "pdf"
22 changes: 15 additions & 7 deletions app/course/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from app.llm.generators.concepts import generate_concepts
from app.llm.generators.outline import generate_outline
from app.services.generators.pdf import download_and_parse_pdfs, search_pdfs
from app.services.generators.wiki import search_wiki
from app.settings import settings
from app.util import debug_print_trace

Expand All @@ -34,12 +35,9 @@ async def create_course_outline(
outline_list = None
queries = None
try:
response = generate_outline(course_name, concepts, revision, item_count=outline_items, include_examples=settings.INCLUDE_EXAMPLES)

# Stream outline as it generates
async for outline_data in response:
outline_list = outline_data.outline
queries = outline_data.queries
outline_data = await generate_outline(course_name, concepts, revision, item_count=outline_items, include_examples=settings.INCLUDE_EXAMPLES)
outline_list = outline_data.outline
queries = outline_data.queries
except (GenerationError, RateLimitError, InvalidRequestError, RetryError) as e:
debug_print_trace()
print(f"Error generating outline for {course_name}")
Expand All @@ -48,12 +46,21 @@ async def create_course_outline(


async def query_course_context(
model, queries: List[str], outline_items: List[str]
model, queries: List[str], outline_items: List[str], course_name: str
) -> List[ResearchNote] | None:
# Store the pdf data in the database
# These are general background queries
pdf_results = await search_pdfs(queries)
pdf_data = await download_and_parse_pdfs(pdf_results)

# Make queries for each chapter and subsection, but not below that level
# These are specific queries related closely to the content
specific_queries = [f"{course_name}: {o}" for o in outline_items if o.count(".") < 3]
if settings.CUSTOM_SEARCH_SERVER:
if "wiki" in settings.CUSTOM_SEARCH_TYPES:
wiki_results = await search_wiki(specific_queries)
pdf_data += wiki_results

# If there are no resources, don't generate research notes
if len(pdf_data) == 0:
return
Expand All @@ -62,4 +69,5 @@ async def query_course_context(
embedding_context.add_resources(pdf_data)

results = embedding_context.query(outline_items)

return results
37 changes: 20 additions & 17 deletions app/lesson/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ async def generate_lesson(
selected_research_notes.append(research_note)

try:
response = generate_single_lesson_chunk(
new_components = await generate_single_lesson_chunk(
numbered_outline,
current_section,
generated_sections,
Expand All @@ -92,14 +92,6 @@ async def generate_lesson(
cache=use_cache,
stop_section=stop_section,
)
new_components = []
new_component_keys = []
async for chunk in response:
new_components = chunk
# Set keys for the new components to the same as the ones in the last iteration
for i, key in enumerate(new_component_keys):
new_components[i].key = key
new_component_keys = [c.key for c in new_components]
except (GenerationError, RateLimitError, InvalidRequestError) as e:
debug_print_trace()
print(f"Error generating lesson: {e}")
Expand Down Expand Up @@ -145,8 +137,8 @@ async def generate_single_lesson_chunk(
include_examples: bool,
cache: bool,
stop_section: str | None = None,
) -> AsyncGenerator[List[AllLessonComponentData], None]:
response = generate_lessons(
) -> List[AllLessonComponentData]:
chunk = await generate_lessons(
numbered_outline,
current_section,
current_section_index,
Expand All @@ -161,10 +153,21 @@ async def generate_single_lesson_chunk(

section_start = f"---{ComponentNames.section}"

async for chunk in response:
# Remove the final section header from the chunk
# This happens when we hit the stop token
chunk = chunk.strip()
# Remove the section header from the chunk
if chunk.endswith(section_start):
chunk = chunk[:-len(section_start)]

if stop_section and chunk.endswith(stop_section):
chunk = chunk[:-len(stop_section)]

chunk = chunk.strip()

new_components = parse_lesson_markdown(chunk)
if len(new_components) > 1 and new_components[-1].type == ComponentNames.section:
# Remove the final section header from the chunk
# This happens when we hit the stop token
if chunk.strip().endswith(section_start):
chunk = chunk.strip()[:-len(section_start)]
new_components = parse_lesson_markdown(chunk)
yield new_components
new_components = new_components[:-1]

return new_components
8 changes: 8 additions & 0 deletions app/llm/examples/rewrite.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[
{
"topic": "Lectures on Stochastic Processes",
"research notes\n": "* ```\nA popular random walk model is that of a random walk on a regular lattice, where at each step the location jumps to another site according to some probability distribution. In a simple random walk, the location can only jump to neighboring sites of the lattice, forming a lattice path. In a simple symmetric random walk on a locally finite lattice, the probabilities of the location jumping to each one of its immediate neighbors are the same.\n```\n* ```\nA stochastic process {Xn}n∈N0 is called a simple random walk if\n1. X0 = 0,\n2. the increment Xn+1 − Xn is independent of (X0, X1, . . . , Xn) for each n ∈ N0, and\n3. the increment Xn+1 − Xn has the coin-toss distribution, i.e.\nP[Xn+1 − Xn = 1] = P[Xn+1 − Xn = −1] = 1\n2 .\nFor the sequence {γn}n∈N, given by Theorem 3.2, define the following, new, sequence {ξn}n∈N\nof random variables:\nξn =\n{\n1, γn ≥ 1\n2\n−1, otherwise.\nThen, we set\nX0 = 0, Xn =\nn∑\nk=1\nξk, n ∈ N.\nIntuitively, we use each ξn to emulate a coin toss and then define the value of the process X at\ntime n as the cumulative sum of the first n coin-tosses.\n```",
"draft\n": "Random walk\n1.1\nSymmetric simple random walk\nLet X0 = x and\nXn+1 = Xn + ξn+1.\n(1.1)\nThe ξi are independent, identically distributed random variables such that\nP[ξi = ±1] = 1/2.\nThe probabilities for this random walk also depend on\nx, and we shall denote them by Px. We can think of this as a fair gambling\ngame, where at each stage one either wins or loses a fixed amount.\nLet Ty be the first time n ≥ 1 when Xn = y. Let ρxy = Px[Ty < ∞] be the\nprobability that the walk starting at x ever gets to y at some future time.\nFirst we show that ρ12 = 1. This says that in the fair game one is almost\nsure to eventually get ahead by one unit. This follows from the following three\nequations.\nThe first equation says that in the first step the walk either goes from 1 to\n2 directly, or it goes from 1 to 0 and then must go from 0 to 2. Thus\nρ12 = 1\n2 + 1\n2ρ02.\n(1.2)\nThe second equation says that to go from 0 to 2, the walk has to go from\n0 to 1 and then from 1 to 2. Furthermore, these two events are independent.",
"markdown": "### 1.1 Symmetric Simple Random Walk\nA random walk is a mathematical model that describes the path of a randomly moving object in discrete time steps. \n\nLet's start with the simplest form of a random walk, called the symmetric simple random walk. In this case, the object can only move one step to the left or one step to the right with equal probability at each time step. \n\nWe denote the current position of the object at time $n$ as $Xn$. The position at time $n+1$ is determined by adding an increment $ξn+1$ to the current position Xn. The increments ξn+1 are independent, identically distributed random variables, where P[ξn+1 = ±1] = 1/2.\n\nTo understand the behavior of the symmetric simple random walk, we introduce the concept of hitting time. Let Ty be the first time n ≥ 1 when Xn = y. The hitting time represents the number of steps it takes for the random walk to reach a specific position y.\n\nWe are interested in the probability that the random walk starting at position x ever reaches position y at some future time. We denote this probability as ρxy = Px[Ty < ∞].\n\nTo calculate ρxy, we can use a recursive approach. Let's consider the case where y = 1. We want to find ρ12, the probability that the random walk starting at position x ever reaches position 1.\n\nWe can break down the calculation of ρ12 into three equations:\n\n1. The first equation states that in the first step, the walk either goes directly from 1 to 2 or goes from 1 to 0 and then from 0 to 2. Therefore, we have:\n\n ρ12 = 1/2 + 1/2 * ρ02. (1.1)\n\n2. The second equation states that to go from 0 to 2, the walk must go from 0 to 1 and then from 1 to 2. Furthermore, these two events are independent.\n\n3. The third equation states that to go from 0 to 2, the walk must go from 0 to 1 and then from 1 to 2. Furthermore, these two events are independent.\n\nNow, let's solve these equations to find ρ12.\n\n```\nρ12 = 1/2 + 1/2 * ρ02\n```\n\nTo solve this equation, we need to find ρ02. We can use the same approach of breaking down the calculation into three equations:\n\n1. The first equation states that in the first step, the walk either goes directly from 0 to 1 or goes from 0 to -1 and then from -1 to 1. Therefore, we have:\n\n ρ02 = 1/2 + 1/2 * ρ-12. (1.2)\n\n2. The second equation states that to go from -1 to 1, the walk must go from -1 to 0 and then from 0 to 1. Furthermore, these two events are independent.\n\n3. The third equation states that to go from -1 to 1, the walk must go from -1 to 0 and then from 0 to 1. Furthermore, these two events are independent.\n\nBy solving these equations iteratively, we can find the value of ρ12.\n\n### Example\n\nLet's consider a symmetric simple random walk starting at position 0. We want to calculate the probability of reaching position 2 at some future time.\n\nUsing the recursive approach, we can calculate ρ02 as follows:\n\n```\nρ02 = 1/2 + 1/2 * ρ-12\n```\n\nTo find ρ-12, we repeat the process:\n\n```\nρ-12 = 1/2 + 1/2 * ρ-22\n```\n\nContinuing this process, we eventually find the value of ρ02.\n\nNow, let's substitute the value of ρ02 into the equation for ρ12:\n\n```\nρ12 = 1/2 + 1/2 * ρ02\n```\n\nBy solving this equation, we can determine the probability of reaching position 2 starting from position 0.\n\n"
}
]
5 changes: 1 addition & 4 deletions app/llm/generators/concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,9 @@ def after_retry_callback(retry_state):
)
async def generate_concepts(topic: str, revision: int, include_examples: bool = True) -> CourseGeneratedConcepts:
prompt = concept_prompt(topic, include_examples=include_examples)
text = ""
# If we should cache the prompt - skip cache if we're retrying
should_cache = not getattr(local_data, "is_retry", False)
response = generate_response(prompt, concept_settings, cache=should_cache, revision=revision)
async for chunk in response:
text += chunk
text = await generate_response(prompt, concept_settings, cache=should_cache, revision=revision)
try:
text = extract_only_json_dict(text)
text = str(ftfy.fix_text(text))
Expand Down
27 changes: 6 additions & 21 deletions app/llm/generators/lesson.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from app.components.schemas import ComponentNames
from app.course.schemas import ResearchNote
from app.llm.llm import GenerationSettings, generate_response
from app.llm.prompts import build_prompt
from app.llm.prompts import build_prompt, render_research_notes
from app.settings import settings
from copy import deepcopy

Expand Down Expand Up @@ -85,11 +85,7 @@ def lesson_prompt(

research_notes_exist = research_notes is not None and len(research_notes) > 0
if research_notes_exist:
research_content = ""
for research_note in research_notes:
content = research_note.content.replace("```", " ")
content = f"```{content}```"
research_content += f"* {content}\n"
research_content = render_research_notes(research_notes)
items.append(("research notes\n", research_content))

items.append(("course\n\n", current_section))
Expand Down Expand Up @@ -125,10 +121,9 @@ async def generate_lessons(
revision: int,
research_notes: List[ResearchNote] | None = None,
include_examples: bool = True,
update_after_chars: int = 500,
cache: bool = True,
stop_section: str | None = None,
) -> AsyncGenerator[str, None]:
) -> str:
prompt = lesson_prompt(
outline,
current_section,
Expand All @@ -139,20 +134,10 @@ async def generate_lessons(
research_notes,
)

text = ""
stop_sequences = None
if stop_section is not None:
stop_sequences = [stop_section]

response = generate_response(prompt, lesson_settings, cache=cache, revision=revision, stop_sequences=stop_sequences)
chunk_len = 0

# Yield text in batches, to avoid creating too many DB models
async for chunk in response:
text += chunk
chunk_len += len(chunk)
if chunk_len >= update_after_chars:
yield text
chunk_len = 0
# Yield the remaining text
yield text
text = await generate_response(prompt, lesson_settings, cache=cache, revision=revision, stop_sequences=stop_sequences)

return text
34 changes: 4 additions & 30 deletions app/llm/generators/outline.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,6 @@ def parse_json_data(outline: dict) -> GeneratedOutlineData:
return outline


def try_parse_json(text: str) -> dict | None:
data = None
try:
data = json.loads(text.strip())
except json.decoder.JSONDecodeError:
# Try to re-parse if it failed
try:
data = json.loads(text.strip() + '"]}')
except json.decoder.JSONDecodeError:
# If it fails again, keep going with the loop.
pass
return data


local_data = threading.local()


Expand All @@ -93,10 +79,9 @@ async def generate_outline(
topic: str,
concepts: List[str],
revision: int,
update_after_chars: int = 50,
item_count: int = 10,
include_examples: bool = True
) -> AsyncGenerator[GeneratedOutlineData, None]:
) -> GeneratedOutlineData:
# Sort concepts alphabetically so that the prompt is the same every time
concepts = sorted(concepts)
prompt = outline_prompt(topic, concepts, item_count=item_count, include_examples=include_examples)
Expand All @@ -105,27 +90,16 @@ async def generate_outline(
text = prompt_start_hint
# Do not hit cache on retries
should_cache = not getattr(local_data, "is_retry", False)
response = generate_response(prompt, outline_settings, cache=should_cache, revision=revision)

chunk_len = 0
async for chunk in response:
text += chunk
chunk_len += len(chunk)
if chunk_len >= update_after_chars:
data = try_parse_json(text.strip())
if data:
yield parse_json_data(data)
chunk_len = 0

# Handle the last bit of data
text += await generate_response(prompt, outline_settings, cache=should_cache, revision=revision)

try:
# Strip out text before/after the json. Sometimes the LLM will include something before the json input.
text = extract_only_json_dict(text)
text = str(ftfy.fix_text(text))
data = json.loads(text.strip())
except JSONDecodeError as e:
raise GenerationError(e)
yield parse_json_data(data)
return parse_json_data(data)


def renumber_outline(outline):
Expand Down
67 changes: 67 additions & 0 deletions app/llm/generators/rewrite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import json
import os
from collections import OrderedDict
from typing import AsyncGenerator, List, get_args

from app.course.schemas import ResearchNote
from app.llm.llm import GenerationSettings, generate_response
from app.llm.prompts import build_prompt, render_research_notes
from app.settings import settings

rewrite_settings = GenerationSettings(
temperature=.6,
max_tokens=6000,
timeout=1200,
prompt_type="rewrite"
)


def rewrite_prompt(
topic: str,
draft: str,
include_examples: bool,
research_notes: List[ResearchNote] | None = None,
) -> str:
with open(os.path.join(settings.EXAMPLE_JSON_DIR, "rewrite.json")) as f:
examples = json.load(f)

items = [("topic", topic)]

research_notes_exist = research_notes is not None and len(research_notes) > 0
if research_notes_exist:
research_content = render_research_notes(research_notes)
items.append(("research notes\n", research_content))

items.append(("draft\n", draft))

input = OrderedDict(items)

prompt = build_prompt(
"rewrite",
input,
examples,
include_examples=include_examples,
topic=topic,
research_notes=research_notes_exist,
)
return prompt


async def generate_rewrite(
topic: str,
draft: str,
revision: int,
research_notes: List[ResearchNote] | None = None,
include_examples: bool = True,
cache: bool = True,
) -> str:
prompt = rewrite_prompt(
topic,
draft,
include_examples,
research_notes,
)

text = await generate_response(prompt, rewrite_settings, cache=cache, revision=revision)

return text
5 changes: 1 addition & 4 deletions app/llm/generators/title.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ async def generate_title(
subject: str,
) -> List[str]:
prompt = title_prompt(subject)
text = ""
response = generate_response(prompt, title_settings, cache=False)
async for chunk in response:
text += chunk
text = await generate_response(prompt, title_settings, cache=False)

try:
text = extract_only_json_list(text)
Expand Down
5 changes: 1 addition & 4 deletions app/llm/generators/toc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,14 @@ def toc_prompt(topic: str, toc: str, include_examples=True) -> str:

async def generate_tocs(topic: str, draft_toc: str, include_examples: bool = True) -> GeneratedTOC | None:
prompt = toc_prompt(topic, draft_toc, include_examples=include_examples)
text = ""

settings_inst = deepcopy(toc_settings)
try:
settings_inst.max_tokens = oai_tokenize_prompt(draft_toc) + 512 # Max tokens to generate
except Exception:
return

response = generate_response(prompt, settings_inst)
async for chunk in response:
text += chunk
text = await generate_response(prompt, settings_inst)
try:
text = extract_only_json_dict(text)
text = str(ftfy.fix_text(text))
Expand Down
Loading

0 comments on commit 3548c55

Please sign in to comment.