Skip to content

Commit

Permalink
PDF Generation and Analysis Improvements
Browse files Browse the repository at this point in the history
- Simplified PDF generation configuration and HTML-based flashcard generator
- Added HTML-based PDF flashcard generator using Playwright and Jinja2
- Implemented multiprocessing for PDF image analysis with batch processing
- Added card type support and safety settings for PDF image analysis
- Created fixed grid with consistent card dimensions for flashcards
- Adjusted grid layout to fill page without gaps using dynamic configuration
- Removed card gaps by adjusting border and sizing
- Corrected row flipping logic in HTML template for double-sided printing
- Updated dependencies and refactored PDF generation
  • Loading branch information
rjpower committed Jan 8, 2025
1 parent e4d0ad7 commit 2b4381d
Show file tree
Hide file tree
Showing 10 changed files with 495 additions and 151 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ dependencies = [
"pytest>=8.3.4",
"reportlab>=4.2.5",
"typer>=0.15.1",
"jinja2>=3.1.5",
"playwright>=1.49.1",
]

[build-system]
Expand Down
26 changes: 13 additions & 13 deletions sample/n2.csv
Original file line number Diff line number Diff line change
@@ -1,16 +1,3 @@
2,題名
2,でたらめ
2,目上
2,成分
2,サイレン
2,毛糸
2,酌む
2,中世
2,雨戸
2,矢印
2,御中
2,目安
2,架空
2,詰まる
2,売行き
2,民謡
Expand Down Expand Up @@ -1832,3 +1819,16 @@
2,思い掛けない
2,煮える
2,的確
2,題名
2,でたらめ
2,目上
2,成分
2,サイレン
2,毛糸
2,酌む
2,中世
2,雨戸
2,矢印
2,御中
2,目安
2,架空
48 changes: 29 additions & 19 deletions scripts/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging
import sys
import traceback
from pathlib import Path
from typing import Optional

import typer
from rich import print

