Skip to content

Commit

Permalink
Implementation of guided alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
jungikim committed Sep 27, 2016
1 parent ae8cd33 commit 894aef3
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 26 deletions.
63 changes: 59 additions & 4 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def make_vocab(srcfile, targetfile, seqlength, max_word_l=0, chars=0, train=1):
src_indexer.vocab[word] += 1
return max_word_l, num_sents

def convert(srcfile, targetfile, batchsize, seqlength, outfile, num_sents,
def convert(srcfile, targetfile, alignfile, batchsize, seqlength, outfile, num_sents,
max_word_l, max_sent_l=0,chars=0, unkfilter=0, shuffle=0):

def init_features_tensor(indexers):
Expand All @@ -188,6 +188,13 @@ def load_features(orig_features, indexers, seqlength):
return features

newseqlength = seqlength + 2 #add 2 for EOS and BOS

alignfile_hdl = None
alignments = None
if not alignfile == '':
alignfile_hdl = open(alignfile,'r')
alignments = np.zeros((num_sents,newseqlength,newseqlength), dtype=np.uint8)

targets = np.zeros((num_sents, newseqlength), dtype=int)
target_output = np.zeros((num_sents, newseqlength), dtype=int)
sources = np.zeros((num_sents, newseqlength), dtype=int)
Expand All @@ -211,6 +218,11 @@ def load_features(orig_features, indexers, seqlength):
targ = [target_indexer.BOS] + targ_orig.strip().split() + [target_indexer.EOS]
src = [src_indexer.BOS] + src_orig.strip().split() + [src_indexer.EOS]
max_sent_l = max(len(targ), len(src), max_sent_l)

align=[]
if alignfile_hdl:
align=alignfile_hdl.readline().strip().split(" ")

if len(targ) > newseqlength or len(src) > newseqlength or len(targ) < 3 or len(src) < 3:
dropped += 1
continue
Expand Down Expand Up @@ -269,6 +281,11 @@ def load_features(orig_features, indexers, seqlength):
for i in range(len(src_feature_indexers)):
sources_features[i][sent_id] = np.array(source_features[i], dtype=int)

if alignfile_hdl:
for pair in align:
aFrom, aTo = pair.split('-')
alignments[sent_id][int(aFrom) + 1][int(aTo) + 1] = 1

sent_id += 1
if sent_id % 100000 == 0:
print("{}/{} sentences processed".format(sent_id, num_sents))
Expand All @@ -279,6 +296,8 @@ def load_features(orig_features, indexers, seqlength):
targets = targets[rand_idx]
target_output = target_output[rand_idx]
sources = sources[rand_idx]
if alignments is not None:
alignments = alignments[rand_idx]
source_lengths = source_lengths[rand_idx]
target_lengths = target_lengths[rand_idx]
for i in range(len(sources_features)):
Expand All @@ -294,6 +313,8 @@ def load_features(orig_features, indexers, seqlength):
sources = sources[source_sort]
targets = targets[source_sort]
target_output = target_output[source_sort]
if alignments is not None:
alignments = alignments[source_sort]
target_l = target_lengths[source_sort]
source_l = source_lengths[source_sort]

Expand Down Expand Up @@ -332,6 +353,35 @@ def load_features(orig_features, indexers, seqlength):
f["source"] = sources
f["target"] = targets
f["target_output"] = target_output
if alignments is not None:
print "build alignment structure"
alignment_cc_val = []
alignment_cc_colidx = []
alignment_cc_sentidx = []
S={}
for k in range(sent_id-1):
alignment_cc_sentidx.append(len(alignment_cc_colidx))
for i in xrange(0, source_l[k]):
# for word i, build aligment vector as a string for indexing
a=''
maxnalign=0
# build a string representing the alignment vector
for j in xrange(0, newseqlength):
a=a+chr(ord('0')+int(alignments[k][i][j]))
# check if we have already built such column
if not a in S:
alignment_cc_colidx.append(len(alignment_cc_val))
S[a]=len(alignment_cc_val)
for j in xrange(0, newseqlength):
alignment_cc_val.append(alignments[k][i][j])
else:
alignment_cc_colidx.append(S[a])

assert(len(alignment_cc_colidx)<4294967296)
f["alignment_cc_sentidx"] = np.array(alignment_cc_sentidx, dtype=np.uint32)
f["alignment_cc_colidx"] = np.array(alignment_cc_colidx, dtype=np.uint32)
f["alignment_cc_val"] = np.array(alignment_cc_val, dtype=np.uint8)

f["target_l"] = np.array(target_l_max, dtype=int)
f["target_l_all"] = target_l
f["batch_l"] = np.array(batch_l, dtype=int)
Expand Down Expand Up @@ -401,10 +451,10 @@ def load_features(orig_features, indexers, seqlength):
len(target_indexer.d)))

