Skip to content

Commit

Permalink
add test_batch function - close graykode#69
Browse files Browse the repository at this point in the history
  • Loading branch information
karim-moon committed Jul 24, 2021
1 parent 8ec1aeb commit 402e315
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions 4-1.Seq2Seq/Seq2Seq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# %%
# code by Tae Hwan Jung @graykode
import argparse
import numpy as np
import torch
import torch.nn as nn
Expand All @@ -27,6 +26,19 @@ def make_batch():
# make tensor
return torch.FloatTensor(input_batch), torch.FloatTensor(output_batch), torch.LongTensor(target_batch)

# make test batch
def make_testbatch(input_word):
input_batch, output_batch = [], []

input_w = input_word + 'P' * (n_step - len(input_word))
input = [num_dic[n] for n in input_w]
output = [num_dic[n] for n in 'S' + 'P' * n_step]

input_batch = np.eye(n_class)[input]
output_batch = np.eye(n_class)[output]

return torch.FloatTensor(input_batch).unsqueeze(0), torch.FloatTensor(output_batch).unsqueeze(0)

# Model
class Seq2Seq(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -87,11 +99,11 @@ def forward(self, enc_input, enc_hidden, dec_input):
optimizer.step()

# Test
def translate(word, args):
input_batch, output_batch, _ = make_batch([[word, 'P' * len(word)]], args)
def translate(word):
input_batch, output_batch = make_testbatch(word)

# make hidden shape [num_layers * num_directions, batch_size, n_hidden]
hidden = torch.zeros(1, 1, args.n_hidden)
hidden = torch.zeros(1, 1, n_hidden)
output = model(input_batch, hidden, output_batch)
# output : [max_len+1(=6), batch_size(=1), n_class]

Expand Down

0 comments on commit 402e315

Please sign in to comment.