forked from jcjohnson/torch-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLanguageModel_test.lua
95 lines (78 loc) · 2.23 KB
/
LanguageModel_test.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
require 'torch'
require 'nn'
require 'LanguageModel'
local tests = {}
local tester = torch.Tester()
local function check_dims(x, dims)
tester:assert(x:dim() == #dims)
for i, d in ipairs(dims) do
tester:assert(x:size(i) == d)
end
end
-- Just a smoke test to make sure model can run forward / backward
function tests.simpleTest()
local N, T, D, H, V = 2, 3, 4, 5, 6
local idx_to_token = {[1]='a', [2]='b', [3]='c', [4]='d', [5]='e', [6]='f'}
local LM = nn.LanguageModel{
idx_to_token=idx_to_token,
model_type='rnn',
wordvec_size=D,
rnn_size=H,
num_layers=6,
dropout=0,
batchnorm=0,
}
local crit = nn.CrossEntropyCriterion()
local params, grad_params = LM:getParameters()
local x = torch.Tensor(N, T):random(V)
local y = torch.Tensor(N, T):random(V)
local scores = LM:forward(x)
check_dims(scores, {N, T, V})
local scores_view = scores:view(N * T, V)
local y_view = y:view(N * T)
local loss = crit:forward(scores_view, y_view)
local dscores = crit:backward(scores_view, y_view):view(N, T, V)
LM:backward(x, dscores)
end
function tests.sampleTest()
local N, T, D, H, V = 2, 3, 4, 5, 6
local idx_to_token = {[1]='a', [2]='b', [3]='c', [4]='d', [5]='e', [6]='f'}
local LM = nn.LanguageModel{
idx_to_token=idx_to_token,
model_type='rnn',
wordvec_size=D,
rnn_size=H,
num_layers=6,
dropout=0,
batchnorm=0,
}
local TT = 100
local start_text = 'bad'
local sampled = LM:sample{start_text=start_text, length=TT}
tester:assert(torch.type(sampled) == 'string')
tester:assert(string.len(sampled) == TT)
end
function tests.encodeDecodeTest()
local idx_to_token = {
[1]='a', [2]='b', [3]='c', [4]='d',
[5]='e', [6]='f', [7]='g', [8]=' ',
}
local N, T, D, H, V = 2, 3, 4, 5, 7
local LM = nn.LanguageModel{
idx_to_token=idx_to_token,
model_type='rnn',
wordvec_size=D,
rnn_size=H,
num_layers=6,
dropout=0,
batchnorm=0,
}
local s = 'a bad feed'
local encoded = LM:encode_string(s)
local expected_encoded = torch.LongTensor{1, 8, 2, 1, 4, 8, 6, 5, 5, 4}
tester:assert(torch.all(torch.eq(encoded, expected_encoded)))
local s2 = LM:decode_string(encoded)
tester:assert(s == s2)
end
tester:add(tests)
tester:run()