forked from Kyubyong/transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
216,599 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
# -*- coding: utf-8 -*- | ||
# /usr/bin/python2 | ||
''' | ||
By kyubyong park. [email protected]. | ||
https://www.github.com/kyubyong/transformer | ||
''' | ||
from __future__ import print_function | ||
from hyperparams import Hyperparams as hp | ||
import tensorflow as tf | ||
import numpy as np | ||
import codecs | ||
import re, regex | ||
|
||
def load_de_vocab(): | ||
vocab = [line.split()[0] for line in codecs.open('preprocessed/de.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt] | ||
word2idx = {word: idx for idx, word in enumerate(vocab)} | ||
idx2word = {idx: word for idx, word in enumerate(vocab)} | ||
return word2idx, idx2word | ||
|
||
def load_en_vocab(): | ||
vocab = [line.split()[0] for line in codecs.open('preprocessed/en.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt] | ||
word2idx = {word: idx for idx, word in enumerate(vocab)} | ||
idx2word = {idx: word for idx, word in enumerate(vocab)} | ||
return word2idx, idx2word | ||
|
||
def create_data(source_sents, target_sents): | ||
de2idx, idx2de = load_de_vocab() | ||
en2idx, idx2en = load_en_vocab() | ||
|
||
# Index | ||
x_list, y_list, Sources, Targets = [], [], [], [] | ||
for source_sent, target_sent in zip(source_sents, target_sents): | ||
x = [de2idx.get(word, 1) for word in (source_sent + u" </S>").split()] # 3: OOV, ␃: End of Text | ||
y = [en2idx.get(word, 1) for word in (target_sent + u" </S>").split()] | ||
if max(len(x), len(y)) <= hp.maxlen and 1 not in x and 1 not in x: | ||
x_list.append(np.array(x)) | ||
y_list.append(np.array(y)) | ||
Sources.append(source_sent) | ||
Targets.append(target_sent) | ||
|
||
# Pad | ||
X = np.zeros([len(x_list), hp.maxlen], np.int32) | ||
Y = np.zeros([len(y_list), hp.maxlen], np.int32) | ||
for i, (x, y) in enumerate(zip(x_list, y_list)): | ||
X[i] = np.lib.pad(x, [0, hp.maxlen-len(x)], 'constant', constant_values=(0, 0)) | ||
Y[i] = np.lib.pad(y, [0, hp.maxlen-len(y)], 'constant', constant_values=(0, 0)) | ||
|
||
print("X.shape =", X.shape) | ||
print("Y.shape =", Y.shape) | ||
|
||
return X, Y, Sources, Targets | ||
|
||
# def create_eval_data(source_sents, target_sents): | ||
# word2idx, idx2word = load_vocab() | ||
# | ||
# # Index | ||
# x_list, y_list, Sources, Targets = [], [], [], [] | ||
# for source_sent, target_sent in zip(source_sents, target_sents): | ||
# x = [word2idx.get(word, 3) for word in source_sent + u" </S>"] # 3: OOV, ␃: End of Text | ||
# y = [word2idx.get(word, 3) for word in target_sent + u"␃"] | ||
# if max(len(x), len(y)) <= hp.maxlen: | ||
# x_list.append(np.array(x)) | ||
# y_list.append(np.array(y)) | ||
# Sources.append(source_sent) | ||
# Targets.append(target_sent) | ||
# | ||
# # Pad | ||
# X = np.zeros([len(x_list), hp.maxlen], np.int32) | ||
# Y = np.zeros([len(y_list), hp.maxlen], np.int32) | ||
# for i, (x, y) in enumerate(zip(x_list, y_list)): | ||
# X[i] = np.lib.pad(x, [0, hp.maxlen-len(x)], 'constant', constant_values=(0, 0)) | ||
# Y[i] = np.lib.pad(y, [0, hp.maxlen-len(y)], 'constant', constant_values=(0, 0)) | ||
# | ||
# print("X.shape =", X.shape) | ||
# print("Y.shape =", Y.shape) | ||
# | ||
# return X, Y, Sources, Targets | ||
|
||
def load_train_data(): | ||
de_sents = [regex.sub("[^\s\p{Latin}']", "", line) for line in codecs.open(hp.de_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"] | ||
en_sents = [regex.sub("[^\s\p{Latin}']", "", line) for line in codecs.open(hp.en_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"] | ||
|
||
X, Y, Sources, Targets = create_data(de_sents, en_sents) | ||
return X, Y | ||
|
||
def load_test_data(): | ||
def _remove_tags(line): | ||
line = re.sub("<[^>]+>", "", line) | ||
return line.strip() | ||
|
||
# de_sents = [_remove_tags(line) for line in codecs.open(hp.de_test, 'r', 'utf-8').read().split("\n") if line and line[:4] == "<seg"] | ||
# en_sents = [_remove_tags(line) for line in codecs.open(hp.en_test, 'r', 'utf-8').read().split("\n") if line and line[:4] == "<seg"] | ||
|
||
if hp.sanity_check: | ||
de_sents = [line for line in codecs.open(hp.de_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"] | ||
en_sents = [line for line in codecs.open(hp.en_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"] | ||
X, Y, Sources, Targets = create_data(de_sents, en_sents) | ||
return X[:128], Sources[:128], Targets[:128] | ||
|
||
X, Y, Sources, Targets = create_data(de_sents, en_sents) | ||
return X, Sources, Targets # (1064, 150) | ||
|
||
def get_batch_data(): | ||
# Load data | ||
X, Y = load_train_data() | ||
|
||
# calc total batch count | ||
num_batch = len(X) // hp.batch_size | ||
|
||
# Convert to tensor | ||
X = tf.convert_to_tensor(X, tf.int32) | ||
Y = tf.convert_to_tensor(Y, tf.int32) | ||
|
||
# Create Queues | ||
input_queues = tf.train.slice_input_producer([X, Y]) | ||
|
||
# create batch queues | ||
x, y = tf.train.shuffle_batch(input_queues, | ||
num_threads=8, | ||
batch_size=hp.batch_size, | ||
capacity=hp.batch_size*64, | ||
min_after_dequeue=hp.batch_size*32, | ||
allow_smaller_final_batch=False) | ||
|
||
return x, y, num_batch # (N, T), (N, T), () | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# -*- coding: utf-8 -*- | ||
''' | ||
By kyubyong park. [email protected]. | ||
https://www.github.com/kyubyong/tacotron | ||
''' | ||
|
||
from __future__ import print_function | ||
import codecs | ||
import os | ||
|
||
import tensorflow as tf | ||
import numpy as np | ||
|
||
from hyperparams import Hyperparams as hp | ||
from data_load import * | ||
from train import Graph | ||
from nltk.translate.bleu_score import corpus_bleu | ||
|
||
def eval(): | ||
# Load graph | ||
g = Graph(is_training=False) | ||
print("Graph loaded") | ||
|
||
# Load data | ||
X, Sources, Targets = load_test_data() | ||
de2idx, idx2de = load_de_vocab() | ||
en2idx, idx2en = load_en_vocab() | ||
|
||
with g.graph.as_default(): | ||
sv = tf.train.Supervisor() | ||
with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: | ||
# Restore parameters | ||
sv.saver.restore(sess, tf.train.latest_checkpoint(hp.logdir)) | ||
print("Restored!") | ||
|
||
# Get model | ||
mname = open(hp.logdir + '/checkpoint', 'r').read().split('"')[1] # model name | ||
|
||
if not os.path.exists('results'): os.mkdir('results') | ||
with codecs.open("results/" + mname, "w", "utf-8") as fout: | ||
list_of_refs, hypotheses = [], [] | ||
for i in range(len(X) // hp.batch_size): | ||
|
||
# Get mini-batches | ||
x = X[i*hp.batch_size: (i+1)*hp.batch_size] # mini-batch | ||
sources = Sources[i*hp.batch_size: (i+1)*hp.batch_size] | ||
targets = Targets[i*hp.batch_size: (i+1)*hp.batch_size] | ||
|
||
|
||
preds = np.zeros((hp.batch_size, hp.maxlen), np.int32) | ||
for j in range(hp.maxlen): | ||
_preds = sess.run(g.preds, {g.x: x, g.y: preds}) | ||
preds[:, j] = _preds[:, j] | ||
|
||
# Write to file | ||
for source, target, pred in zip(sources, targets, preds): # sentence-wise | ||
got = "".join(idx2en[idx] for idx in pred)#.split(u"␃")[0] | ||
print("==", pred, ">>>") | ||
print("==", got, ">>>") | ||
fout.write("- source: " + source +"\n") | ||
fout.write("- expected: " + target + "\n") | ||
fout.write("- got: " + got + "\n\n") | ||
fout.flush() | ||
|
||
# For bleu score | ||
ref = target.split() | ||
hypothesis = got.split() | ||
if len(ref) > 3 and len(hypothesis) > 3: | ||
list_of_refs.append([ref]) | ||
|
||
|
||
hypotheses.append(hypothesis) | ||
|
||
# Get bleu score | ||
score = corpus_bleu(list_of_refs, hypotheses) | ||
fout.write("Bleu Score = " + str(100*score)) | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
eval() | ||
print("Done") | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# -*- coding: utf-8 -*- | ||
# /usr/bin/python2 | ||
''' | ||
By kyubyong park. [email protected]. | ||
https://www.github.com/kyubyong/transformer | ||
''' | ||
class Hyperparams: | ||
'''Hyperparameters''' | ||
# data | ||
de_train = 'corpora/train.tags.de-en.de' | ||
en_train = 'corpora/train.tags.de-en.en' | ||
de_test = 'corpora/IWSLT16.TED.tst2014.de-en.de.xml' | ||
en_test = 'corpora/IWSLT16.TED.tst2014.de-en.en.xml' | ||
|
||
# training | ||
batch_size = 32 # alias = N | ||
lr = 0.0005 | ||
logdir = 'logdir' | ||
|
||
# model | ||
maxlen = 10 # Maximum sentence length. alias = T | ||
hidden_units = 512 # alias = C | ||
num_blocks = 6 | ||
num_epochs = 200 | ||
num_heads = 8 | ||
|
||
sanity_check=True | ||
min_cnt = 100 | ||
|
||
|
Oops, something went wrong.