Skip to content

Commit

Permalink
Merge pull request graykode#50 from Yuhuishishishi/fix-bi-lstm-shape-…
Browse files Browse the repository at this point in the history
…comment

fix bi-LSTM hidden and cell state shape comments
  • Loading branch information
graykode authored Aug 13, 2020
2 parents 7a3c8d8 + 90dc12a commit f468977
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions 3-3.Bi-LSTM/Bi-LSTM-Torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(self):
def forward(self, X):
input = X.transpose(0, 1) # input : [n_step, batch_size, n_class]

hidden_state = Variable(torch.zeros(1*2, len(X), n_hidden)) # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
cell_state = Variable(torch.zeros(1*2, len(X), n_hidden)) # [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
hidden_state = Variable(torch.zeros(1*2, len(X), n_hidden)) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]
cell_state = Variable(torch.zeros(1*2, len(X), n_hidden)) # [num_layers(=1) * num_directions(=2), batch_size, n_hidden]

outputs, (_, _) = self.lstm(input, (hidden_state, cell_state))
outputs = outputs[-1] # [batch_size, n_hidden]
Expand All @@ -75,4 +75,4 @@ def forward(self, X):

predict = model(input_batch).data.max(1, keepdim=True)[1]
print(sentence)
print([number_dict[n.item()] for n in predict.squeeze()])
print([number_dict[n.item()] for n in predict.squeeze()])

0 comments on commit f468977

Please sign in to comment.