Skip to content

Commit

Permalink
Merge pull request lukas-blecher#14 from lukas-blecher/gui
Browse files Browse the repository at this point in the history
Add GUI
  • Loading branch information
lukas-blecher authored May 11, 2021
2 parents 57ae396 + 648b84b commit c7898ab
Show file tree
Hide file tree
Showing 11 changed files with 14,766 additions and 30 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,4 @@ dataset/data/**
wandb/
checkpoints/**
!**/.gitkeep
.vscode
31 changes: 19 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pix2tex - LaTeX OCR
The goal of this project is to create a learning based system that takes an image of a math formula and returns corresponding LaTeX code. As a physics student I often find myself writing down Latex code from a reference image. I wanted to streamline my workflow and began looking into solutions, but besides the Freemium [Mathpix](https://mathpix.com/) I could not find anything ready-to-use that runs locally. That's why I decided to create it myself.
The goal of this project is to create a learning based system that takes an image of a math formula and returns corresponding LaTeX code.

![header](https://user-images.githubusercontent.com/55287601/109183599-69431f00-778e-11eb-9809-d42b9451e018.png)

Expand All @@ -23,12 +23,18 @@ In order to render the math in many different fonts we use XeLaTeX, generate a
2. For now you need to install the Python dependencies specified in `requirements.txt` (look [above](#Requirements))
3. Download the `weights.pth` (and optionally `image_resizer.pth`) file from my [Google Drive](https://drive.google.com/drive/folders/1cgmyiaT5uwQJY2pB0ngebuTcK5ivKXIb) and place it in the `checkpoints` directory

The `pix2tex.py` file offers a quick way to get the model prediction of an image. First you need to copy the formula image into the clipboard memory for example by using a snipping tool (on Windows built in `Win`+`Shift`+`S`). Next just call the script with `python pix2tex.py`. It will print out the predicted Latex code for that image and also copy it into your clipboard.
Thanks to [@katie-lim](https://github.com/katie-lim), you can use a nice user interface as a quick way to get the model prediction. Just call the GUI with `python gui.py`. From here you can take a screenshot and the predicted latex code is rendered using [MathJax](https://www.mathjax.org/) and copied to your clipboard.

![demo](https://user-images.githubusercontent.com/55287601/117812740-77b7b780-b262-11eb-81f6-fc19766ae2ae.gif)

If the model is unsure about the what's in the image it might output a different prediction every time you click "Retry". With the `temperature` parameter you can control this behavior (low temperature will produce the same result).

Alternatively you can use `pix2tex.py` with similar functionality as `gui.py`, only as command line tool. In this case you don't need to install PyQt5. Using this script you can also parse already existing images from the disk.

**Note:** As of right now it works best with images of smaller resolution. Don't zoom in all the way before taking a picture. Double check the result carefully. You can try to redo the prediction with an other resolution if the answer was wrong.

**Update:** I have trained an image classifier on randomly scaled images of the training data to predict the original size.
This model will automatically resize the custom image to best resemble the training data and thus increase performance of images found in the wild. To use this preprocessing step, all you have to do is download the second weights file mentioned above. You should be able to take bigger (or smaller) images of the formula and still get a satisfying result
**Update:** I have trained an image classifier on randomly scaled images of the training data to predict the original size.
This model will automatically resize the custom image to best resemble the training data and thus increase performance of images found in the wild. To use this preprocessing step, all you have to do is download the second weights file mentioned above. You should be able to take bigger (or smaller) images of the formula and still get a satisfying result

## Training the model
1. First we need to combine the images with their ground truth labels. I wrote a dataset class (which needs further improving) that saves the relative paths to the images with the LaTeX code they were rendered with. To generate the dataset pickle file run
Expand All @@ -50,9 +56,9 @@ python train.py --config path_to_config_file
The model consist of a ViT [[1](#References)] encoder with a ResNet backbone and a Transformer [[2](#References)] decoder.

### Performance
|BLEU score | normed edit distance|
|-|-|
|0.88|0.10|
| BLEU score | normed edit distance |
| ---------- | -------------------- |
| 0.88 | 0.10 |

## Data
We need paired data for the network to learn. Luckily there is a lot of LaTeX code on the internet, e.g. [wikipedia](www.wikipedia.org), [arXiv](www.arxiv.org). We also use the formulae from the [im2latex-100k](https://zenodo.org/record/56198#.V2px0jXT6eA) dataset.
Expand All @@ -63,21 +69,22 @@ Latin Modern Math, GFSNeohellenicMath.otf, Asana Math, XITS Math, Cambria Math


## TODO
- [x] add more evaluation metrics
- [x] create a GUI
- [ ] add beam search
- [ ] support handwritten formulae
- [ ] reduce model size (distillation)
- [ ] find optimal hyperparameters
- [ ] tweak model structure
- [x] add more evaluation metrics
- [ ] fix data scraping and scape more data
- [ ] fix data scraping and scrape more data
- [ ] trace the model
- [ ] create a standalone application


## Contribution
Contributions of any kind are welcome.

## Acknowledgement
Code taken and modified from [lucidrains](https://github.com/lucidrains), [rwightman](https://github.com/rwightman/pytorch-image-models), [im2markup](https://github.com/harvardnlp/im2markup), [arxiv_leaks](https://github.com/soskek/arxiv_leaks)
## Acknowledgment
Code taken and modified from [lucidrains](https://github.com/lucidrains), [rwightman](https://github.com/rwightman/pytorch-image-models), [im2markup](https://github.com/harvardnlp/im2markup), [arxiv_leaks](https://github.com/soskek/arxiv_leaks), [pkra: Mathjax](https://github.com/pkra/MathJax-single-file), [harupy: snipping tool](https://github.com/harupy/snipping-tool)

## References
[1] [An Image is Worth 16x16 Words](https://arxiv.org/abs/2010.11929)
Expand Down
269 changes: 269 additions & 0 deletions gui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import sys
import os
import argparse
from PyQt5 import QtCore, QtGui
from PyQt5.QtCore import QObject, Qt, pyqtSlot, pyqtSignal, QThread
from PyQt5.QtWebEngineWidgets import QWebEngineView
from PyQt5.QtWidgets import QMainWindow, QApplication, QMessageBox, QVBoxLayout, QWidget,\
QPushButton, QTextEdit, QLineEdit, QFormLayout, QHBoxLayout, QCheckBox, QSpinBox, QDoubleSpinBox
from resources import resources
from pynput.mouse import Controller

from PIL import ImageGrab
import numpy as np
from screeninfo import get_monitors
import pix2tex

QApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling, True)
QApplication.setAttribute(QtCore.Qt.AA_UseHighDpiPixmaps, True)


class App(QMainWindow):
def __init__(self, args=None):
super().__init__()
self.args = args
self.initModel()
self.initUI()
self.snipWidget = SnipWidget(self)

self.show()

def initModel(self):
args, *objs = pix2tex.initialize(self.args)
self.args = args
self.objs = objs

def initUI(self):
self.setWindowTitle("LaTeX OCR")
QApplication.setWindowIcon(QtGui.QIcon(':/icons/icon.svg'))
self.left = 300
self.top = 300
self.width = 500
self.height = 300
self.setGeometry(self.left, self.top, self.width, self.height)

# Create LaTeX display
self.webView = QWebEngineView()
self.webView.setHtml("")
self.webView.setMinimumHeight(80)

# Create textbox
self.textbox = QTextEdit(self)
self.textbox.textChanged.connect(self.displayPrediction)
self.textbox.setMinimumHeight(40)

# Create temperature text input
self.tempField = QDoubleSpinBox(self)
self.tempField.setValue(self.args.get('temperature', 0.25))
self.tempField.setRange(0, 1)
self.tempField.setSingleStep(0.1)

# Create snip button
self.snipButton = QPushButton('Snip', self)
self.snipButton.clicked.connect(self.onClick)

# Create retry button
self.retryButton = QPushButton('Retry', self)
self.retryButton.setEnabled(False)
self.retryButton.clicked.connect(self.returnSnip)

# Create layout
centralWidget = QWidget()
centralWidget.setMinimumWidth(200)
self.setCentralWidget(centralWidget)

lay = QVBoxLayout(centralWidget)
lay.addWidget(self.webView, stretch=4)
lay.addWidget(self.textbox, stretch=2)
buttons = QHBoxLayout()
buttons.addWidget(self.snipButton)
buttons.addWidget(self.retryButton)
lay.addLayout(buttons)
settings = QFormLayout()
settings.addRow('Temperature:', self.tempField)
lay.addLayout(settings)

@pyqtSlot()
def onClick(self):
self.close()
self.snipWidget.snip()

def returnSnip(self, img=None):
# Show processing icon
pageSource = """<center>
<img src="qrc:/icons/processing-icon-anim.svg" width="50", height="50">
</center>"""
self.webView.setHtml(pageSource)

self.snipButton.setEnabled(False)
self.retryButton.setEnabled(False)

self.show()
try:
self.args.temperature = self.tempField.value()
if self.args.temperature == 0:
self.args.temperature = 1e-8
except:
pass
# Run the model in a separate thread
self.thread = ModelThread(img=img, args=self.args, objs=self.objs)
self.thread.finished.connect(self.returnPrediction)
self.thread.finished.connect(self.thread.deleteLater)

self.thread.start()

def returnPrediction(self, result):
self.snipButton.setEnabled(True)

success, prediction = result["success"], result["prediction"]

if success:
self.displayPrediction(prediction)
self.retryButton.setEnabled(True)
else:
self.webView.setHtml("")
msg = QMessageBox()
msg.setWindowTitle(" ")
msg.setText("Prediction failed.")
msg.exec_()

def displayPrediction(self, prediction=None):
if prediction is not None:
self.textbox.setText("${equation}$".format(equation=prediction))
else:
prediction = self.textbox.toPlainText().strip('$')
pageSource = """
<html>
<head><script id="MathJax-script" src="qrc:MathJax.js"></script>
<script>
MathJax.Hub.Config({messageStyle: 'none',tex2jax: {preview: 'none'}});
MathJax.Hub.Queue(
function () {
document.getElementById("equation").style.visibility = "";
}
);
</script>
</head> """ + """
<body>
<div id="equation" style="font-size:1em; visibility:hidden">$${equation}$$</div>
</body>
</html>
""".format(equation=prediction)
self.webView.setHtml(pageSource)


class ModelThread(QThread):
finished = pyqtSignal(dict)

def __init__(self, img, args, objs):
super().__init__()
self.img = img
self.args = args
self.objs = objs

def run(self):
try:
prediction = pix2tex.call_model(self.args, *self.objs, img=self.img)
self.finished.emit({"success": True, "prediction": prediction})
except Exception as e:
print(e)
self.finished.emit({"success": False, "prediction": None})


class SnipWidget(QMainWindow):
isSnipping = False

def __init__(self, parent):
super().__init__()
self.parent = parent

monitos = get_monitors()
bboxes = np.array([[m.x, m.y, m.width, m.height] for m in monitos])
x, y, _, _ = bboxes.min(0)
w, h = bboxes[:, [0, 2]].sum(1).max(), bboxes[:, [1, 3]].sum(1).max()
self.setGeometry(x, y, w-x, h-y)

self.begin = QtCore.QPoint()
self.end = QtCore.QPoint()

self.mouse = Controller()

def snip(self):
self.isSnipping = True
self.setWindowFlags(Qt.WindowStaysOnTopHint)
QApplication.setOverrideCursor(QtGui.QCursor(QtCore.Qt.CrossCursor))

self.show()

def paintEvent(self, event):
if self.isSnipping:
brushColor = (0, 180, 255, 100)
lw = 3
opacity = 0.3
else:
brushColor = (255, 255, 255, 0)
lw = 3
opacity = 0

self.setWindowOpacity(opacity)
qp = QtGui.QPainter(self)
qp.setPen(QtGui.QPen(QtGui.QColor('black'), lw))
qp.setBrush(QtGui.QColor(*brushColor))
qp.drawRect(QtCore.QRect(self.begin, self.end))

def keyPressEvent(self, event):
if event.key() == QtCore.Qt.Key_Escape:
QApplication.restoreOverrideCursor()
self.close()
self.parent.show()
event.accept()

def mousePressEvent(self, event):
self.startPos = self.mouse.position

self.begin = event.pos()
self.end = self.begin
self.update()

def mouseMoveEvent(self, event):
self.end = event.pos()
self.update()

def mouseReleaseEvent(self, event):
self.isSnipping = False
QApplication.restoreOverrideCursor()

startPos = self.startPos
endPos = self.mouse.position

x1 = min(startPos[0], endPos[0])
y1 = min(startPos[1], endPos[1])
x2 = max(startPos[0], endPos[0])
y2 = max(startPos[1], endPos[1])

self.repaint()
QApplication.processEvents()
img = ImageGrab.grab(bbox=(x1, y1, x2, y2), all_screens=True)
QApplication.processEvents()

self.close()
self.begin = QtCore.QPoint()
self.end = QtCore.QPoint()
self.parent.returnSnip(img)


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GUI arguments')
parser.add_argument('-t', '--temperature', type=float, default=.2, help='Softmax sampling frequency')
parser.add_argument('-c', '--config', type=str, default='settings/config.yaml', help='path to config file')
parser.add_argument('-m', '--checkpoint', type=str, default='checkpoints/weights.pth', help='path to weights file')
parser.add_argument('--no-cuda', action='store_true', help='Compute on CPU')
parser.add_argument('--no-resize', action='store_true', help='Resize the image beforehand')
arguments = parser.parse_args()
latexocr_path = os.path.dirname(sys.argv[0])
if latexocr_path != '':
sys.path.insert(0, latexocr_path)
os.chdir(latexocr_path)
app = QApplication(sys.argv)
ex = App(arguments)
sys.exit(app.exec_())
Loading

0 comments on commit c7898ab

Please sign in to comment.