Skip to content

Commit

Permalink
reverting cudnn decoder to lstmcell
Browse files Browse the repository at this point in the history
  • Loading branch information
bmccann authored and soumith committed Mar 14, 2017
1 parent bf82a7b commit c90842c
Showing 1 changed file with 35 additions and 9 deletions.
44 changes: 35 additions & 9 deletions OpenNMT/onmt/Models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,34 @@ def forward(self, input, hidden=None):
return hidden_t, outputs


class StackedLSTM(nn.Module):
def __init__(self, num_layers, input_size, rnn_size, dropout):
super(StackedLSTM, self).__init__()
self.dropout = nn.Dropout(dropout)
self.num_layers = num_layers
self.layers = nn.ModuleList()

for i in range(num_layers):
self.layers.append(nn.LSTMCell(input_size, rnn_size))
input_size = rnn_size

def forward(self, input, hidden):
h_0, c_0 = hidden
h_1, c_1 = [], []
for i, layer in enumerate(self.layers):
h_1_i, c_1_i = layer(input, (h_0[i], c_0[i]))
input = h_1_i
if i != self.num_layers:
input = self.dropout(input)
h_1 += [h_1_i]
c_1 += [c_1_i]

h_1 = torch.stack(h_1)
c_1 = torch.stack(c_1)

return input, (h_1, c_1)


class Decoder(nn.Module):

def __init__(self, opt, dicts):
Expand All @@ -51,9 +79,7 @@ def __init__(self, opt, dicts):
self.word_lut = nn.Embedding(dicts.size(),
opt.word_vec_size,
padding_idx=onmt.Constants.PAD)
self.rnn = nn.LSTM(input_size, opt.rnn_size,
num_layers=opt.layers,
dropout=opt.dropout)
self.rnn = StackedLSTM(opt.layers, input_size, opt.rnn_size, opt.dropout)
self.attn = onmt.modules.GlobalAttention(opt.rnn_size)
self.dropout = nn.Dropout(opt.dropout)

Expand All @@ -77,16 +103,16 @@ def forward(self, input, hidden, context, init_output):
outputs = []
output = init_output
for i, emb_t in enumerate(emb.split(1)):
emb_t = emb_t
emb_t = emb_t.squeeze(0)
if self.input_feed:
emb_t = torch.cat([emb_t, output], 2)
emb_t = torch.cat([emb_t, output], 1)

output, hidden = self.rnn(emb_t, hidden)
output, attn = self.attn(output.squeeze(0), context.t())
output = self.dropout(output.unsqueeze(0))
output, attn = self.attn(output, context.t())
output = self.dropout(output)
outputs += [output]

outputs = torch.cat(outputs, 0)
outputs = torch.stack(outputs)
return outputs.transpose(0, 1), hidden, attn


Expand All @@ -105,7 +131,7 @@ def set_generate(self, enabled):
def make_init_decoder_output(self, context):
batch_size = context.size(1)
h_size = (batch_size, self.decoder.hidden_size)
return Variable(context.data.new(1, *h_size).zero_(), requires_grad=False)
return Variable(context.data.new(*h_size).zero_(), requires_grad=False)

def _fix_enc_hidden(self, h):
# the encoder hidden is (layers*directions) x batch x dim
Expand Down

0 comments on commit c90842c

Please sign in to comment.