max_sent_l = 0
max_sent_l = convert(args.srcvalfile, args.targetvalfile, args.batchsize, args.seqlength,
max_sent_l = convert(args.srcvalfile, args.targetvalfile, args.alignvalfile, args.batchsize, args.seqlength,
args.outputfile + "-val.hdf5", num_sents_valid,
max_word_l, max_sent_l, args.chars, args.unkfilter, args.shuffle)
max_sent_l = convert(args.srcfile, args.targetfile, args.batchsize, args.seqlength,
max_sent_l = convert(args.srcfile, args.targetfile, args.alignfile, args.batchsize, args.seqlength,
args.outputfile + "-train.hdf5", num_sents_train, max_word_l,
max_sent_l, args.chars, args.unkfilter, args.shuffle)

Expand Down Expand Up @@ -461,7 +511,12 @@ def main(arguments):
parser.add_argument('--shuffle', help="If = 1, shuffle sentences before sorting (based on "
"source length).",
type = int, default = 0)

parser.add_argument('--alignfile', help="Path to source-to-target alignment of training data, "
"where each line represents a set of alignments "
"per train instance.",
type = str, required=False, default='')
parser.add_argument('--alignvalfile', help="Path to source-to-target alignment of validation data",
type = str, required=False, default='')
args = parser.parse_args(arguments)
get_data(args)

Expand Down
49 changes: 47 additions & 2 deletions s2sa/data.lua
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,29 @@ function features_on_gpu(features)
return clone
end

-- using the sentences id, build the alignment tensor
function generate_aligns(batch_sent_idx, alignment_cc_colidx, alignment_cc_val, source_l, target_l, opt_start_symbol)
if batch_sent_idx == nil then
return nil
end
local batch_size = batch_sent_idx:size(1)

local src_offset = 0
if opt_start_symbol == 0 then
src_offset = 1
end

t = torch.Tensor(batch_size, source_l, target_l)
for k = 1, batch_size do
local sent_idx=batch_sent_idx[k]
for i = 0, source_l-1 do
t[k][i+1]:copy(alignment_cc_val:narrow(1, alignment_cc_colidx[sent_idx+1+i+src_offset]+1, target_l))
end
end

return t
end

local data = torch.class("data")

function data:__init(opt, data_file)
Expand Down Expand Up @@ -71,6 +94,12 @@ function data:__init(opt, data_file)
self.source_size = f:read('source_size'):all()[1]
self.target_nonzeros = f:read('target_nonzeros'):all()

if opt.guided_alignment == 1 then
self.alignment_cc_sentidx = f:read('alignment_cc_sentidx'):all()
self.alignment_cc_colidx = f:read('alignment_cc_colidx'):all()
self.alignment_cc_val = f:read('alignment_cc_val'):all()
end

if opt.use_chars_enc == 1 then
self.source_char = f:read('source_char'):all()
self.char_size = f:read('char_size'):all()[1]
Expand Down Expand Up @@ -140,6 +169,11 @@ function data:__init(opt, data_file)
-- convert table of timesteps per feature to a table of features per timestep
source_features_i = features_per_timestep(source_feats)

local alignment_i
if opt.guided_alignment == 1 then
alignment_i = self.alignment_cc_sentidx:sub(self.batch_idx[i], self.batch_idx[i]+self.batch_l[i]-1)
end

table.insert(self.batches, {target_i,
target_output_i:transpose(1,2),
self.target_nonzeros[i],
Expand All @@ -148,7 +182,8 @@ function data:__init(opt, data_file)
self.target_l[i],
self.source_l[i],
target_l_i,
source_features_i})
source_features_i,
alignment_i})
end
end

