Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyubyong authored Jun 17, 2017
1 parent dba8a11 commit 50b45ef
Show file tree
Hide file tree
Showing 8 changed files with 216,599 additions and 0 deletions.
126 changes: 126 additions & 0 deletions data_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# -*- coding: utf-8 -*-
# /usr/bin/python2
'''
By kyubyong park. [email protected].
https://www.github.com/kyubyong/transformer
'''
from __future__ import print_function
from hyperparams import Hyperparams as hp
import tensorflow as tf
import numpy as np
import codecs
import re, regex

def load_de_vocab():
vocab = [line.split()[0] for line in codecs.open('preprocessed/de.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt]
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}
return word2idx, idx2word

def load_en_vocab():
vocab = [line.split()[0] for line in codecs.open('preprocessed/en.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt]
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}
return word2idx, idx2word

def create_data(source_sents, target_sents):
de2idx, idx2de = load_de_vocab()
en2idx, idx2en = load_en_vocab()

# Index
x_list, y_list, Sources, Targets = [], [], [], []
for source_sent, target_sent in zip(source_sents, target_sents):
x = [de2idx.get(word, 1) for word in (source_sent + u" </S>").split()] # 3: OOV, ␃: End of Text
y = [en2idx.get(word, 1) for word in (target_sent + u" </S>").split()]
if max(len(x), len(y)) <= hp.maxlen and 1 not in x and 1 not in x:
x_list.append(np.array(x))
y_list.append(np.array(y))
Sources.append(source_sent)
Targets.append(target_sent)

# Pad
X = np.zeros([len(x_list), hp.maxlen], np.int32)
Y = np.zeros([len(y_list), hp.maxlen], np.int32)
for i, (x, y) in enumerate(zip(x_list, y_list)):
X[i] = np.lib.pad(x, [0, hp.maxlen-len(x)], 'constant', constant_values=(0, 0))
Y[i] = np.lib.pad(y, [0, hp.maxlen-len(y)], 'constant', constant_values=(0, 0))

print("X.shape =", X.shape)
print("Y.shape =", Y.shape)

return X, Y, Sources, Targets

# def create_eval_data(source_sents, target_sents):
# word2idx, idx2word = load_vocab()
#
# # Index
# x_list, y_list, Sources, Targets = [], [], [], []
# for source_sent, target_sent in zip(source_sents, target_sents):
# x = [word2idx.get(word, 3) for word in source_sent + u" </S>"] # 3: OOV, ␃: End of Text
# y = [word2idx.get(word, 3) for word in target_sent + u"␃"]
# if max(len(x), len(y)) <= hp.maxlen:
# x_list.append(np.array(x))
# y_list.append(np.array(y))
# Sources.append(source_sent)
# Targets.append(target_sent)
#
# # Pad
# X = np.zeros([len(x_list), hp.maxlen], np.int32)
# Y = np.zeros([len(y_list), hp.maxlen], np.int32)
# for i, (x, y) in enumerate(zip(x_list, y_list)):
# X[i] = np.lib.pad(x, [0, hp.maxlen-len(x)], 'constant', constant_values=(0, 0))
# Y[i] = np.lib.pad(y, [0, hp.maxlen-len(y)], 'constant', constant_values=(0, 0))
#
# print("X.shape =", X.shape)
# print("Y.shape =", Y.shape)
#
# return X, Y, Sources, Targets

def load_train_data():
de_sents = [regex.sub("[^\s\p{Latin}']", "", line) for line in codecs.open(hp.de_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"]
en_sents = [regex.sub("[^\s\p{Latin}']", "", line) for line in codecs.open(hp.en_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"]

X, Y, Sources, Targets = create_data(de_sents, en_sents)
return X, Y

def load_test_data():
def _remove_tags(line):
line = re.sub("<[^>]+>", "", line)
return line.strip()

# de_sents = [_remove_tags(line) for line in codecs.open(hp.de_test, 'r', 'utf-8').read().split("\n") if line and line[:4] == "<seg"]
# en_sents = [_remove_tags(line) for line in codecs.open(hp.en_test, 'r', 'utf-8').read().split("\n") if line and line[:4] == "<seg"]

if hp.sanity_check:
de_sents = [line for line in codecs.open(hp.de_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"]
en_sents = [line for line in codecs.open(hp.en_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"]
X, Y, Sources, Targets = create_data(de_sents, en_sents)
return X[:128], Sources[:128], Targets[:128]

X, Y, Sources, Targets = create_data(de_sents, en_sents)
return X, Sources, Targets # (1064, 150)

def get_batch_data():
# Load data
X, Y = load_train_data()

# calc total batch count
num_batch = len(X) // hp.batch_size

# Convert to tensor
X = tf.convert_to_tensor(X, tf.int32)
Y = tf.convert_to_tensor(Y, tf.int32)

# Create Queues
input_queues = tf.train.slice_input_producer([X, Y])

# create batch queues
x, y = tf.train.shuffle_batch(input_queues,
num_threads=8,
batch_size=hp.batch_size,
capacity=hp.batch_size*64,
min_after_dequeue=hp.batch_size*32,
allow_smaller_final_batch=False)

return x, y, num_batch # (N, T), (N, T), ()

84 changes: 84 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
'''
By kyubyong park. [email protected].
https://www.github.com/kyubyong/tacotron
'''

from __future__ import print_function
import codecs
import os

import tensorflow as tf
import numpy as np

from hyperparams import Hyperparams as hp
from data_load import *
from train import Graph
from nltk.translate.bleu_score import corpus_bleu

def eval():
# Load graph
g = Graph(is_training=False)
print("Graph loaded")

# Load data
X, Sources, Targets = load_test_data()
de2idx, idx2de = load_de_vocab()
en2idx, idx2en = load_en_vocab()

with g.graph.as_default():
sv = tf.train.Supervisor()
with sv.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
# Restore parameters
sv.saver.restore(sess, tf.train.latest_checkpoint(hp.logdir))
print("Restored!")

# Get model
mname = open(hp.logdir + '/checkpoint', 'r').read().split('"')[1] # model name

if not os.path.exists('results'): os.mkdir('results')
with codecs.open("results/" + mname, "w", "utf-8") as fout:
list_of_refs, hypotheses = [], []
for i in range(len(X) // hp.batch_size):

# Get mini-batches
x = X[i*hp.batch_size: (i+1)*hp.batch_size] # mini-batch
sources = Sources[i*hp.batch_size: (i+1)*hp.batch_size]
targets = Targets[i*hp.batch_size: (i+1)*hp.batch_size]


preds = np.zeros((hp.batch_size, hp.maxlen), np.int32)
for j in range(hp.maxlen):
_preds = sess.run(g.preds, {g.x: x, g.y: preds})
preds[:, j] = _preds[:, j]

# Write to file
for source, target, pred in zip(sources, targets, preds): # sentence-wise
got = "".join(idx2en[idx] for idx in pred)#.split(u"␃")[0]
print("==", pred, ">>>")
print("==", got, ">>>")
fout.write("- source: " + source +"\n")
fout.write("- expected: " + target + "\n")
fout.write("- got: " + got + "\n\n")
fout.flush()

# For bleu score
ref = target.split()
hypothesis = got.split()
if len(ref) > 3 and len(hypothesis) > 3:
list_of_refs.append([ref])


hypotheses.append(hypothesis)

# Get bleu score
score = corpus_bleu(list_of_refs, hypotheses)
fout.write("Bleu Score = " + str(100*score))



if __name__ == '__main__':
eval()
print("Done")


30 changes: 30 additions & 0 deletions hyperparams.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# /usr/bin/python2
'''
By kyubyong park. [email protected].
https://www.github.com/kyubyong/transformer
'''
class Hyperparams:
'''Hyperparameters'''
# data
de_train = 'corpora/train.tags.de-en.de'
en_train = 'corpora/train.tags.de-en.en'
de_test = 'corpora/IWSLT16.TED.tst2014.de-en.de.xml'
en_test = 'corpora/IWSLT16.TED.tst2014.de-en.en.xml'

# training
batch_size = 32 # alias = N
lr = 0.0005
logdir = 'logdir'

# model
maxlen = 10 # Maximum sentence length. alias = T
hidden_units = 512 # alias = C
num_blocks = 6
num_epochs = 200
num_heads = 8

sanity_check=True
min_cnt = 100


Loading

0 comments on commit 50b45ef

Please sign in to comment.