Skip to content

Commit

Permalink
feat: holder test for nougat
Browse files Browse the repository at this point in the history
Signed-off-by: 117503445 <[email protected]>
  • Loading branch information
117503445 committed Oct 14, 2023
1 parent 5b51a9b commit e6b5765
Show file tree
Hide file tree
Showing 6 changed files with 396 additions and 11 deletions.
10 changes: 1 addition & 9 deletions experiment/nougat-demo/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime

# RUN apt-get update && apt-get install gcc git ffmpeg libsm6 libxext6 -y
# RUN pip uninstall numpy pillow -y
# RUN pip install Pillow==9.1
# RUN python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.8/index.html
# RUN pip install layoutparser pymupdf
# COPY download.py /scripts/download.py
# RUN python /scripts/download.py

RUN pip install nougat-ocr[api]
# RUN pip install nougat-ocr[api]

WORKDIR /workspace/nougat-demo

Expand Down
149 changes: 149 additions & 0 deletions experiment/nougat-demo/img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import os
import sys
from functools import partial
from http import HTTPStatus
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from pathlib import Path
import hashlib
from fastapi.middleware.cors import CORSMiddleware
import pypdfium2
import torch
from nougat import NougatModel
from nougat.postprocessing import markdown_compatible, close_envs
from nougat.utils.dataset import ImageDataset
from nougat.utils.checkpoint import get_checkpoint
from nougat.dataset.rasterize import rasterize_paper
from nougat.utils.device import move_to_device, default_batch_size
from tqdm import tqdm


SAVE_DIR = Path("./pdfs")
BATCHSIZE = int(os.environ.get("NOUGAT_BATCHSIZE", default_batch_size()))
NOUGAT_CHECKPOINT = get_checkpoint()
if NOUGAT_CHECKPOINT is None:
print(
"Set environment variable 'NOUGAT_CHECKPOINT' with a path to the model checkpoint!"
)
sys.exit(1)

model = None

def load_model(
checkpoint: str = NOUGAT_CHECKPOINT,
):
global model, BATCHSIZE
if model is None:
model = NougatModel.from_pretrained(checkpoint)
model = move_to_device(model, cuda=BATCHSIZE > 0)
if BATCHSIZE <= 0:
BATCHSIZE = 1
model.eval()

def predict() -> str:
"""
Perform predictions on a PDF document and return the extracted text in Markdown format.
Args:
file (UploadFile): The uploaded PDF file to process.
start (int, optional): The starting page number for prediction.
stop (int, optional): The ending page number for prediction.
Returns:
str: The extracted text in Markdown format.
"""

with open('hotstuff.pdf', 'rb') as f:
pdfbin = f.read()

pdf = pypdfium2.PdfDocument(pdfbin)
md5 = hashlib.md5(pdfbin).hexdigest()
save_path = SAVE_DIR / md5

pages = list(range(len(pdf)))
predictions = [""] * len(pages)
dellist = []
if save_path.exists():
for computed in (save_path / "pages").glob("*.mmd"):
try:
idx = int(computed.stem) - 1
if idx in pages:
i = pages.index(idx)
print("skip page", idx + 1)
predictions[i] = computed.read_text(encoding="utf-8")
dellist.append(idx)
except Exception as e:
print(e)
compute_pages = pages.copy()
for el in dellist:
compute_pages.remove(el)
images = rasterize_paper(pdf, pages=compute_pages)
global model

dataset = ImageDataset(
images,
partial(model.encoder.prepare_input, random_padding=False),
)

dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=BATCHSIZE,
pin_memory=True,
shuffle=False,
)

for idx, sample in tqdm(enumerate(dataloader), total=len(dataloader)):
if sample is None:
continue
model_output = model.inference(image_tensors=sample)
for j, output in enumerate(model_output["predictions"]):
if model_output["repeats"][j] is not None:
if model_output["repeats"][j] > 0:
disclaimer = "\n\n+++ ==WARNING: Truncated because of repetitions==\n%s\n+++\n\n"
else:
disclaimer = (
"\n\n+++ ==ERROR: No output for this page==\n%s\n+++\n\n"
)
rest = close_envs(model_output["repetitions"][j]).strip()
if len(rest) > 0:
disclaimer = disclaimer % rest
else:
disclaimer = ""
else:
disclaimer = ""

predictions[pages.index(compute_pages[idx * BATCHSIZE + j])] = (
markdown_compatible(output) + disclaimer
)

(save_path / "pages").mkdir(parents=True, exist_ok=True)
pdf.save(save_path / "doc.pdf")
if len(images) > 0:
thumb = Image.open(images[0])
thumb.thumbnail((400, 400))
thumb.save(save_path / "thumb.jpg")
for idx, page_num in enumerate(pages):
(save_path / "pages" / ("%02d.mmd" % (page_num + 1))).write_text(
predictions[idx], encoding="utf-8"
)
final = "".join(predictions).strip()
(save_path / "doc.mmd").write_text(final, encoding="utf-8")
return final


def main():
load_model()

final = predict()
with open('hotstuff.mmd', 'w') as f:
f.write(final)


if __name__ == "__main__":
main()
7 changes: 6 additions & 1 deletion experiment/nougat-demo/note.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,9 @@ python main.py

# ARG DEBIAN_FRONTEND=noninteractive
apt update && apt install nodejs npm git
npm config set registry https://registry.npmmirror.com/
npm config set registry https://registry.npmmirror.com/

python predict.py hotstuff.pdf -o out -m 0.1.0-base --markdown