Expand All @@ -169,6 +204,13 @@ function data.__index(self, idx)
local source_l = self.batches[idx][7]
local target_l_all = self.batches[idx][8]
local source_features = self.batches[idx][9]
local alignment = generate_aligns(self.batches[idx][10],
self.alignment_cc_colidx,
self.alignment_cc_val,
source_l,
target_l,
opt.start_symbol)

if opt.gpuid >= 0 then --if multi-gpu, source lives in gpuid1, rest on gpuid2
cutorch.setDevice(opt.gpuid)
source_input = source_input:cuda()
Expand All @@ -179,9 +221,12 @@ function data.__index(self, idx)
target_input = target_input:cuda()
target_output = target_output:cuda()
target_l_all = target_l_all:cuda()
if opt.guided_alignment == 1 then
alignment = alignment:cuda()
end
end
return {target_input, target_output, nonzeros, source_input,
batch_l, target_l, source_l, target_l_all, source_features}
batch_l, target_l, source_l, target_l_all, source_features, alignment}
end
end

Expand Down
20 changes: 18 additions & 2 deletions s2sa/models.lua
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,15 @@ function make_lstm(data, opt, model, use_chars)
if model == 'dec' then
local top_h = outputs[#outputs]
local decoder_out
local attn_output
if opt.attn == 1 then
local decoder_attn = make_decoder_attn(data, opt)
decoder_attn.name = 'decoder_attn'
decoder_out = decoder_attn({top_h, inputs[2]})
if opt.guided_alignment == 1 then
decoder_out, attn_output = decoder_attn({top_h, inputs[2]}):split(2)
else
decoder_out = decoder_attn({top_h, inputs[2]})
end
else
decoder_out = nn.JoinTable(2)({top_h, inputs[2]})
decoder_out = nn.Tanh()(nn.LinearNoBias(opt.rnn_size*2, opt.rnn_size)(decoder_out))
Expand All @@ -154,6 +159,9 @@ function make_lstm(data, opt, model, use_chars)
(decoder_out)
end
table.insert(outputs, decoder_out)
if opt.guided_alignment == 1 then
table.insert(outputs, attn_output)
end
end
return nn.gModule(inputs, outputs)
end
Expand All @@ -178,6 +186,10 @@ function make_decoder_attn(data, opt, simple)
local softmax_attn = nn.SoftMax()
softmax_attn.name = 'softmax_attn'
attn = softmax_attn(attn)
local attn_output
if opt.guided_alignment == 1 then
attn_output = attn
end
attn = nn.Replicate(1,2)(attn) -- batch_l x 1 x source_l

-- apply attention to context
Expand All @@ -203,7 +215,11 @@ function make_decoder_attn(data, opt, simple)
{{opt.max_batch_l, opt.rnn_size}, {opt.max_batch_l, opt.rnn_size}})
({context_combined,inputs[1]})
end
return nn.gModule(inputs, {context_output})
if opt.guided_alignment == 1 then
return nn.gModule(inputs, {context_output, attn_output})
else
return nn.gModule(inputs, {context_output})
end
end

function make_generator(data, opt)
Expand Down
Loading

0 comments on commit 894aef3

Please sign in to comment.