forked from 117503445/flow-pdf
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: 117503445 <[email protected]>
- Loading branch information
Showing
6 changed files
with
396 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
""" | ||
Copyright (c) Meta Platforms, Inc. and affiliates. | ||
This source code is licensed under the MIT license found in the | ||
LICENSE file in the root directory of this source tree. | ||
""" | ||
import os | ||
import sys | ||
from functools import partial | ||
from http import HTTPStatus | ||
from fastapi import FastAPI, File, UploadFile | ||
from PIL import Image | ||
from pathlib import Path | ||
import hashlib | ||
from fastapi.middleware.cors import CORSMiddleware | ||
import pypdfium2 | ||
import torch | ||
from nougat import NougatModel | ||
from nougat.postprocessing import markdown_compatible, close_envs | ||
from nougat.utils.dataset import ImageDataset | ||
from nougat.utils.checkpoint import get_checkpoint | ||
from nougat.dataset.rasterize import rasterize_paper | ||
from nougat.utils.device import move_to_device, default_batch_size | ||
from tqdm import tqdm | ||
|
||
|
||
SAVE_DIR = Path("./pdfs") | ||
BATCHSIZE = int(os.environ.get("NOUGAT_BATCHSIZE", default_batch_size())) | ||
NOUGAT_CHECKPOINT = get_checkpoint() | ||
if NOUGAT_CHECKPOINT is None: | ||
print( | ||
"Set environment variable 'NOUGAT_CHECKPOINT' with a path to the model checkpoint!" | ||
) | ||
sys.exit(1) | ||
|
||
model = None | ||
|
||
def load_model( | ||
checkpoint: str = NOUGAT_CHECKPOINT, | ||
): | ||
global model, BATCHSIZE | ||
if model is None: | ||
model = NougatModel.from_pretrained(checkpoint) | ||
model = move_to_device(model, cuda=BATCHSIZE > 0) | ||
if BATCHSIZE <= 0: | ||
BATCHSIZE = 1 | ||
model.eval() | ||
|
||
def predict() -> str: | ||
""" | ||
Perform predictions on a PDF document and return the extracted text in Markdown format. | ||
Args: | ||
file (UploadFile): The uploaded PDF file to process. | ||
start (int, optional): The starting page number for prediction. | ||
stop (int, optional): The ending page number for prediction. | ||
Returns: | ||
str: The extracted text in Markdown format. | ||
""" | ||
|
||
with open('hotstuff.pdf', 'rb') as f: | ||
pdfbin = f.read() | ||
|
||
pdf = pypdfium2.PdfDocument(pdfbin) | ||
md5 = hashlib.md5(pdfbin).hexdigest() | ||
save_path = SAVE_DIR / md5 | ||
|
||
pages = list(range(len(pdf))) | ||
predictions = [""] * len(pages) | ||
dellist = [] | ||
if save_path.exists(): | ||
for computed in (save_path / "pages").glob("*.mmd"): | ||
try: | ||
idx = int(computed.stem) - 1 | ||
if idx in pages: | ||
i = pages.index(idx) | ||
print("skip page", idx + 1) | ||
predictions[i] = computed.read_text(encoding="utf-8") | ||
dellist.append(idx) | ||
except Exception as e: | ||
print(e) | ||
compute_pages = pages.copy() | ||
for el in dellist: | ||
compute_pages.remove(el) | ||
images = rasterize_paper(pdf, pages=compute_pages) | ||
global model | ||
|
||
dataset = ImageDataset( | ||
images, | ||
partial(model.encoder.prepare_input, random_padding=False), | ||
) | ||
|
||
dataloader = torch.utils.data.DataLoader( | ||
dataset, | ||
batch_size=BATCHSIZE, | ||
pin_memory=True, | ||
shuffle=False, | ||
) | ||
|
||
for idx, sample in tqdm(enumerate(dataloader), total=len(dataloader)): | ||
if sample is None: | ||
continue | ||
model_output = model.inference(image_tensors=sample) | ||
for j, output in enumerate(model_output["predictions"]): | ||
if model_output["repeats"][j] is not None: | ||
if model_output["repeats"][j] > 0: | ||
disclaimer = "\n\n+++ ==WARNING: Truncated because of repetitions==\n%s\n+++\n\n" | ||
else: | ||
disclaimer = ( | ||
"\n\n+++ ==ERROR: No output for this page==\n%s\n+++\n\n" | ||
) | ||
rest = close_envs(model_output["repetitions"][j]).strip() | ||
if len(rest) > 0: | ||
disclaimer = disclaimer % rest | ||
else: | ||
disclaimer = "" | ||
else: | ||
disclaimer = "" | ||
|
||
predictions[pages.index(compute_pages[idx * BATCHSIZE + j])] = ( | ||
markdown_compatible(output) + disclaimer | ||
) | ||
|
||
(save_path / "pages").mkdir(parents=True, exist_ok=True) | ||
pdf.save(save_path / "doc.pdf") | ||
if len(images) > 0: | ||
thumb = Image.open(images[0]) | ||
thumb.thumbnail((400, 400)) | ||
thumb.save(save_path / "thumb.jpg") | ||
for idx, page_num in enumerate(pages): | ||
(save_path / "pages" / ("%02d.mmd" % (page_num + 1))).write_text( | ||
predictions[idx], encoding="utf-8" | ||
) | ||
final = "".join(predictions).strip() | ||
(save_path / "doc.mmd").write_text(final, encoding="utf-8") | ||
return final | ||
|
||
|
||
def main(): | ||
load_model() | ||
|
||
final = predict() | ||
with open('hotstuff.mmd', 'w') as f: | ||
f.write(final) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
""" | ||
Copyright (c) Meta Platforms, Inc. and affiliates. | ||
This source code is licensed under the MIT license found in the | ||
LICENSE file in the root directory of this source tree. | ||
""" | ||
import sys | ||
from pathlib import Path | ||
import logging | ||
import re | ||
import argparse | ||
import re | ||
from functools import partial | ||
import torch | ||
from torch.utils.data import ConcatDataset | ||
from tqdm import tqdm | ||
from nougat import NougatModel | ||
from nougat.utils.dataset import LazyDataset | ||
from nougat.utils.device import move_to_device, default_batch_size | ||
from nougat.utils.checkpoint import get_checkpoint | ||
from nougat.postprocessing import markdown_compatible | ||
import pypdf | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--batchsize", | ||
"-b", | ||
type=int, | ||
default=default_batch_size(), | ||
help="Batch size to use.", | ||
) | ||
parser.add_argument( | ||
"--checkpoint", | ||
"-c", | ||
type=Path, | ||
default=None, | ||
help="Path to checkpoint directory.", | ||
) | ||
parser.add_argument( | ||
"--model", | ||
"-m", | ||
type=str, | ||
default="0.1.0-small", | ||
help=f"Model tag to use.", | ||
) | ||
parser.add_argument("--out", "-o", type=Path, help="Output directory.") | ||
parser.add_argument( | ||
"--recompute", | ||
action="store_true", | ||
help="Recompute already computed PDF, discarding previous predictions.", | ||
) | ||
parser.add_argument( | ||
"--markdown", | ||
action="store_true", | ||
help="Add postprocessing step for markdown compatibility.", | ||
) | ||
parser.add_argument( | ||
"--no-skipping", | ||
dest="skipping", | ||
action="store_false", | ||
help="Don't apply failure detection heuristic.", | ||
) | ||
parser.add_argument( | ||
"--pages", | ||
"-p", | ||
type=str, | ||
help="Provide page numbers like '1-4,7' for pages 1 through 4 and page 7. Only works for single PDF input.", | ||
) | ||
parser.add_argument("pdf", nargs="+", type=Path, help="PDF(s) to process.") | ||
args = parser.parse_args() | ||
if args.checkpoint is None or not args.checkpoint.exists(): | ||
args.checkpoint = get_checkpoint(args.checkpoint, model_tag=args.model) | ||
if args.out is None: | ||
logging.warning("No output directory. Output will be printed to console.") | ||
else: | ||
if not args.out.exists(): | ||
logging.info("Output directory does not exist. Creating output directory.") | ||
args.out.mkdir(parents=True) | ||
if not args.out.is_dir(): | ||
logging.error("Output has to be directory.") | ||
sys.exit(1) | ||
if len(args.pdf) == 1 and not args.pdf[0].suffix == ".pdf": | ||
# input is a list of pdfs | ||
try: | ||
args.pdf = [ | ||
Path(l) for l in open(args.pdf[0]).read().split("\n") if len(l) > 0 | ||
] | ||
except: | ||
pass | ||
if args.pages and len(args.pdf) == 1: | ||
pages = [] | ||
for p in args.pages.split(","): | ||
if "-" in p: | ||
start, end = p.split("-") | ||
pages.extend(range(int(start) - 1, int(end))) | ||
else: | ||
pages.append(int(p) - 1) | ||
args.pages = pages | ||
else: | ||
args.pages = None | ||
return args | ||
|
||
|
||
def main(): | ||
args = get_args() | ||
model = NougatModel.from_pretrained(args.checkpoint) | ||
if args.batchsize > 0: | ||
model = move_to_device(model) | ||
else: | ||
# set batch size to 1. Need to check if there are benefits for CPU conversion for >1 | ||
args.batchsize = 1 | ||
model.eval() | ||
datasets = [] | ||
for pdf in args.pdf: | ||
if not pdf.exists(): | ||
continue | ||
if args.out: | ||
out_path = args.out / pdf.with_suffix(".mmd").name | ||
if out_path.exists() and not args.recompute: | ||
logging.info( | ||
f"Skipping {pdf.name}, already computed. Run with --recompute to convert again." | ||
) | ||
continue | ||
try: | ||
dataset = LazyDataset( | ||
pdf, | ||
partial(model.encoder.prepare_input, random_padding=False), | ||
args.pages, | ||
) | ||
except pypdf.errors.PdfStreamError: | ||
logging.info(f"Could not load file {str(pdf)}.") | ||
continue | ||
datasets.append(dataset) | ||
if len(datasets) == 0: | ||
return | ||
dataloader = torch.utils.data.DataLoader( | ||
ConcatDataset(datasets), | ||
batch_size=args.batchsize, | ||
shuffle=False, | ||
collate_fn=LazyDataset.ignore_none_collate, | ||
) | ||
|
||
predictions = [] | ||
file_index = 0 | ||
page_num = 0 | ||
for i, (sample, is_last_page) in enumerate(tqdm(dataloader)): | ||
model_output = model.inference( | ||
image_tensors=sample, early_stopping=args.skipping | ||
) | ||
# check if model output is faulty | ||
for j, output in enumerate(model_output["predictions"]): | ||
if page_num == 0: | ||
logging.info( | ||
"Processing file %s with %i pages" | ||
% (datasets[file_index].name, datasets[file_index].size) | ||
) | ||
page_num += 1 | ||
if output.strip() == "[MISSING_PAGE_POST]": | ||
# uncaught repetitions -- most likely empty page | ||
predictions.append(f"\n\n[MISSING_PAGE_EMPTY:{page_num}]\n\n") | ||
elif args.skipping and model_output["repeats"][j] is not None: | ||
if model_output["repeats"][j] > 0: | ||
# If we end up here, it means the output is most likely not complete and was truncated. | ||
logging.warning(f"Skipping page {page_num} due to repetitions.") | ||
predictions.append(f"\n\n[MISSING_PAGE_FAIL:{page_num}]\n\n") | ||
else: | ||
# If we end up here, it means the document page is too different from the training domain. | ||
# This can happen e.g. for cover pages. | ||
predictions.append( | ||
f"\n\n[MISSING_PAGE_EMPTY:{i*args.batchsize+j+1}]\n\n" | ||
) | ||
else: | ||
if args.markdown: | ||
output = markdown_compatible(output) | ||
predictions.append(output) | ||
if is_last_page[j]: | ||
out = "".join(predictions).strip() | ||
out = re.sub(r"\n{3,}", "\n\n", out).strip() | ||
if args.out: | ||
out_path = args.out / Path(is_last_page[j]).with_suffix(".mmd").name | ||
out_path.parent.mkdir(parents=True, exist_ok=True) | ||
out_path.write_text(out, encoding="utf-8") | ||
else: | ||
print(out, "\n\n") | ||
predictions = [] | ||
page_num = 0 | ||
file_index += 1 | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
data | ||
data | ||
Oops, something went wrong.