Skip to content

Commit

Permalink
Include wiki embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Oct 13, 2023
1 parent 7ceae0d commit 610cdb8
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 41 deletions.
3 changes: 0 additions & 3 deletions app/course/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,6 @@ def query(self, query_text, result_count=1, score_thresh=0.6) -> List[ResearchNo
text_data = self.text_data[i]
result = ResearchNote(
content=self.content[index],
title=text_data.title,
link=text_data.link,
description=text_data.description,
outline_items=[k for k in item_mapping.keys() if item_mapping[k] == index]
)
results.append(result)
Expand Down
3 changes: 0 additions & 3 deletions app/course/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,4 @@

class ResearchNote(BaseModel):
content: str
title: str
link: str
description: str
outline_items: List[int]
12 changes: 11 additions & 1 deletion app/course/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from app.llm.generators.concepts import generate_concepts
from app.llm.generators.outline import generate_outline
from app.services.generators.pdf import download_and_parse_pdfs, search_pdfs
from app.services.generators.wiki import search_wiki
from app.settings import settings
from app.util import debug_print_trace

Expand Down Expand Up @@ -48,12 +49,21 @@ async def create_course_outline(


async def query_course_context(
model, queries: List[str], outline_items: List[str]
model, queries: List[str], outline_items: List[str], course_name: str
) -> List[ResearchNote] | None:
# Store the pdf data in the database
# These are general background queries
pdf_results = await search_pdfs(queries)
pdf_data = await download_and_parse_pdfs(pdf_results)

# Make queries for each chapter and subsection, but not below that level
# These are specific queries related closely to the content
specific_queries = [f"{course_name}: {o}" for o in outline_items if o.count(".") < 3]
if settings.CUSTOM_SEARCH_SERVER:
if "wiki" in settings.CUSTOM_SEARCH_TYPES:
wiki_results = await search_wiki(specific_queries)
pdf_data += wiki_results

# If there are no resources, don't generate research notes
if len(pdf_data) == 0:
return
Expand Down
7 changes: 7 additions & 0 deletions app/llm/examples/rewrite.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[
{
"topic": "Lectures on Stochastic Processes",
"draft": "Random walk\n1.1\nSymmetric simple random walk\nLet X0 = x and\nXn+1 = Xn + ξn+1.\n(1.1)\nThe ξi are independent, identically distributed random variables such that\nP[ξi = ±1] = 1/2.\nThe probabilities for this random walk also depend on\nx, and we shall denote them by Px. We can think of this as a fair gambling\ngame, where at each stage one either wins or loses a fixed amount.\nLet Ty be the first time n ≥ 1 when Xn = y. Let ρxy = Px[Ty < ∞] be the\nprobability that the walk starting at x ever gets to y at some future time.\nFirst we show that ρ12 = 1. This says that in the fair game one is almost\nsure to eventually get ahead by one unit. This follows from the following three\nequations.\nThe first equation says that in the first step the walk either goes from 1 to\n2 directly, or it goes from 1 to 0 and then must go from 0 to 2. Thus\nρ12 = 1\n2 + 1\n2ρ02.\n(1.2)\nThe second equation says that to go from 0 to 2, the walk has to go from\n0 to 1 and then from 1 to 2. Furthermore, these two events are independent.",
"markdown": "## Random Walk\n\n### Symmetric Simple Random Walk\n\nConsider a random walk where the initial position is denoted by X0 = x and subsequent positions are given by the formula:\n\nXn+1 = Xn + ξn+1. (1.1)\n\nHere, ξi are independent and identically distributed random variables with the property:\n\nP[ξi = ±1] = 1/2.\n\nThe probabilities for this random walk are also dependent on the initial position x, and we will denote them by Px. We can think of this random walk as a fair gambling game, where at each stage, one either wins or loses a fixed amount.\n\nLet Ty be the first time n ≥ 1 when Xn = y. We define ρxy as the probability that the random walk starting at x ever reaches the position y at some future time.\n\nFirst, we will show that ρ12 = 1. This means that in this fair game, one is almost certain to eventually get ahead by one unit. This can be proven using the following three equations.\n\nThe first equation states that in the first step, the random walk either goes directly from 1 to 2 or goes from 1 to 0 and then must go from 0 to 2. Thus,\n\nρ12 = 1/2 + 1/2 * ρ02. (1.2)\n\nThe second equation states that to go from 0 to 2, the random walk must first go from 0 to 1 and then from 1 to 2. Furthermore, these two events are independent."
}
]
1 change: 1 addition & 0 deletions app/llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class PromptTypes(str, BaseEnum):
topic = "topic"
title = "title"
toc = "toc"
rewrite = "rewrite"


class Prompt(BaseDBModel, table=True):
Expand Down
42 changes: 42 additions & 0 deletions app/services/adaptors/custom_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import urllib.parse
from json import JSONDecodeError

import aiohttp

from app.services.exceptions import RequestError, ResponseError
from app.services.schemas import ServiceInfo, ServiceNames, ServiceSettings
from app.settings import settings

wiki_search_settings = ServiceSettings(name=ServiceNames.custom, type="wiki")


async def custom_search_router(service_settings: ServiceSettings, service_info: ServiceInfo):
match service_settings.type:
case "wiki":
response = await run_search(
service_info.query, "search", extract_field="match"
)
case _:
raise RequestError(f"Unknown external search service type {service_settings.type}")
return response


async def run_search(query: str, endpoint: str, extract_field: str = None):
if not settings.CUSTOM_SEARCH_SERVER:
raise RequestError(f"Custom search server not configured")

params = {"query": query}
auth = aiohttp.BasicAuth(settings.CUSTOM_SEARCH_USER, settings.CUSTOM_SEARCH_PASSWORD)

request_url = f"{settings.CUSTOM_SEARCH_SERVER}/{endpoint}"

try:
async with aiohttp.ClientSession() as session:
async with session.get(request_url, params=params, auth=auth) as response:
json = await response.json()
except aiohttp.ClientResponseError as e:
raise RequestError(f"Custom search request failed with status {e.status}")
except JSONDecodeError as e:
raise ResponseError(f"Could not decode custom search response as JSON: {e}")

return {"text": json[extract_field]}
52 changes: 42 additions & 10 deletions app/services/generators/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
from app.services.exceptions import ProcessingError
from app.services.models import store_scraped_data
from app.services.network import download_and_save
from app.services.schemas import PDFData, ServiceInfo, ServiceNames
from app.services.schemas import SearchData, ServiceInfo, ServiceNames
from app.services.service import get_service_response
from app.settings import settings

BLOCK_SIZE = 2000 # Characters per text block

SEARCH_SETTINGS = {
"serply": serply_pdf_search_settings,
"serpapi": serpapi_pdf_search_settings,
Expand Down Expand Up @@ -103,7 +101,7 @@ async def search_pdf(query: str, max_count) -> List[PDFSearchResult]:

async def download_and_parse_pdfs(
search_results: List[PDFSearchResult],
) -> List[PDFData]:
) -> List[SearchData]:
# Deduplicate links
deduped_search_results = []
seen_links = set()
Expand All @@ -130,7 +128,7 @@ async def download_and_parse_pdfs(
return results


async def download_and_parse_pdf(search_result: PDFSearchResult, pdf_path: Optional[str]) -> Optional[PDFData]:
async def download_and_parse_pdf(search_result: PDFSearchResult, pdf_path: Optional[str]) -> Optional[SearchData]:
stored = False
if pdf_path:
with open(os.path.join(settings.PDF_CACHE_DIR, pdf_path), "rb") as f:
Expand All @@ -153,11 +151,10 @@ async def download_and_parse_pdf(search_result: PDFSearchResult, pdf_path: Optio
except FileDataError:
return

pdf_cls = PDFData(
pdf_cls = SearchData(
pdf_path=pdf_path,
link=search_result.link,
title=search_result.title,
description=search_result.description,
content=pdf_content,
query=search_result.query,
stored=stored,
Expand All @@ -174,6 +171,41 @@ async def download_and_parse_pdf(search_result: PDFSearchResult, pdf_path: Optio
return pdf_cls


def smart_split(s, max_remove=settings.CONTEXT_BLOCK_SIZE // 4):
# Split into chunks based on actual word boundaries
s_len = len(s)

# Don't remove anything if string is too short
if max_remove > s_len:
return s, ""

delimiter = None
max_len = 0

for split_delimiter in ["\n\n", ". ", "! ", "? ", "}\n", ":\n", ")\n", ".\n", "!\n", "?\n"]:
split_str = s.rsplit(split_delimiter, 1)
if len(split_str) > 1 and len(split_str[0]) > max_len:
max_len = len(split_str[0])
delimiter = split_delimiter

if delimiter is not None and max_len > s_len - max_remove:
return s.rsplit(delimiter, 1)

# Try \n as a last resort
str_split = s.rsplit("\n", 1)
if len(split_str) > 1 and len(split_str[0]) > max_len:
max_len = len(str_split[0])
delimiter = "\n"

if delimiter is None:
return s, ""

if max_len < s_len - max_remove:
return s, ""

return s.rsplit(delimiter, 1)


def parse_pdf(data) -> List[str]:
with pymupdf.open(stream=data) as doc:
blocks = []
Expand All @@ -196,8 +228,8 @@ def parse_pdf(data) -> List[str]:
block = ""
for i, b in enumerate(blocks):
block += b[4]
if len(block) > BLOCK_SIZE:
parsed_blocks.append(block)
block = ""
if len(block) > settings.CONTEXT_BLOCK_SIZE:
parsed_block, block = smart_split(block)
parsed_blocks.append(parsed_block)
parsed_blocks.append(block)
return parsed_blocks
47 changes: 47 additions & 0 deletions app/services/generators/wiki.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import asyncio
from typing import List

from pydantic import BaseModel

from app.services.schemas import ServiceInfo, SearchData
from app.services.service import get_service_response
from app.settings import settings
from app.services.adaptors.custom_search import wiki_search_settings


async def search_wiki(queries: List[str]) -> List[SearchData]:
coroutines = [_search_wiki(query) for query in queries]

# Run queries in parallel
results = await asyncio.gather(*coroutines)

# Filter results to only unique wiki entries
filtered = []
seen_text = []
for r in results:
text = r.content[0]
if text not in seen_text:
seen_text.append(text)
filtered.append(r)
return filtered


async def _search_wiki(query):
if not settings.CUSTOM_SEARCH_SERVER:
return []

service_info = ServiceInfo(query=query)
response = await get_service_response(wiki_search_settings, service_info, cache=False)

content = []
curr_block = ""
for line in response["text"].split("\n"):
curr_block += line + "\n"
if len(curr_block) > settings.CONTEXT_BLOCK_SIZE:
content.append(curr_block.strip())
curr_block = ""

if curr_block:
content.append(curr_block.strip())

return SearchData(content=content, query=query)
13 changes: 6 additions & 7 deletions app/services/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class ServiceNames(str, BaseEnum):
serply = "serply"
serpapi = "serpapi"
custom = "custom"


class ServiceSettings(BaseModel):
Expand All @@ -20,12 +21,10 @@ class ServiceInfo(BaseModel):
content: Optional[str] = None


class PDFData(BaseModel):
pdf_path: str
link: str
text_link: str | None = None
title: str
description: str
class SearchData(BaseModel):
content: List[str]
query: str | None = None
query: str
stored: bool = False
pdf_path: Optional[str] = None
link: Optional[str] = None
title: Optional[str] = None
34 changes: 20 additions & 14 deletions app/services/service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hashlib

from app.db.session import get_session
from app.services.adaptors.custom_search import custom_search_router
from app.services.adaptors.serpapi import serpapi_router
from app.services.adaptors.serply import serply_router
from app.services.dependencies import get_service_response_model
Expand All @@ -11,36 +12,41 @@
async def get_service_response(
service_settings: ServiceSettings,
service_info: ServiceInfo,
cache=True,
) -> dict:
hash = hashlib.sha512()
# Turn dict into list, sort keys, then hash. This ensures consistent order.
service_info_str = str(sorted(service_info.dict().items())).encode("utf-8")
hash.update(service_info_str)
hex = hash.hexdigest()

# Break if we've already run this query
service_model = await get_service_response_model(service_settings.name, hex)
if cache:
# Break if we've already run this query
service_model = await get_service_response_model(service_settings.name, hex)

if service_model is not None:
return service_model.response
if service_model is not None:
return service_model.response

match service_settings.name:
case ServiceNames.serply:
response = await serply_router(service_settings, service_info)
case ServiceNames.serpapi:
response = await serpapi_router(service_settings, service_info)
case ServiceNames.custom:
response = await custom_search_router(service_settings, service_info)
case _:
raise NotImplementedError("This Service type is not currently supported.")

async with get_session() as db:
# Save the response to the DB
service_model = ServiceResponse(
hash=hex,
request=service_info.dict(),
response=response,
name=service_settings.name,
)
db.add(service_model)
await db.commit()
if cache:
async with get_session() as db:
# Save the response to the DB
service_model = ServiceResponse(
hash=hex,
request=service_info.dict(),
response=response,
name=service_settings.name,
)
db.add(service_model)
await db.commit()

return response
9 changes: 7 additions & 2 deletions app/settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Literal, Optional
from typing import Literal, Optional, List

from dotenv import find_dotenv
from pydantic import BaseSettings
Expand Down Expand Up @@ -53,12 +53,17 @@ class Settings(BaseSettings):
SERPLY_KEY: str = ""
SERPAPI_KEY: str = ""
SEARCH_BACKEND: Optional[str] = "serply"
CUSTOM_SEARCH_SERVER: Optional[str] = None
CUSTOM_SEARCH_USER: Optional[str] = None
CUSTOM_SEARCH_PASSWORD: Optional[str] = None
CUSTOM_SEARCH_TYPES: Optional[List[str]] = ["wiki"]
CONTEXT_BLOCK_SIZE: int = 2200 # Characters per text block

# General
THREADS_PER_WORKER: int = 1 # How many threads to use per worker process to save RAM
RAY_CACHE_PATH: Optional[str] = None # Where to save ray cache
RAY_DASHBOARD_HOST: str = "0.0.0.0"
RAY_CORES_PER_WORKER = .5 # How many cpu cores to allocate per worker
RAY_CORES_PER_WORKER = 1 # How many cpu cores to allocate per worker

class Config:
env_file = find_dotenv("local.env")
Expand Down
2 changes: 1 addition & 1 deletion book_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def generate_single_course(model, course_data: Dict | str, revision=1, out
# Up to one retrieved passage per outline item
# Remove numbers from outline for use in retrieval
context_outline = [item.split(" ", 1)[-1] for item in outline]
context = await query_course_context(model, queries, context_outline)
context = await query_course_context(model, queries, context_outline, course_name)
except Exception as e:
debug_print_trace()
print(f"Error generating context for {course_name}: {e}")
Expand Down

0 comments on commit 610cdb8

Please sign in to comment.