Skip to content

Commit

Permalink
graykode#18 modify wrong comment in Seq2Seq torch
Browse files Browse the repository at this point in the history
  • Loading branch information
graykode committed Mar 29, 2019
1 parent 35e4924 commit 11c5601
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions 4-1.Seq2Seq/Seq2Seq-Torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(self):
self.fc = nn.Linear(n_hidden, n_class)

def forward(self, enc_input, enc_hidden, dec_input):
enc_input = enc_input.transpose(0, 1) # enc_input: [max_len(=n_step, time step), batch_size, n_hidden]
dec_input = dec_input.transpose(0, 1) # dec_input: [max_len(=n_step, time step), batch_size, n_hidden]
enc_input = enc_input.transpose(0, 1) # enc_input: [max_len(=n_step, time step), batch_size, n_class]
dec_input = dec_input.transpose(0, 1) # dec_input: [max_len(=n_step, time step), batch_size, n_class]

# enc_states : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
_, enc_states = self.enc_cell(enc_input, enc_hidden)
Expand All @@ -71,8 +71,8 @@ def forward(self, enc_input, enc_hidden, dec_input):
hidden = Variable(torch.zeros(1, batch_size, n_hidden))

optimizer.zero_grad()
# input_batch : [batch_size, max_len(=n_step, time step), n_hidden]
# output_batch : [batch_size, max_len+1(=n_step, time step) (becase of 'S' or 'E'), n_hidden]
# input_batch : [batch_size, max_len(=n_step, time step), n_class]
# output_batch : [batch_size, max_len+1(=n_step, time step) (becase of 'S' or 'E'), n_class]
# target_batch : [batch_size, max_len+1(=n_step, time step)], not one-hot
output = model(input_batch, hidden, output_batch)
# output : [max_len+1, batch_size, num_directions(=1) * n_hidden]
Expand Down

0 comments on commit 11c5601

Please sign in to comment.