Skip to content

Commit

Permalink
Move string line encoding logic from tokenizer to Dictionary (unified…
Browse files Browse the repository at this point in the history
… diff). (facebookresearch#541)

Summary:
Pull Request resolved: facebookresearch#541

Just a combo of a stacked pair D14057943 & D14176011,
Made this as a separete diff cause there seems to be some issue with porting a stacked change into github repo

Differential Revision: D14251048

fbshipit-source-id: 0a47f534a69d6ab2ebe035fba40fd51748cccfb8
  • Loading branch information
vlad-karpukhin authored and facebook-github-bot committed Feb 28, 2019
1 parent bc91927 commit f296824
Show file tree
Hide file tree
Showing 13 changed files with 204 additions and 196 deletions.
10 changes: 4 additions & 6 deletions docs/tutorial_classifying_names.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ following contents::

from fairseq.data import Dictionary, LanguagePairDataset
from fairseq.tasks import FairseqTask, register_task
from fairseq.tokenizer import Tokenizer


@register_task('simple_classification')
Expand Down Expand Up @@ -253,8 +252,8 @@ following contents::
sentence = line.strip()

# Tokenize the sentence, splitting on spaces
tokens = Tokenizer.tokenize(
sentence, self.input_vocab, add_if_not_exist=False,
tokens = self.input_vocab.encode_line(
sentence, add_if_not_exist=False,
)

sentences.append(tokens)
Expand Down Expand Up @@ -356,7 +355,6 @@ Finally we can write a short script to evaluate our model on new inputs. Create
a new file named :file:`eval_classifier.py` with the following contents::

from fairseq import data, options, tasks, utils
from fairseq.tokenizer import Tokenizer

# Parse command-line arguments for generation
parser = options.get_generation_parser(default_task='simple_classification')
Expand All @@ -375,8 +373,8 @@ a new file named :file:`eval_classifier.py` with the following contents::

# Tokenize into characters
chars = ' '.join(list(sentence.strip()))
tokens = Tokenizer.tokenize(
chars, task.source_dictionary, add_if_not_exist=False,
tokens = task.source_dictionary.encode_line(
chars, add_if_not_exist=False,
)

# Build mini-batch to feed to the model
Expand Down
67 changes: 67 additions & 0 deletions fairseq/binarizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

from collections import Counter
import os

from fairseq.tokenizer import tokenize_line


def safe_readline(f):
pos = f.tell()
while True:
try:
return f.readline()
except UnicodeDecodeError:
pos -= 1
f.seek(pos) # search where this character begins


class Binarizer:

@staticmethod
def binarize(filename, dict, consumer, tokenize=tokenize_line, append_eos=True, reverse_order=False,
offset=0, end=-1):
nseq, ntok = 0, 0
replaced = Counter()

def replaced_consumer(word, idx):
if idx == dict.unk_index and word != dict.unk_word:
replaced.update([word])

with open(filename, 'r', encoding='utf-8') as f:
f.seek(offset)
# next(f) breaks f.tell(), hence readline() must be used
line = safe_readline(f)
while line:
if end > 0 and f.tell() > end:
break
ids = dict.encode_line(
line=line,
line_tokenizer=tokenize,
add_if_not_exist=False,
consumer=replaced_consumer,
append_eos=append_eos,
reverse_order=reverse_order,
)
nseq += 1
ntok += len(ids)
consumer(ids)
line = f.readline()
return {'nseq': nseq, 'nunk': sum(replaced.values()), 'ntok': ntok, 'replaced': replaced}

@staticmethod
def find_offsets(filename, num_chunks):
with open(filename, 'r', encoding='utf-8') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_chunks
offsets = [0 for _ in range(num_chunks + 1)]
for i in range(1, num_chunks):
f.seek(chunk_size * i)
safe_readline(f)
offsets[i] = f.tell()
return offsets
8 changes: 8 additions & 0 deletions fairseq/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,3 +213,11 @@ def is_batch_full(num_tokens):

if len(batch) > 0:
yield batch


def process_bpe_symbol(sentence: str, bpe_symbol: str):
if bpe_symbol == 'sentencepiece':
sentence = sentence.replace('\u2581', ' ').strip()
elif bpe_symbol is not None:
sentence = (sentence + ' ').replace(bpe_symbol, '').rstrip()
return sentence
100 changes: 86 additions & 14 deletions fairseq/data/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
# can be found in the PATENTS file in the same directory.

from collections import Counter
from multiprocessing import Pool
import os

import torch

from fairseq.tokenizer import tokenize_line
from fairseq.binarizer import safe_readline
from fairseq.data import data_utils


class Dictionary(object):
"""A mapping from symbols to consecutive integers"""
Expand Down Expand Up @@ -57,14 +62,8 @@ def token_string(i):
else:
return self[i]

if bpe_symbol == 'sentencepiece':
sent = ''.join(token_string(i) for i in tensor if i != self.eos())
sent = sent.replace('\u2581', ' ').strip()
else:
sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
if bpe_symbol is not None and bpe_symbol != 'sentencepiece':
sent = (sent + ' ').replace(bpe_symbol, '').rstrip()
return sent
sent = ''.join(token_string(i) for i in tensor if i != self.eos())
return data_utils.process_bpe_symbol(sent, bpe_symbol)

def unk_string(self, escape=False):
"""Return unknown string, optionally escaped as: <<unk>>"""
Expand Down Expand Up @@ -181,31 +180,104 @@ def load(cls, f, ignore_utf_errors=False):
"rebuild the dataset".format(f))

d = cls()
for line in f.readlines():
lines = f.readlines()
indices_start_line = d._load_meta(lines)
for line in lines[indices_start_line:]:
idx = line.rfind(' ')
if idx == -1:
raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
word = line[:idx]
count = int(line[idx+1:])
count = int(line[idx + 1:])
d.indices[word] = len(d.symbols)
d.symbols.append(word)
d.count.append(count)
return d

def save(self, f):
"""Stores dictionary into a text file"""
def _save(self, f, kv_iterator):
if isinstance(f, str):
os.makedirs(os.path.dirname(f), exist_ok=True)
with open(f, 'w', encoding='utf-8') as fd:
return self.save(fd)
for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]):
print('{} {}'.format(symbol, count), file=f)
for k, v in kv_iterator:
print('{} {}'.format(k, v), file=f)

