Skip to content

Commit

Permalink
add dmmiller612 greedy deecoder for colab
Browse files Browse the repository at this point in the history
  • Loading branch information
graykode committed Feb 3, 2019
1 parent 2eb3317 commit 7704dec
Show file tree
Hide file tree
Showing 3 changed files with 371 additions and 42 deletions.
65 changes: 24 additions & 41 deletions 5-1.Transformer/Transformer(Greedy_decoder)-Torch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
'''
code by Tae Hwan Jung(Jeff Jung) @graykode
code by Tae Hwan Jung(Jeff Jung) @graykode, Derek Miller @dmmiller612
Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch
https://github.com/JayParks/transformer
'''
Expand Down Expand Up @@ -32,14 +32,12 @@
n_layers = 6 # number of Encoder of Decoder Layer
n_heads = 8 # number of heads in Multi-Head Attention


def make_batch(sentences):
input_batch = [[src_vocab[n] for n in sentences[0].split()]]
output_batch = [[tgt_vocab[n] for n in sentences[1].split()]]
target_batch = [[tgt_vocab[n] for n in sentences[2].split()]]
return Variable(torch.LongTensor(input_batch)), Variable(torch.LongTensor(output_batch)), Variable(torch.LongTensor(target_batch))


def get_sinusoid_encoding_table(n_position, d_model):
def cal_angle(position, hid_idx):
return position / np.power(10000, 2 * (hid_idx // 2) / d_model)
Expand All @@ -51,14 +49,12 @@ def get_posi_angle_vec(position):
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table)


def get_attn_pad_mask(seq_q, seq_k):
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # batch_size x 1 x len_k(=len_q)
return pad_attn_mask.expand(batch_size, len_q, len_k) # batch_size x len_q x len_k


class ScaledDotProductAttention(nn.Module):

def __init__(self):
Expand All @@ -72,7 +68,6 @@ def forward(self, Q, K, V, attn_mask=None):
context = torch.matmul(attn, V)
return context, attn


class MultiHeadAttention(nn.Module):

def __init__(self):
Expand All @@ -97,7 +92,6 @@ def forward(self, Q, K, V, attn_mask=None):
output = nn.Linear(n_heads * d_v, d_model)(context)
return nn.LayerNorm(d_model)(output + residual), attn # output: [batch_size x len_q x d_model]


class PoswiseFeedForwardNet(nn.Module):

def __init__(self):
Expand All @@ -111,7 +105,6 @@ def forward(self, inputs):
output = self.conv2(output).transpose(1, 2)
return nn.LayerNorm(d_model)(output + residual)


class EncoderLayer(nn.Module):

def __init__(self):
Expand All @@ -124,7 +117,6 @@ def forward(self, enc_inputs):
enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
return enc_outputs, attn


class DecoderLayer(nn.Module):

def __init__(self):
Expand All @@ -139,7 +131,6 @@ def forward(self, dec_inputs, enc_outputs, enc_attn_mask, dec_attn_mask=None):
dec_outputs = self.pos_ffn(dec_outputs)
return dec_outputs, dec_self_attn, dec_enc_attn


class Encoder(nn.Module):

def __init__(self):
Expand All @@ -156,7 +147,6 @@ def forward(self, enc_inputs): # enc_inputs : [batch_size x source_len]
enc_self_attns.append(enc_self_attn)
return enc_outputs, enc_self_attns


class Decoder(nn.Module):

def __init__(self):
Expand All @@ -178,7 +168,6 @@ def forward(self, dec_inputs, enc_inputs, enc_outputs, dec_attn_mask=None): # de
dec_enc_attns.append(dec_enc_attn)
return dec_outputs, dec_self_attns, dec_enc_attns


class Transformer(nn.Module):

def __init__(self):
Expand All @@ -193,7 +182,6 @@ def forward(self, enc_inputs, dec_inputs, decoder_mask=None):
dec_logits = self.projection(dec_outputs) # dec_logits : [batch_size x src_vocab_size x tgt_vocab_size]
return dec_logits.view(-1, dec_logits.size(-1)), enc_self_attns, dec_self_attns, dec_enc_attns


def greedy_decoder(model, enc_input, start_symbol):
"""
For simplicity, a Greedy Decoder is Beam search when K=1. This is necessary for inference as we don't know the
Expand All @@ -218,7 +206,6 @@ def greedy_decoder(model, enc_input, start_symbol):
next_symbol = next_word[0]
return dec_input


def showgraph(attn):
attn = attn[-1].squeeze(0)[0]
attn = attn.squeeze(0).data.numpy()
Expand All @@ -229,36 +216,32 @@ def showgraph(attn):
ax.set_yticklabels(['']+sentences[2].split(), fontdict={'fontsize': 14})
plt.show()

model = Transformer()

if __name__ == '__main__':

model = Transformer()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(100):
optimizer.zero_grad()
enc_inputs, dec_inputs, target_batch = make_batch(sentences)
outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
loss = criterion(outputs, target_batch.contiguous().view(-1))
if (epoch + 1) % 20 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
loss.backward()
optimizer.step()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

for epoch in range(100):
optimizer.zero_grad()
enc_inputs, dec_inputs, target_batch = make_batch(sentences)
outputs, enc_self_attns, dec_self_attns, dec_enc_attns = model(enc_inputs, dec_inputs)
loss = criterion(outputs, target_batch.contiguous().view(-1))
if (epoch + 1) % 20 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
loss.backward()
optimizer.step()

# Test
greedy_dec_input = greedy_decoder(model, enc_inputs, start_symbol=4)
predict, _, _, _ = model(enc_inputs, greedy_dec_input)
predict = predict.data.max(1, keepdim=True)[1]
print(sentences[0], '->', [number_dict[n.item()] for n in predict.squeeze()])
# Test
greedy_dec_input = greedy_decoder(model, enc_inputs, start_symbol=4)
predict, _, _, _ = model(enc_inputs, greedy_dec_input)
predict = predict.data.max(1, keepdim=True)[1]
print(sentences[0], '->', [number_dict[n.item()] for n in predict.squeeze()])

print('first head of last state enc_self_attns')
showgraph(enc_self_attns)
print('first head of last state enc_self_attns')
showgraph(enc_self_attns)

print('first head of last state dec_self_attns')
showgraph(dec_self_attns)
print('first head of last state dec_self_attns')
showgraph(dec_self_attns)

print('first head of last state dec_enc_attns')
showgraph(dec_enc_attns)
print('first head of last state dec_enc_attns')
showgraph(dec_enc_attns)
Loading

0 comments on commit 7704dec

Please sign in to comment.