nougat hotstuff-holder.pdf -o out -m 0.1.0-base --markdown
195 changes: 195 additions & 0 deletions experiment/nougat-demo/predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import sys
from pathlib import Path
import logging
import re
import argparse
import re
from functools import partial
import torch
from torch.utils.data import ConcatDataset
from tqdm import tqdm
from nougat import NougatModel
from nougat.utils.dataset import LazyDataset
from nougat.utils.device import move_to_device, default_batch_size
from nougat.utils.checkpoint import get_checkpoint
from nougat.postprocessing import markdown_compatible
import pypdf

logging.basicConfig(level=logging.INFO)


def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--batchsize",
"-b",
type=int,
default=default_batch_size(),
help="Batch size to use.",
)
parser.add_argument(
"--checkpoint",
"-c",
type=Path,
default=None,
help="Path to checkpoint directory.",
)
parser.add_argument(
"--model",
"-m",
type=str,
default="0.1.0-small",
help=f"Model tag to use.",
)
parser.add_argument("--out", "-o", type=Path, help="Output directory.")
parser.add_argument(
"--recompute",
action="store_true",
help="Recompute already computed PDF, discarding previous predictions.",
)
parser.add_argument(
"--markdown",
action="store_true",
help="Add postprocessing step for markdown compatibility.",
)
parser.add_argument(
"--no-skipping",
dest="skipping",
action="store_false",
help="Don't apply failure detection heuristic.",
)
parser.add_argument(
"--pages",
"-p",
type=str,
help="Provide page numbers like '1-4,7' for pages 1 through 4 and page 7. Only works for single PDF input.",
)
parser.add_argument("pdf", nargs="+", type=Path, help="PDF(s) to process.")
args = parser.parse_args()
if args.checkpoint is None or not args.checkpoint.exists():
args.checkpoint = get_checkpoint(args.checkpoint, model_tag=args.model)
if args.out is None:
logging.warning("No output directory. Output will be printed to console.")
else:
if not args.out.exists():
logging.info("Output directory does not exist. Creating output directory.")
args.out.mkdir(parents=True)
if not args.out.is_dir():
logging.error("Output has to be directory.")
sys.exit(1)
if len(args.pdf) == 1 and not args.pdf[0].suffix == ".pdf":
# input is a list of pdfs
try:
args.pdf = [
Path(l) for l in open(args.pdf[0]).read().split("\n") if len(l) > 0
]
except:
pass
if args.pages and len(args.pdf) == 1:
pages = []
for p in args.pages.split(","):
if "-" in p:
start, end = p.split("-")
pages.extend(range(int(start) - 1, int(end)))
else:
pages.append(int(p) - 1)
args.pages = pages
else:
args.pages = None
return args


def main():
args = get_args()
model = NougatModel.from_pretrained(args.checkpoint)
if args.batchsize > 0:
model = move_to_device(model)
else:
# set batch size to 1. Need to check if there are benefits for CPU conversion for >1
args.batchsize = 1
model.eval()
datasets = []
for pdf in args.pdf:
if not pdf.exists():
continue
if args.out:
out_path = args.out / pdf.with_suffix(".mmd").name
if out_path.exists() and not args.recompute:
logging.info(
f"Skipping {pdf.name}, already computed. Run with --recompute to convert again."
)
continue
try:
dataset = LazyDataset(
pdf,
partial(model.encoder.prepare_input, random_padding=False),
args.pages,
)
except pypdf.errors.PdfStreamError:
logging.info(f"Could not load file {str(pdf)}.")
continue
datasets.append(dataset)
if len(datasets) == 0:
return
dataloader = torch.utils.data.DataLoader(
ConcatDataset(datasets),
batch_size=args.batchsize,
shuffle=False,
collate_fn=LazyDataset.ignore_none_collate,
)

predictions = []
file_index = 0
page_num = 0
for i, (sample, is_last_page) in enumerate(tqdm(dataloader)):
model_output = model.inference(
image_tensors=sample, early_stopping=args.skipping
)
# check if model output is faulty
for j, output in enumerate(model_output["predictions"]):
if page_num == 0:
logging.info(
"Processing file %s with %i pages"
% (datasets[file_index].name, datasets[file_index].size)
)
page_num += 1
if output.strip() == "[MISSING_PAGE_POST]":
# uncaught repetitions -- most likely empty page
predictions.append(f"\n\n[MISSING_PAGE_EMPTY:{page_num}]\n\n")
elif args.skipping and model_output["repeats"][j] is not None:
if model_output["repeats"][j] > 0:
# If we end up here, it means the output is most likely not complete and was truncated.
logging.warning(f"Skipping page {page_num} due to repetitions.")
predictions.append(f"\n\n[MISSING_PAGE_FAIL:{page_num}]\n\n")
else:
# If we end up here, it means the document page is too different from the training domain.
# This can happen e.g. for cover pages.
predictions.append(
f"\n\n[MISSING_PAGE_EMPTY:{i*args.batchsize+j+1}]\n\n"
)
else:
if args.markdown:
output = markdown_compatible(output)
predictions.append(output)
if is_last_page[j]:
out = "".join(predictions).strip()
out = re.sub(r"\n{3,}", "\n\n", out).strip()
if args.out:
out_path = args.out / Path(is_last_page[j]).with_suffix(".mmd").name
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(out, encoding="utf-8")
else:
print(out, "\n\n")
predictions = []
page_num = 0
file_index += 1


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion src/layout-parser/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
data
data
*.pdf
Loading

0 comments on commit e6b5765

Please sign in to comment.