def _get_meta(self):
return [], []

def _load_meta(self, lines):
return 0

def save(self, f):
"""Stores dictionary into a text file"""
ex_keys, ex_vals = self._get_meta()
self._save(f, zip(ex_keys + self.symbols[self.nspecial:], ex_vals + self.count[self.nspecial:]))

def dummy_sentence(self, length):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos()
return t

def encode_line(self, line, line_tokenizer=tokenize_line, add_if_not_exist=True,
consumer=None, append_eos=True, reverse_order=False):
words = line_tokenizer(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)

for i, word in enumerate(words):
if add_if_not_exist:
idx = self.add_symbol(word)
else:
idx = self.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = self.eos_index
return ids

@staticmethod
def _add_file_to_dictionary_single_worker(filename, tokenize, eos_word, worker_id=0, num_workers=1):
counter = Counter()
with open(filename, 'r', encoding='utf-8') as f:
size = os.fstat(f.fileno()).st_size
chunk_size = size // num_workers
offset = worker_id * chunk_size
end = offset + chunk_size
f.seek(offset)
if offset > 0:
safe_readline(f) # drop first incomplete line
line = f.readline()
while line:
for word in tokenize(line):
counter.update([word])
counter.update([eos_word])
if f.tell() > end:
break
line = f.readline()
return counter

@staticmethod
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
def merge_result(counter):
for w, c in counter.items():
dict.add_symbol(w, c)

if num_workers > 1:
pool = Pool(processes=num_workers)
results = []
for worker_id in range(num_workers):
results.append(pool.apply_async(
Dictionary._add_file_to_dictionary_single_worker,
(filename, tokenize, dict.eos_word, worker_id, num_workers)
))
pool.close()
pool.join()
for r in results:
merge_result(r.get())
else:
merge_result(Dictionary._add_file_to_dictionary_single_worker(filename, tokenize, dict.eos_word))

class TruncatedDictionary(object):

Expand Down
6 changes: 2 additions & 4 deletions fairseq/data/indexed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import numpy as np
import torch

from fairseq.tokenizer import Tokenizer


def read_longs(f, n):
a = np.empty(n, dtype=np.int64)
Expand Down Expand Up @@ -171,8 +169,8 @@ def read_data(self, path, dictionary):
with open(path, 'r', encoding='utf-8') as f:
for line in f:
self.lines.append(line.strip('\n'))
tokens = Tokenizer.tokenize(
line, dictionary, add_if_not_exist=False,
tokens = dictionary.encode_line(
line, add_if_not_exist=False,
append_eos=self.append_eos, reverse_order=self.reverse_order,
).long()
self.tokens_list.append(tokens)
Expand Down
3 changes: 1 addition & 2 deletions fairseq/tasks/fairseq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from fairseq import tokenizer
from fairseq.data import data_utils, FairseqDataset, iterators, Dictionary
from fairseq.tokenizer import Tokenizer


class FairseqTask(object):
Expand Down Expand Up @@ -52,7 +51,7 @@ def build_dictionary(cls, filenames, workers=1, threshold=-1, nwords=-1, padding
"""
d = Dictionary()
for filename in filenames:
Tokenizer.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
Dictionary.add_file_to_dictionary(filename, d, tokenizer.tokenize_line, workers)
d.finalize(threshold=threshold, nwords=nwords, padding_factor=padding_factor)
return d

Expand Down
Loading

0 comments on commit f296824

Please sign in to comment.