Skip to content

Commit

Permalink
Merge branch 'main' into api
Browse files Browse the repository at this point in the history
  • Loading branch information
lukas-blecher committed Apr 20, 2022
2 parents f002a65 + 59903b1 commit 279cb2f
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 81 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ To run the model you need Python 3.7+
Install the package `pix2tex`:

```
pip install pix2tex
pip install pix2tex[gui]
```

Model checkpoints will be downloaded automatically.
Expand Down Expand Up @@ -73,7 +73,6 @@ In order to render the math in many different fonts we use XeLaTeX, generate a
* [XeLaTeX](https://www.ctan.org/pkg/xetex)
* [ImageMagick](https://imagemagick.org/) with [Ghostscript](https://www.ghostscript.com/index.html). (for converting pdf to png)
* [Node.js](https://nodejs.org/) to run [KaTeX](https://github.com/KaTeX/KaTeX) (for normalizing Latex code)
* [`de-macro`](https://www.ctan.org/pkg/de-macro) >= 1.4 (only for parsing arxiv papers)
* Python 3.7+ & dependencies (specified in `setup.py`)

### Fonts
Expand Down
41 changes: 17 additions & 24 deletions pix2tex/dataset/arxiv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@
import sys
import argparse
import logging
import shutil
import subprocess
import tarfile
import tempfile
import chardet
import logging
import requests
import urllib.request
from tqdm import tqdm
from urllib.error import HTTPError
from pix2tex.dataset.extract_latex import *
from pix2tex.dataset.scraping import *
from pix2tex.dataset.extract_latex import find_math
from pix2tex.dataset.scraping import recursive_search
from pix2tex.dataset.demacro import *

# logging.getLogger().setLevel(logging.INFO)
Expand Down Expand Up @@ -49,7 +48,7 @@ def download(url, dir_path='./'):
return 0


def read_tex_files(file_path, demacro=True):
def read_tex_files(file_path):
tex = ''
try:
with tempfile.TemporaryDirectory() as tempdir:
Expand All @@ -58,18 +57,11 @@ def read_tex_files(file_path, demacro=True):
tf.extractall(tempdir)
tf.close()
texfiles = [os.path.abspath(x) for x in glob.glob(os.path.join(tempdir, '**', '*.tex'), recursive=True)]
# de-macro
if demacro:
ret = subprocess.run(['de-macro', *texfiles], cwd=tempdir, capture_output=True)
if ret.returncode == 0:
texfiles = glob.glob(os.path.join(tempdir, '**', '*-clean.tex'), recursive=True)
except tarfile.ReadError as e:
texfiles = [file_path] # [os.path.join(tempdir, file_path+'.tex')]
#shutil.move(file_path, texfiles[0])

for texfile in texfiles:
try:
tex += open(texfile, 'r', encoding=chardet.detect(open(texfile, 'br').readline())['encoding']).read()
tex += open(texfile, 'r', encoding=chardet.detect(open(texfile, 'br').readline())['encoding']).read()
except UnicodeDecodeError:
pass
tex = unfold(convert(tex))
Expand All @@ -85,32 +77,32 @@ def download_paper(arxiv_id, dir_path='./'):
return download(url, dir_path)


def read_paper(targz_path, delete=True, demacro=True):
def read_paper(targz_path, delete=True):
paper = ''
if targz_path != 0:
paper = read_tex_files(targz_path, demacro)
paper = read_tex_files(targz_path)
if delete:
os.remove(targz_path)
return paper


def parse_arxiv(id, demacro=True):
def parse_arxiv(id):
tempdir = tempfile.gettempdir()
text = read_paper(download_paper(id, tempdir), demacro=demacro)
text = read_paper(download_paper(id, tempdir))
#print(text, file=open('paper.tex', 'w'))
#linked = list(set([l for l in re.findall(arxiv_id, text)]))

return find_math(text, wiki=False), []


if __name__ == '__main__':
# logging.getLogger().setLevel(logging.DEBUG)
parser = argparse.ArgumentParser(description='Extract math from arxiv')
parser.add_argument('-m', '--mode', default='top100', choices=['top100', 'id', 'dir'],
parser.add_argument('-m', '--mode', default='top100', choices=['top100', 'ids', 'dir'],
help='Where to extract code from. top100: current 100 arxiv papers, id: specific arxiv ids. \
Usage: `python arxiv.py -m id id001 id002`, dir: a folder full of .tar.gz files. Usage: `python arxiv.py -m dir directory`')
parser.add_argument(nargs='+', dest='args', default=[])
parser.add_argument(nargs='*', dest='args', default=[])
parser.add_argument('-o', '--out', default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data'), help='output directory')
parser.add_argument('-d', '--no-demacro', dest='demacro', action='store_false', help='Use de-macro (Slows down extraction but improves quality)')
args = parser.parse_args()
if '.' in args.out:
args.out = os.path.dirname(args.out)
Expand All @@ -122,7 +114,7 @@ def parse_arxiv(id, demacro=True):
if args.mode == 'ids':
visited, math = recursive_search(parse_arxiv, args.args, skip=skip, unit='paper')
elif args.mode == 'top100':
url = 'https://arxiv.org/list/hep-th/2012?skip=0&show=100' # https://arxiv.org/list/hep-th/2012?skip=0&show=100
url = 'https://arxiv.org/list/physics/pastweek?skip=0&show=100' #'https://arxiv.org/list/hep-th/2203?skip=0&show=100'
ids = get_all_arxiv_ids(requests.get(url).text)
math, visited = [], ids
for id in tqdm(ids):
Expand All @@ -133,15 +125,16 @@ def parse_arxiv(id, demacro=True):
math, visited = [], []
for f in tqdm(dirs):
try:
text = read_paper(os.path.join(args.args[0], f), False, args.demacro)
text = read_paper(os.path.join(args.args[0], f), False)
math.extend(find_math(text, wiki=False))
visited.append(os.path.basename(f))
visited.append(os.path.basename(f))
except Exception as e:
logging.debug(e)
pass
else:
raise NotImplementedError

print('\n'.join(math))
sys.exit(0)
for l, name in zip([visited, math], ['visited_arxiv.txt', 'math_arxiv.txt']):
f = os.path.join(args.out, name)
if not os.path.exists(f):
Expand Down
67 changes: 51 additions & 16 deletions pix2tex/dataset/demacro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,24 @@

import argparse
import re
from pix2tex.dataset.extract_latex import remove_labels


def main():
args = parse_command_line()
data = read(args.input)
data = convert(data)
if args.demacro:
data = unfold(data)
write(args.output, data)
data = unfold(data)
if args.output is not None:
write(args.output, data)
else:
print(data)


def parse_command_line():
parser = argparse.ArgumentParser(description='Replace \\def with \\newcommand where possible.')
parser.add_argument('input', help='TeX input file with \\def')
parser.add_argument('--output', '-o', required=True, help='TeX output file with \\newcommand')
parser.add_argument('--demacro', action='store_true', help='replace all commands with their definition')

parser.add_argument('--output', '-o', default=None, help='TeX output file with \\newcommand')
return parser.parse_args()


Expand All @@ -37,27 +38,61 @@ def convert(data):
)


def unfold(t):
cmds = re.findall(r'\\(?:re)?newcommand\*?{\\(.+?)}\s*(\[\d\])?(\[.+?\])?{(.+?)}\n', t)
cmds = sorted(cmds, key=lambda x: len(x[0]))
# print(cmds)
def bracket_replace(string: str) -> str:
'''
replaces all layered brackets with special symbols
'''
layer = 0
out = list(string)
for i, c in enumerate(out):
if c == '{':
if layer > 0:
out[i] = 'Ḋ'
layer += 1
elif c == '}':
layer -= 1
if layer > 0:
out[i] = 'Ḍ'
return ''.join(out)


def undo_bracket_replace(string):
return string.replace('Ḋ', '{').replace('Ḍ', '}')


def sweep(t, cmds):
num_matches = 0
for c in cmds:
nargs = int(c[1][1]) if c[1] != r'' else 0
# print(c)
optional = c[2] != r''
if nargs == 0:
#t = t.replace(r'\\%s' % c[0], c[-1])
t = re.sub(r'\\%s([\W_^\d])' % c[0], r'%s\1' % c[-1].replace('\\', r'\\'), t)
else:
matches = re.findall(r'(\\%s(?:\[(.+?)\])?' % c[0]+r'{(.+?)}'*(nargs-(1 if c[2] != r'' else 0))+r')', t)
# print(matches)
matches = re.findall(r'(\\%s(?:\[(.+?)\])?' % c[0]+r'{(.+?)}'*(nargs-(1 if optional else 0))+r')', t)
num_matches += len(matches)
for i, m in enumerate(matches):
r = c[-1]
if m[1] == r'':
matches[i] = (m[0], c[2][1:-1], *m[2:])
for j in range(1, nargs+1):
r = r.replace(r'#%i' % j, matches[i][j])
r = r.replace(r'#%i' % j, matches[i][j+int(not optional)])
t = t.replace(matches[i][0], r)
return t
return t, num_matches


def unfold(t):
t = remove_labels(t).replace('\n', 'Ċ')

cmds = re.findall(r'\\(?:re)?newcommand\*?{\\(.+?)}\s*(\[\d\])?(\[.+?\])?{(.+?)}Ċ', t)
cmds = sorted(cmds, key=lambda x: len(x[0]))
for _ in range(10):
# check for up to 10 nested commands
t = bracket_replace(t)
t, N = sweep(t, cmds)
t = undo_bracket_replace(t)
if N == 0:
break
return t.replace('Ċ', '\n')


def replace(match):
Expand Down
11 changes: 6 additions & 5 deletions pix2tex/dataset/extract_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
displaymath = re.compile(r'(\\displaystyle)(.{%i,%i}?)(\}(?:<|"))' % (1, MAX_CHARS))
outer_whitespace = re.compile(
r'^\\,|\\,$|^~|~$|^\\ |\\ $|^\\thinspace|\\thinspace$|^\\!|\\!$|^\\:|\\:$|^\\;|\\;$|^\\enspace|\\enspace$|^\\quad|\\quad$|^\\qquad|\\qquad$|^\\hspace{[a-zA-Z0-9]+}|\\hspace{[a-zA-Z0-9]+}$|^\\hfill|\\hfill$')

label_names = [re.compile(r'\\%s\s?\{(.*?)\}' % s) for s in ['ref', 'cite', 'label', 'caption', 'eqref']]

def check_brackets(s):
a = []
Expand Down Expand Up @@ -39,17 +39,18 @@ def check_brackets(s):
else:
return s

def remove_labels(string):
for s in label_names:
string = re.sub(s, '', string)
return string

def clean_matches(matches, min_chars=MIN_CHARS):
template = r'\\%s\s?\{(.*?)\}'
sub = [re.compile(template % s) for s in ['ref', 'cite', 'label', 'caption']]
faulty = []
for i in range(len(matches)):
if 'tikz' in matches[i]: # do not support tikz at the moment
faulty.append(i)
continue
for s in sub:
matches[i] = re.sub(s, '', matches[i])
matches[i] = remove_labels(matches[i])
matches[i] = matches[i].replace('\n', '').replace(r'\notag', '').replace(r'\nonumber', '')
matches[i] = re.sub(outer_whitespace, '', matches[i])
if len(matches[i]) < min_chars:
Expand Down
4 changes: 1 addition & 3 deletions pix2tex/dataset/scraping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import html
import requests
import re
import tempfile
from pix2tex.dataset.arxiv import *
from pix2tex.dataset.extract_latex import *
from pix2tex.dataset.extract_latex import find_math

wikilinks = re.compile(r'href="/wiki/(.*?)"')
htmltags = re.compile(r'<(noscript|script)>.*?<\/\1>', re.S)
Expand Down
16 changes: 8 additions & 8 deletions pix2tex/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ def num_model_params(model):

@contextlib.contextmanager
def in_model_path():
from importlib.resources import path
with path('pix2tex', 'model') as model_path:
saved = os.getcwd()
os.chdir(model_path)
try:
yield
finally:
os.chdir(saved)
import pix2tex
model_path = os.path.join(os.path.dirname(pix2tex.__file__), 'model')
saved = os.getcwd()
os.chdir(model_path)
try:
yield
finally:
os.chdir(saved)
45 changes: 22 additions & 23 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# read the contents of your README file
from pathlib import Path
this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text()
long_description = (this_directory / 'README.md').read_text()

gui = [
"PyQt5",
Expand All @@ -21,8 +21,8 @@

setuptools.setup(
name='pix2tex',
version='0.0.12',
description="pix2tex: Using a ViT to convert images of equations into LaTeX code.",
version='0.0.14',
description='pix2tex: Using a ViT to convert images of equations into LaTeX code.',
long_description=long_description,
long_description_content_type='text/markdown',
author='Lukas Blecher',
Expand All @@ -43,26 +43,25 @@
]
},
install_requires=[
"tqdm>=4.47.0",
"munch>=2.5.0",
"torch>=1.7.1",
"torchvision>=0.8.1",
"opencv_python_headless>=4.1.1.26",
"requests>=2.22.0",
"einops>=0.3.0",
"chardet>=3.0.4",
"x_transformers==0.15.0",
"imagesize>=1.2.0",
"transformers==4.2.2",
"tokenizers==0.9.4",
"numpy>=1.19.5",
"Pillow>=9.1.0",
"PyYAML>=5.4.1",
"torchtext>=0.6.0",
"albumentations>=0.5.2",
"pandas>=1.0.0",
"timm",
"python-Levenshtein>=0.12.2",
'tqdm>=4.47.0',
'munch>=2.5.0',
'torch>=1.7.1',
'opencv_python_headless>=4.1.1.26',
'requests>=2.22.0',
'einops>=0.3.0',
'x_transformers==0.15.0',
'transformers>=4.18.0',
'tokenizers==0.12.1',
'numpy>=1.19.5',
'Pillow>=9.1.0',
'PyYAML>=5.4.1',
'pandas>=1.0.0',
'timm',
'chardet>=3.0.4',
'python-Levenshtein>=0.12.2',
'torchtext>=0.6.0',
'albumentations>=0.5.2',
'imagesize>=1.2.0',
],
extras_require={
"all": gui+api,
Expand Down

0 comments on commit 279cb2f

Please sign in to comment.