Skip to content

Commit

Permalink
Added 'prepare_data' cache
Browse files Browse the repository at this point in the history
Useful when trying to fit model in VRAM by changing vocab size
  • Loading branch information
daniel-kukiela committed Mar 26, 2018
1 parent 54f7683 commit e51a3ad
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 96 deletions.
210 changes: 116 additions & 94 deletions setup/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,94 +29,104 @@ def prepare():
raise

# Ensure that model/log folder exists
train_log_dir = os.path.join(hparams['out_dir'], 'train_log')
train_log_dir = hparams['out_dir'] + 'train_log/'
try:
os.makedirs(train_log_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise

data_vocab = Counter()
if not preprocessing['cache_preparation'] or not Path('{}/cache_data_vocab.pickle'.format(preprocessing['train_folder'])).exists() or not Path('{}/cache_data_vocab.pickle'.format(preprocessing['train_folder'])).is_file():

# Iterate thru files and prepare them
for file_name, amounts in files.items():
data_vocab = Counter()

vocab = Counter()
# Iterate thru files and prepare them
for file_name, amounts in files.items():

print("File: {}{}{}".format(colorama.Fore.GREEN, file_name, colorama.Fore.RESET))
vocab = Counter()

# Output file handler
out_file = open('{}/{}'.format(preprocessing['train_folder'], file_name), 'w', encoding='utf-8', buffering=131072)
print("File: {}{}{}".format(colorama.Fore.GREEN, file_name, colorama.Fore.RESET))

# Output file handler
out_file = open('{}{}'.format(preprocessing['train_folder'], file_name), 'w', encoding='utf-8', buffering=131072)

# Maximum number of lines
read = 0
amount = int(min(amounts['amount'] * preprocessing['samples'] if preprocessing['samples'] > 0 else 10 ** 20, amounts['up_to'] if amounts['up_to'] > 0 else 10 ** 20))

# Prepare thread variables
write_thread = None
vocab_thread = None
written_lines = 0

# Maximum number of lines
read = 0
amount = int(min(amounts['amount'] * preprocessing['samples'] if preprocessing['samples'] > 0 else 10 ** 20, amounts['up_to'] if amounts['up_to'] > 0 else 10 ** 20))
# We are going to use multiprocessing for tokenization, as it's cpu intensive
with Pool(processes=preprocessing['cpu_count']) as pool:

# Prepare thread variables
write_thread = None
vocab_thread = None
written_lines = 0
# Count number of lines in file
progress = tqdm(ascii=True, unit=' lines', total=min(amount, sum(1 for _ in open('{}{}'.format(preprocessing['source_folder'], file_name), 'r', encoding='utf-8', buffering=131072))))

# We are going to use multiprocessing for tokenization, as it's cpu intensive
with Pool(processes=preprocessing['cpu_count']) as pool:
# Open input file
with open('{}{}'.format(preprocessing['source_folder'], file_name), 'r', encoding='utf-8', buffering=131072) as in_file:

# Count number of lines in file
progress = tqdm(ascii=True, unit=' lines', total=min(amount, sum(1 for _ in open('{}/{}'.format(preprocessing['source_folder'], file_name), 'r', encoding='utf-8', buffering=131072))))
last_batch = False

# Open input file
with open('{}/{}'.format(preprocessing['source_folder'], file_name), 'r', encoding='utf-8', buffering=131072) as in_file:
# Iterate every 10k lines
for rows in read_lines(in_file, 30000, ''):

last_batch = False
# If number of lines is greater than limit - break
read += len(rows)
if read >= amount:
rows = rows[:amount-read+len(rows)]
last_batch = True

# Iterate every 10k lines
for rows in read_lines(in_file, 30000, ''):
# Process using multiprocessing
rows = pool.map(tokenize, rows, 500)

# If number of lines is greater than limit - break
read += len(rows)
if read >= amount:
rows = rows[:amount-read+len(rows)]
last_batch = True
# Process vocab using multiprocessing
vocab_part = pool.map(sentence_split, rows, 500)

# Process using multiprocessing
rows = pool.map(tokenize, rows, 500)
# Join running threads from previous loop
if write_thread is not None:
write_thread.join()
vocab_thread.join()
progress.update(written_lines)

# Process vocab using multiprocessing
vocab_part = pool.map(sentence_split, rows, 500)
# Thread for vocab update
vocab_thread = Thread(target=append_vocab, args=(vocab_part,))
vocab_thread.start()

# Join running threads from previous loop
if write_thread is not None:
write_thread.join()
vocab_thread.join()
progress.update(written_lines)
# And thread for saving tokenized data to output file
write_thread = Thread(target=write_lines, args=(out_file, rows, written_lines == 0))
write_thread.start()

# Thread for vocab update
vocab_thread = Thread(target=append_vocab, args=(vocab_part,))
vocab_thread.start()
# Last batch - break / exit loop
if last_batch:
break

# And thread for saving tokenized data to output file
write_thread = Thread(target=write_lines, args=(out_file, rows, written_lines == 0))
write_thread.start()
# Join running threads and update progress bar
write_thread.join()
vocab_thread.join()
progress.update(written_lines)
progress.close()

# Last batch - break / exit loop
if last_batch:
break
# If it's train file, save vocab
if file_name == '{}.{}'.format(hparams['train_prefix'].replace('.bpe', ''), hparams['src']).replace(preprocessing['train_folder'], '').lstrip('\\/'):
data_vocab[hparams['src']] = vocab
elif file_name == '{}.{}'.format(hparams['train_prefix'].replace('.bpe', ''), hparams['tgt']).replace(preprocessing['train_folder'], '').lstrip('\\/'):
data_vocab[hparams['tgt']] = vocab

# Join running threads and update progress bar
write_thread.join()
vocab_thread.join()
progress.update(written_lines)
progress.close()
# If joined vocab - add counters
if preprocessing['joined_vocab']:
data_vocab[hparams['src']] += data_vocab[hparams['tgt']]
del data_vocab[hparams['tgt']]

# If it's train file, save vocab
if file_name == '{}.{}'.format(hparams['train_prefix'].replace('.bpe', ''), hparams['src']).replace(preprocessing['train_folder'], '').lstrip('\\/'):
data_vocab[hparams['src']] = vocab
elif file_name == '{}.{}'.format(hparams['train_prefix'].replace('.bpe', ''), hparams['tgt']).replace(preprocessing['train_folder'], '').lstrip('\\/'):
data_vocab[hparams['tgt']] = vocab
with open('{}/cache_data_vocab.pickle'.format(preprocessing['train_folder']), 'wb') as f:
pickle.dump(data_vocab, f)

# If joined vocab - add counters
if preprocessing['joined_vocab']:
data_vocab[hparams['src']] += data_vocab[hparams['tgt']]
del data_vocab[hparams['tgt']]
else:
print('Using cached data')
with open('{}/cache_data_vocab.pickle'.format(preprocessing['train_folder']), 'rb') as f:
data_vocab = pickle.load(f)

# BPE/WPM-like tokenization
# inspired by and based on https://github.com/rsennrich/subword-nmt
Expand All @@ -133,32 +143,43 @@ def prepare():
# Learn BPE for both vocabs (or common vocab)
for source, raw_vocab in data_vocab.items():

# Pair stats
stats = Counter()
if not preprocessing['cache_preparation'] or not Path('{}/cache_temp_vocab.pickle'.format(preprocessing['train_folder'])).exists() or not Path('{}/cache_temp_vocab.pickle'.format(preprocessing['train_folder'])).is_file():

# Pair stats
stats = Counter()

# Pair indexes
indices = defaultdict(lambda: defaultdict(int))

# Build 'new' vocab used for BPE learning (train_vocab will be a final vocab for NMT)
vocab = []
train_vocab[source] = Counter()

# Pair indexes
indices = defaultdict(lambda: defaultdict(int))
# Build vocab for BPE learning purpose
print("Building temporary vocab ({})".format(hparams['src'] if preprocessing['joined_vocab'] else source))
for i, (entity, freq) in tqdm(enumerate(raw_vocab.most_common()), ascii=True, unit=' tokens'):

# Build 'new' vocab used for BPE learning (train_vocab will be a final vocab for NMT)
vocab = []
train_vocab[source] = Counter()
# Split vocab token
entity = tuple(entity.split())

# Build vocab for BPE learning purpose
print("Building temporary vocab ({})".format(hparams['src'] if preprocessing['joined_vocab'] else source))
for i, (entity, freq) in tqdm(enumerate(raw_vocab.most_common()), ascii=True, unit=' tokens'):
# Make pairs ("ABCD" -> (A, B), (B, C), (C, D)), stats, indexes and train vocab
prev_char = entity[0]
train_vocab[source][prev_char] += freq
for char in entity[1:]:
stats[prev_char, char] += freq
indices[prev_char, char][i] += 1
train_vocab[source][char] += freq
prev_char = char
vocab.append((entity, freq))

# Split vocab token
entity = tuple(entity.split())
with open('{}/cache_temp_vocab.pickle'.format(preprocessing['train_folder']), 'wb') as f:
pickle.dump((stats, dict(indices), train_vocab, vocab), f)

# Make pairs ("ABCD" -> (A, B), (B, C), (C, D)), stats, indexes and train vocab
prev_char = entity[0]
train_vocab[source][prev_char] += freq
for char in entity[1:]:
stats[prev_char, char] += freq
indices[prev_char, char][i] += 1
train_vocab[source][char] += freq
prev_char = char
vocab.append((entity, freq))
else:
print('Using cached data')
with open('{}/cache_temp_vocab.pickle'.format(preprocessing['train_folder']), 'rb') as f:
stats, indices, train_vocab, vocab = pickle.load(f)
indices = defaultdict(lambda: defaultdict(int), indices)

print("Learning BPE for vocab of {} tokens".format(preprocessing['vocab_size']))

Expand Down Expand Up @@ -336,15 +357,15 @@ def prepare():

# Save list of joins to a file (joined vocab) and replace main vocabs
if preprocessing['joined_vocab']:
with open('{}/{}'.format(preprocessing['train_folder'], 'bpe_joins.common.json'), 'w', encoding='utf-8', buffering=131072) as bpe_file:
with open('{}{}'.format(preprocessing['train_folder'], 'bpe_joins.common.json'), 'w', encoding='utf-8', buffering=131072) as bpe_file:
json.dump({json.dumps(k):v for k,v in joins[hparams['src']].items()}, bpe_file)
data_vocab[hparams['src']] = train_vocab[hparams['src']]

# Save list of joins to files (separated vocab)
else:
with open('{}/{}'.format(preprocessing['train_folder'], 'bpe_joins.{}.json'.format(hparams['src'])), 'w', encoding='utf-8', buffering=131072) as bpe_file:
with open('{}{}'.format(preprocessing['train_folder'], 'bpe_joins.{}.json'.format(hparams['src'])), 'w', encoding='utf-8', buffering=131072) as bpe_file:
json.dump({json.dumps(k):v for k,v in joins[hparams['src']].items()}, bpe_file)
with open('{}/{}'.format(preprocessing['train_folder'], 'bpe_joins.{}.json'.format(hparams['tgt'])), 'w', encoding='utf-8', buffering=131072) as bpe_file:
with open('{}{}'.format(preprocessing['train_folder'], 'bpe_joins.{}.json'.format(hparams['tgt'])), 'w', encoding='utf-8', buffering=131072) as bpe_file:
json.dump({json.dumps(k):v for k,v in joins[hparams['tgt']].items()}, bpe_file)
data_vocab[hparams['src']] = train_vocab[hparams['src']]
data_vocab[hparams['tgt']] = train_vocab[hparams['tgt']]
Expand All @@ -370,7 +391,7 @@ def prepare():
print("File: {}{}{}".format(colorama.Fore.GREEN, file_name, colorama.Fore.RESET))

# Output file handler
out_file = open('{}/{}'.format(preprocessing['train_folder'], file_name), 'w', encoding='utf-8', buffering=131072)
out_file = open('{}{}'.format(preprocessing['train_folder'], file_name), 'w', encoding='utf-8', buffering=131072)

# Prepare thread variables
write_thread = None
Expand All @@ -380,10 +401,10 @@ def prepare():
with Pool(processes=preprocessing['cpu_count'], initializer=apply_bpe_init, initargs=(joins[source],)) as pool:

# Progress bar
progress = tqdm(ascii=True, unit=' lines', total=sum(1 for _ in open('{}/{}'.format(preprocessing['train_folder'], file_name.replace('.bpe.', '.')), 'r', encoding='utf-8', buffering=131072)))
progress = tqdm(ascii=True, unit=' lines', total=sum(1 for _ in open('{}{}'.format(preprocessing['train_folder'], file_name.replace('.bpe.', '.')), 'r', encoding='utf-8', buffering=131072)))

# Open input file
with open('{}/{}'.format(preprocessing['train_folder'], file_name.replace('.bpe.', '.')), 'r', encoding='utf-8', buffering=131072) as in_file:
with open('{}{}'.format(preprocessing['train_folder'], file_name.replace('.bpe.', '.')), 'r', encoding='utf-8', buffering=131072) as in_file:

# Iterate every 10k lines
for rows in read_lines(in_file, 10000, ''):
Expand All @@ -409,7 +430,8 @@ def prepare():
progress.close()

# Remove unnecessary train file (BPE one will be used by NMT)
os.remove('{}/{}'.format(preprocessing['train_folder'], file_name.replace('.bpe.', '.')))
if not preprocessing['cache_preparation']:
os.remove('{}{}'.format(preprocessing['train_folder'], file_name.replace('.bpe.', '.')))

print(colorama.Fore.GREEN + "\nPostprocessing and saving vocabs" + colorama.Fore.RESET)

Expand All @@ -436,20 +458,20 @@ def prepare():
data_vocab[source] = [entity for entity, _ in data_vocab[source].most_common()]

# Write entities to a file
with open('{}/{}'.format(preprocessing['train_folder'], vocab_file_name), 'w', encoding='utf-8', buffering=131072) as vocab_file:
with open('{}{}'.format(preprocessing['train_folder'], vocab_file_name), 'w', encoding='utf-8', buffering=131072) as vocab_file:
vocab_file.write("<unk>\n<s>\n</s>\n" + "\n".join(data_vocab[source][:preprocessing['vocab_size']]))
with open('{}/{}'.format(preprocessing['train_folder'], vocab_file_name.replace('vocab', 'vocab_unused')), 'w', encoding='utf-8', buffering=131072) as vocab_file:
with open('{}{}'.format(preprocessing['train_folder'], vocab_file_name.replace('vocab', 'vocab_unused')), 'w', encoding='utf-8', buffering=131072) as vocab_file:
vocab_file.write("\n".join(data_vocab[source][preprocessing['vocab_size']:]))

print(colorama.Fore.GREEN + "\nWriting pbtxt file" + colorama.Fore.RESET)

# Write pbtxt file for metadata for embeddings
with open('{}/{}'.format(os.path.join(train_log_dir), 'projector_config.pbtxt'), 'w', encoding='utf-8', buffering=131072) as pbtxt_file:
with open(train_log_dir + 'projector_config.pbtxt', 'w', encoding='utf-8', buffering=131072) as pbtxt_file:
pbtxt_file.write(('''embeddings {{\n tensor_name: 'embeddings/decoder/embedding_decoder'\n '''+
'''metadata_path: '{}'\n}}\nembeddings {{\n '''+
'''tensor_name: 'embeddings/encoder/embedding_encoder'\n metadata_path: '{}'\n}}''').format(
'{}/{}'.format(preprocessing['train_folder'], vocab_files[0].replace('train', 'vocab')),
'{}/{}'.format(preprocessing['train_folder'], vocab_files[0 if preprocessing['joined_vocab'] else 1].replace('train', 'vocab'))
'{}{}'.format(preprocessing['train_folder'], vocab_files[0].replace('train', 'vocab')),
'{}{}'.format(preprocessing['train_folder'], vocab_files[0 if preprocessing['joined_vocab'] else 1].replace('train', 'vocab'))
))

print(colorama.Fore.GREEN + "\nAll done" + colorama.Fore.RESET)
Expand Down
7 changes: 5 additions & 2 deletions setup/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@
'test_size': 100,


## You don't need to change anything below (internal settings)
## You don't normally need to change anything below (internal settings)

# Cache 'prepairing training set' and 'building temporary vocab' steps
'cache_preparation': True,

# Source (raw) data folder
'source_folder': source_dir,
Expand Down Expand Up @@ -87,7 +90,7 @@
'num_translations_per_input': 20,
# 'num_keep_ckpts': 5,

## You don't need to change anything below (internal settings)
## You don't normally need to change anything below (internal settings)
'src': 'from',
'tgt': 'to',
'vocab_prefix': os.path.join(train_dir, "vocab"),
Expand Down

0 comments on commit e51a3ad

Please sign in to comment.