from srt.analyze_pdf import process_pdf_images
from srt.analyze_pdf import AnalyzePdfConfig, CardType, process_pdf_images
from srt.config import settings
from srt.lib import (
CSVProcessConfig,
Expand Down Expand Up @@ -48,6 +49,12 @@ def flashcards_from_pdf(
"-f",
help="Output format (pdf or apkg)",
),
card_type: CardType = typer.Option(
CardType.VOCAB,
"--card-type",
"-c",
help="Type of card to generate (vocab or sentences)",
),
):
"""Extract learning content from PDF files and create flashcards."""
output_dir = settings.output_dir
Expand All @@ -58,7 +65,14 @@ def flashcards_from_pdf(

print(f"[blue]Analyzing PDF: {pdf_path}[/blue]")

for progress in process_pdf_images(pdf_path, output_path, output_format):
config = AnalyzePdfConfig(
pdf_path=pdf_path,
output_path=output_path,
output_format=OutputFormat(output_format),
card_type=card_type,
)

for progress in process_pdf_images(config):
if progress.stage == "error":
print(f"[red]Error: {progress.error}[/red]", file=sys.stderr)
sys.exit(1)
Expand Down Expand Up @@ -112,26 +126,22 @@ def flashcards_from_csv(
else:
# Parse mapping string or infer mapping
field_mapping = None
with open(input_path, "r", encoding="utf-8") as f:
content = f.read()
separator, df = read_csv(content)

if mapping:
try:
# Parse mapping string (term=col,reading=col2)
pairs = dict(pair.split("=") for pair in mapping.split(","))
field_mapping = SourceMapping(
term=pairs.get("term"),
reading=pairs.get("reading"),
meaning=pairs.get("meaning"),
context_jp=pairs.get("context_jp"),
context_en=pairs.get("context_en"),
level=pairs.get("level"),
)
except Exception as e:
print(f"[red]Error parsing mapping: {e}[/red]", file=sys.stderr)
sys.exit(1)
pairs = dict(pair.split("=") for pair in mapping.split(","))
field_mapping = SourceMapping(
term=pairs.get("term"),
reading=pairs.get("reading"),
meaning=pairs.get("meaning"),
context_jp=pairs.get("context_jp"),
context_en=pairs.get("context_en"),
level=pairs.get("level"),
)
else:
# Infer mapping from CSV content
with open(input_path, "r", encoding="utf-8") as f:
content = f.read()
separator, df = read_csv(content)
result = infer_field_mapping(df)
field_mapping = SourceMapping.model_validate(result["suggested_mapping"])
print(
Expand Down
105 changes: 83 additions & 22 deletions srt/analyze_pdf.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import base64
import enum
import json
import logging
import multiprocessing.dummy
from dataclasses import dataclass
from pathlib import Path
from typing import Generator, List
from typing import Generator, List, Sequence

import litellm
from openai import audio

from srt.config import cached_completion
from srt.generate_pdf import PDFGeneratorConfig, create_flashcard_pdf
from srt.generate_pdf_html import PDFGeneratorConfig, create_flashcard_pdf
from srt.images_from_pdf import ImageData, PdfOptions, extract_images_from_pdf
from srt.lib import (
ConversionProgress,
Expand Down Expand Up @@ -81,12 +85,18 @@
"""


def analyze_pdf_images(images: List[ImageData]) -> List[RawFlashCard]:
"""Analyze all PDF page images and extract flashcards"""
class CardType(enum.StrEnum):
VOCAB = "vocab"
SENTENCE = "sentence"

# Convert all images to base64

def _analyze_image_batch(
image_batch: Sequence[ImageData], card_type: CardType
) -> List[RawFlashCard]:
"""Process a batch of images into flashcards"""
# Convert images to base64
image_contents = []
for image in images:
for image in image_batch:
image_b64 = base64.b64encode(image.content).decode("utf-8")
image_contents.append(
{
Expand All @@ -95,13 +105,33 @@ def analyze_pdf_images(images: List[ImageData]) -> List[RawFlashCard]:
}
)

prompt = VOCAB_PROMPT if card_type == CardType.VOCAB else SENTENCE_PROMPT

response = cached_completion(
messages=[
{
"role": "user",
"content": [{"type": "text", "text": SENTENCE_PROMPT}] + image_contents,
"content": [{"type": "text", "text": prompt}] + image_contents,
}
],
safety_settings=[
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH",
},
],
response_format={"type": "json_object"},
)

Expand All @@ -116,14 +146,40 @@ def analyze_pdf_images(images: List[ImageData]) -> List[RawFlashCard]:
back_context=card_data.get("back_context"),
)
cards.append(card)

# dedup
cards = list({card.front: card for card in cards}.values())
return cards


def analyze_pdf_images(
images: List[ImageData], card_type: CardType, batch_size: int = 4
) -> List[RawFlashCard]:
"""Analyze all PDF page images and extract flashcards using parallel processing"""

with multiprocessing.dummy.Pool(processes=4) as pool:
# Process images in batches
batch_results = pool.starmap(
_analyze_image_batch,
[
(images[i : i + batch_size], card_type)
for i in range(0, len(images), batch_size)
],
chunksize=1,
)

# Flatten results and remove duplicates
all_cards = [card for batch in batch_results for card in batch]
return list({card.front: card for card in all_cards}.values())


@dataclass
class AnalyzePdfConfig:
pdf_path: Path
output_path: Path
output_format: OutputFormat
card_type: CardType


def process_pdf_images(
pdf_path: Path, output_path: Path, output_format: OutputFormat
config: AnalyzePdfConfig,
) -> Generator[ConversionProgress, None, None]:
"""Process PDF file by converting to images and analyzing with Gemini Vision"""

Expand All @@ -138,7 +194,7 @@ def process_pdf_images(
progress=10,
)

images = extract_images_from_pdf(pdf_path, PdfOptions())
images = extract_images_from_pdf(config.pdf_path, PdfOptions())

# Analyze all images together
yield ConversionProgress(
Expand All @@ -147,29 +203,34 @@ def process_pdf_images(
progress=50,
)

all_cards = analyze_pdf_images(images)
all_cards = analyze_pdf_images(images, card_type=config.card_type)

# Export cards
yield ConversionProgress(
stage=ConversionStage.PROCESSING,
message=f"Exporting to {output_format}",
message=f"Exporting to {config.output_format}",
progress=90,
)

if output_format == OutputFormat.ANKI_PKG:
create_anki_package(output_path, all_cards, pdf_path.stem, audio_mapping={})
if config.output_format == OutputFormat.ANKI_PKG:
create_anki_package(
config.output_path,
vocab_items=all_cards,
deck_name=config.pdf_path.stem,
audio_mapping={},
)
else:
config = PDFGeneratorConfig(
columns=2,
rows=4,
pdf_config = PDFGeneratorConfig(
columns=3 if config.card_type == CardType.VOCAB else 2,
rows=6 if config.card_type == CardType.VOCAB else 4,
cards=all_cards,
output_path=output_path,
output_path=config.output_path,
)
create_flashcard_pdf(config)
create_flashcard_pdf(pdf_config)

yield ConversionProgress(
stage=ConversionStage.COMPLETE,
message="Processing complete",
progress=100,
filename=output_path.name,
filename=config.output_path.name,
)
2 changes: 1 addition & 1 deletion srt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class Settings(BaseSettings):
default=100, description="Number of text blocks to process in each LLM batch"
)
llm_model: str = Field(
default="gemini/gemini-2.0-flash-exp",
default="gemini/gemini-1.5-flash",
description="LLM model to use for vocabulary analysis",
)

Expand Down
18 changes: 9 additions & 9 deletions srt/generate_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def draw_wrapped_text(
font_name,
font_size,
)
current_line = []
current_line = [item]
current_x = line_start_x + item.get_width(canvas)
continue

# split the line, draw the current line and move on
Expand All @@ -167,7 +168,7 @@ def draw_wrapped_text(
remaining_width - space_width,
)

if current_text.text:
if current_text:
current_line.append(current_text)

current_y = draw_text(
Expand Down Expand Up @@ -236,25 +237,24 @@ def wrap_text(
- current_line is Text object that fits within remaining_width
- remaining_text is Text object that needs to wrap to next line
"""
if not text_obj.text:
return Text(
text="", font_name=text_obj.font_name, font_size=text_obj.font_size
), Text(text="", font_name=text_obj.font_name, font_size=text_obj.font_size)

words = text_obj.text.split()
current_line = []
next_line = []

current_width = 0
space_width = canvas.stringWidth(" ", text_obj.font_name, text_obj.font_size)

for word in words:
while words:
word = words[0]
word_width = canvas.stringWidth(word, text_obj.font_name, text_obj.font_size)
if current_width + word_width <= remaining_width:
current_line.append(word)
current_width += word_width + space_width
words.pop(0)
else:
next_line.append(word)
break

next_line = words

return (
Text(
Expand Down
Loading

0 comments on commit 2b4381d

Please sign in to comment.