forked from lukas-blecher/LaTeX-OCR
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
39,716 additions
and
108 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -134,3 +134,4 @@ dataset/data/** | |
wandb/ | ||
checkpoints/** | ||
!**/.gitkeep | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,85 +1,13 @@ | ||
# pix2tex - LaTeX OCR | ||
The goal of this project is to create a learning based system that takes an image of a math formula and returns corresponding LaTeX code. As a physics student I often find myself writing down Latex code from a reference image. I wanted to streamline my workflow and began looking into solutions, but besides the Freemium [Mathpix](https://mathpix.com/) I could not find anything ready-to-use that runs locally. That's why I decided to create it myself. | ||
Convert images into LaTeX code. A basic desktop app for https://github.com/lukas-blecher/LaTeX-OCR | ||
|
||
![header](https://user-images.githubusercontent.com/55287601/109183599-69431f00-778e-11eb-9809-d42b9451e018.png) | ||
## Demo | ||
![demo](demo.gif) | ||
|
||
## Requirements | ||
### Model | ||
* PyTorch (tested on v1.7.1) | ||
* Python 3.7+ & dependencies (`requirements.txt`) | ||
``` | ||
pip install -r requirements.txt | ||
``` | ||
### Dataset | ||
In order to render the math in many different fonts we use XeLaTeX, generate a PDF and finally convert it to a PNG. For the last step we need to use some third party tools: | ||
* [XeLaTeX](https://www.ctan.org/pkg/xetex) | ||
* [ImageMagick](https://imagemagick.org/) with [Ghostscript](https://www.ghostscript.com/index.html). (for converting pdf to png) | ||
* [Node.js](https://nodejs.org/) to run [KaTeX](https://github.com/KaTeX/KaTeX) (for normalizing Latex code) | ||
* [`de-macro`](https://www.ctan.org/pkg/de-macro) >= 1.4 (only for parsing arxiv papers) | ||
* Python 3.7+ & dependencies (`requirements.txt`) | ||
## Usage | ||
Follow the [usage instructions here](https://github.com/lukas-blecher/LaTeX-OCR#using-the-model) (note this project has extra dependencies!) and run ```main.py```. | ||
|
||
## Using the model | ||
1. Download/Clone this repository | ||
2. For now you need to install the Python dependencies specified in `requirements.txt` (look [above](#Requirements)) | ||
3. Download the `weights.pth` (and optionally `image_resizer.pth`) file from my [Google Drive](https://drive.google.com/drive/folders/1cgmyiaT5uwQJY2pB0ngebuTcK5ivKXIb) and place it in the `checkpoints` directory | ||
|
||
The `pix2tex.py` file offers a quick way to get the model prediction of an image. First you need to copy the formula image into the clipboard memory for example by using a snipping tool (on Windows built in `Win`+`Shift`+`S`). Next just call the script with `python pix2tex.py`. It will print out the predicted Latex code for that image and also copy it into your clipboard. | ||
|
||
**Note:** As of right now it works best with images of smaller resolution. Don't zoom in all the way before taking a picture. Double check the result carefully. You can try to redo the prediction with an other resolution if the answer was wrong. | ||
|
||
**Update:** I have trained an image classifier on randomly scaled images of the training data to predict the original size. | ||
This model will automatically resize the custom image to best resemble the training data and thus increase performance of images found in the wild. To use this preprocessing step, all you have to do is download the second weights file mentioned above. You should be able to take bigger (or smaller) images of the formula and still get a satisfying result | ||
|
||
## Training the model | ||
1. First we need to combine the images with their ground truth labels. I wrote a dataset class (which needs further improving) that saves the relative paths to the images with the LaTeX code they were rendered with. To generate the dataset pickle file run | ||
|
||
``` | ||
python dataset/dataset.py --equations path_to_textfile --images path_to_images --tokenizer path_to_tokenizer --out dataset.pkl | ||
``` | ||
|
||
You can find my generated training data on the [Google Drive](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO) as well (formulae.zip - images, math.txt - labels). Repeat the step for the validation and test data. All use the same label text file. | ||
|
||
2. Edit the `data` entry in the config file to the newly generated `.pkl` file. Change other hyperparameters if you want to. See `settings/default.yaml` for a template. | ||
3. Now for the actual training run | ||
``` | ||
python train.py --config path_to_config_file | ||
``` | ||
|
||
|
||
## Model | ||
The model consist of a ViT [[1](#References)] encoder with a ResNet backbone and a Transformer [[2](#References)] decoder. | ||
|
||
### Performance | ||
|BLEU score | normed edit distance| | ||
|-|-| | ||
|0.88|0.10| | ||
|
||
## Data | ||
We need paired data for the network to learn. Luckily there is a lot of LaTeX code on the internet, e.g. [wikipedia](www.wikipedia.org), [arXiv](www.arxiv.org). We also use the formulae from the [im2latex-100k](https://zenodo.org/record/56198#.V2px0jXT6eA) dataset. | ||
All of it can be found [here](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO) | ||
|
||
### Fonts | ||
Latin Modern Math, GFSNeohellenicMath.otf, Asana Math, XITS Math, Cambria Math | ||
|
||
|
||
## TODO | ||
- [ ] support handwritten formulae | ||
- [ ] reduce model size (distillation) | ||
- [ ] find optimal hyperparameters | ||
- [ ] tweak model structure | ||
- [x] add more evaluation metrics | ||
- [ ] fix data scraping and scape more data | ||
- [ ] trace the model | ||
- [ ] create a standalone application | ||
|
||
|
||
## Contribution | ||
Contributions of any kind are welcome. | ||
|
||
## Acknowledgement | ||
Code taken and modified from [lucidrains](https://github.com/lucidrains), [rwightman](https://github.com/rwightman/pytorch-image-models), [im2markup](https://github.com/harvardnlp/im2markup), [arxiv_leaks](https://github.com/soskek/arxiv_leaks) | ||
|
||
## References | ||
[1] [An Image is Worth 16x16 Words](https://arxiv.org/abs/2010.11929) | ||
|
||
[2] [Attention Is All You Need](https://arxiv.org/abs/1706.03762) | ||
## Acknowledgements | ||
This project uses code from | ||
- https://github.com/pkra/MathJax-single-file | ||
- https://github.com/harupy/snipping-tool |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
import sys | ||
from PyQt5 import QtCore, QtGui | ||
from PyQt5.QtCore import QObject, Qt, pyqtSlot, pyqtSignal, QThread | ||
from PyQt5.QtWebEngineWidgets import QWebEngineView | ||
from PyQt5.QtWidgets import QMainWindow, QApplication, QMessageBox, QVBoxLayout, QWidget, QPushButton, QTextEdit | ||
import resources | ||
from pynput.mouse import Controller | ||
|
||
import tkinter as tk | ||
from PIL import ImageGrab | ||
|
||
import pix2tex | ||
|
||
QApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling, True) | ||
QApplication.setAttribute(QtCore.Qt.AA_UseHighDpiPixmaps, True) | ||
|
||
class App(QMainWindow): | ||
def __init__(self): | ||
super().__init__() | ||
self.initModel() | ||
self.initUI() | ||
self.snipWidget = SnipWidget(self) | ||
|
||
self.show() | ||
|
||
def initModel(self): | ||
args, *objs = pix2tex.initialize() | ||
self.args = args | ||
self.objs = objs | ||
|
||
|
||
def initUI(self): | ||
self.setWindowTitle("LaTeX OCR") | ||
QApplication.setWindowIcon(QtGui.QIcon(':/icons/icon.svg')) | ||
self.left = 300 | ||
self.top = 300 | ||
self.width = 300 | ||
self.height = 200 | ||
self.setGeometry(self.left, self.top, self.width, self.height) | ||
|
||
|
||
# Create LaTeX display | ||
self.webView = QWebEngineView() | ||
self.webView.setHtml("") | ||
self.webView.setMinimumHeight(70) | ||
|
||
|
||
# Create textbox | ||
self.textbox = QTextEdit(self) | ||
|
||
# Create snip button | ||
self.snipButton = QPushButton('Snip', self) | ||
self.snipButton.clicked.connect(self.onClick) | ||
|
||
# Create layout | ||
centralWidget = QWidget() | ||
centralWidget.setMinimumWidth(200) | ||
self.setCentralWidget(centralWidget) | ||
|
||
lay = QVBoxLayout(centralWidget) | ||
lay.addWidget(self.webView, stretch=2) | ||
lay.addWidget(self.textbox, stretch=3) | ||
lay.addWidget(self.snipButton) | ||
|
||
|
||
@pyqtSlot() | ||
def onClick(self): | ||
self.close() | ||
self.snipWidget.snip() | ||
|
||
|
||
def returnSnip(self, img): | ||
# Show processing icon | ||
pageSource = """<center> | ||
<img src="qrc:/icons/processing-icon-anim.svg" width="50", height="50"> | ||
</center>""" | ||
self.webView.setHtml(pageSource) | ||
self.textbox.setText("") | ||
|
||
self.snipButton.setEnabled(False) | ||
|
||
self.show() | ||
|
||
# Run the model in a separate thread | ||
self.thread = ModelThread(img, self.args, self.objs) | ||
self.thread.finished.connect(self.returnPrediction) | ||
self.thread.finished.connect(self.thread.deleteLater) | ||
|
||
self.thread.start() | ||
|
||
|
||
def returnPrediction(self, result): | ||
self.snipButton.setEnabled(True) | ||
|
||
success, prediction = result["success"], result["prediction"] | ||
|
||
if success: | ||
self.displayPrediction(prediction) | ||
else: | ||
self.webView.setHtml("") | ||
msg = QMessageBox() | ||
msg.setWindowTitle(" ") | ||
msg.setText("Prediction failed.") | ||
msg.exec_() | ||
|
||
|
||
def displayPrediction(self, prediction): | ||
self.textbox.setText("${equation}$".format(equation=prediction)) | ||
|
||
pageSource = """ | ||
<html> | ||
<head><script id="MathJax-script" src="qrc:MathJax.js"></script> | ||
<script> | ||
MathJax.Hub.Config({messageStyle: 'none',tex2jax: {preview: 'none'}}); | ||
MathJax.Hub.Queue( | ||
function () { | ||
document.getElementById("equation").style.visibility = ""; | ||
} | ||
); | ||
</script> | ||
</head> """ + """ | ||
<body> | ||
<div id="equation" style="font-size:1em; visibility:hidden">$${equation}$$</div> | ||
</body> | ||
</html> | ||
""".format(equation=prediction) | ||
self.webView.setHtml(pageSource) | ||
|
||
|
||
class ModelThread(QThread): | ||
finished = pyqtSignal(dict) | ||
|
||
def __init__(self, img, args, objs): | ||
super().__init__() | ||
self.img = img | ||
self.args = args | ||
self.objs = objs | ||
|
||
def run(self): | ||
try: | ||
prediction = pix2tex.call_model(self.img, self.args, *self.objs) | ||
|
||
self.finished.emit({"success": True, "prediction": prediction}) | ||
except: | ||
self.finished.emit({"success": False, "prediction": None}) | ||
|
||
|
||
class SnipWidget(QMainWindow): | ||
isSnipping = False | ||
|
||
def __init__(self, parent): | ||
super().__init__() | ||
self.parent = parent | ||
|
||
root = tk.Tk() | ||
screenWidth = root.winfo_screenwidth() | ||
screenHeight = root.winfo_screenheight() | ||
self.setGeometry(0, 0, screenWidth, screenHeight) | ||
|
||
self.begin = QtCore.QPoint() | ||
self.end = QtCore.QPoint() | ||
|
||
self.mouse = Controller() | ||
|
||
def snip(self): | ||
self.isSnipping = True | ||
self.setWindowFlags(Qt.WindowStaysOnTopHint) | ||
QApplication.setOverrideCursor(QtGui.QCursor(QtCore.Qt.CrossCursor)) | ||
|
||
self.show() | ||
|
||
|
||
def paintEvent(self, event): | ||
if self.isSnipping: | ||
brushColor = (0, 180, 255, 100) | ||
lw = 3 | ||
opacity = 0.3 | ||
else: | ||
brushColor = (0, 200, 0, 128) | ||
lw = 3 | ||
opacity = 0.3 | ||
|
||
self.setWindowOpacity(opacity) | ||
qp = QtGui.QPainter(self) | ||
qp.setPen(QtGui.QPen(QtGui.QColor('black'), lw)) | ||
qp.setBrush(QtGui.QColor(*brushColor)) | ||
qp.drawRect(QtCore.QRect(self.begin, self.end)) | ||
|
||
|
||
def keyPressEvent(self, event): | ||
if event.key() == QtCore.Qt.Key_Escape: | ||
QApplication.restoreOverrideCursor() | ||
self.close() | ||
self.parent.show() | ||
event.accept() | ||
|
||
def mousePressEvent(self, event): | ||
self.startPos = self.mouse.position | ||
|
||
self.begin = event.pos() | ||
self.end = self.begin | ||
self.update() | ||
|
||
def mouseMoveEvent(self, event): | ||
self.end = event.pos() | ||
self.update() | ||
|
||
def mouseReleaseEvent(self, event): | ||
self.isSnipping = False | ||
QApplication.restoreOverrideCursor() | ||
|
||
startPos = self.startPos | ||
endPos = self.mouse.position | ||
|
||
x1 = min(startPos[0], endPos[0]) | ||
y1 = min(startPos[1], endPos[1]) | ||
x2 = max(startPos[0], endPos[0]) | ||
y2 = max(startPos[1], endPos[1]) | ||
|
||
self.repaint() | ||
QApplication.processEvents() | ||
img = ImageGrab.grab(bbox=(x1, y1, x2, y2)) | ||
QApplication.processEvents() | ||
|
||
self.close() | ||
self.begin = QtCore.QPoint() | ||
self.end = QtCore.QPoint() | ||
self.parent.returnSnip(img) | ||
|
||
|
||
if __name__ == '__main__': | ||
app = QApplication(sys.argv) | ||
ex = App() | ||
sys.exit(app.exec_()) |
Oops, something went wrong.