Skip to content

Commit

Permalink
refactor preallocation and generalize to all nodes in the LSTM - can …
Browse files Browse the repository at this point in the history
…be disabled by option, in that case it rolls back to initial memory optimization - provides an additional 50% memory decrease.

for RNN SIZE 600, WE 500, 4 LAYERS, ATTN=1, DROPOUT, MAX BATCH SIZE=64, MAX SOURCE/TARGET SENT L=50 - the max memory usage observed is:
- initial: 6877MiB
- intermediate: 5243MiB
- current version: 2577MiB
  • Loading branch information
Jean A. Senellart committed Sep 4, 2016
1 parent e25c7f5 commit 4258944
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 90 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ GPU and the decoder is on the second GPU. This will allow you to train bigger mo
has much faster convolutions so this is highly recommended if using the character model.
* `save_every`: Save every this many epochs.
* `print_every`: Print various stats after this many batches.
* `seed`: Change the random seed for random numbers in torch - use that option to train alternate models for ensemble
* `prealloc`: when set to 1 (default), enable memory preallocation and sharing between clones - this reduces by a lot the used memory - there should not be
any situation where you don't need it. Also - since memory is preallocated, there is not (major)
memory increase during the training. When set to 0, it rolls back to original memory optimization.

#### Decoding options (`beam.lua`)

Expand Down
65 changes: 53 additions & 12 deletions s2sa/memory.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
-- use :reuseMem() on the module to allow the feature
-- then apply setReuse after initialization
-- only applies if output and gradinput are of the same type
function nn.Module:reuseMem()
function nn.Module:reuseMem(name)
self.reuse = true
return self
end

function nn.Module:setReuse()
if self.reuse then
assert(type(self.output) == type(self.gradInput), "invalid use of reuseMem")
assert(type(self.output) == type(self.gradInput), "invalid use of reuseMem:")
self.gradInput = self.output
end
return self
Expand All @@ -20,30 +20,71 @@ end
-- usePrealloc is based on the same principle but use pre-allocated memory at the beginning of the process that can be shared
-- between different objects
-- use to prellocate gradInput, or output - useful for intermediate calculations working on large input
preallocWarning = {}
preallocTable = {}
preallocTable = nil

function nn.Module:usePrealloc(preallocName)
function preallocateMemory(switch)
if switch == 1 then
preallocTable = {}
print('Switching on memory preallocation')
end
end

function preallocateTensor(name,D)
if #D > 1 then
local T={}
for i=1,#D do
table.insert(T,preallocateTensor(name,{D[i]}))
end
return T
else
D = D[1]
end
local t=torch.zeros(torch.LongStorage(D))
if opt.gpuid >= 0 then
if opt.gpuid2 >= 0 and string.sub(name,1,"4") == "dec_" then
cutorch.setDevice(opt.gpuid2)
else
cutorch.setDevice(opt.gpuid)
end
t=t:cuda()
end
return t
end

-- enable reuseMemory - if preallocation disable, then switched back to reuseMem checking for 'reuse' in name
function nn.Module:usePrealloc(preallocName, inputDim, outputDim)
if preallocTable == nil then
if string.find(preallocName, "reuse") then
self:reuseMem()
end
return self;
end
self.prealloc = preallocName
self.name = preallocName
self.preallocInputDim = inputDim
self.preallocOutputDim = outputDim
return self
end

function nn.Module:setPrealloc()
if self.prealloc then
if self.prealloc and (self.preallocInputDim ~= nil or self.preallocOutputDim ~= nil) then
if preallocTable[self.prealloc] == nil then
if not(preallocWarning[self.prealloc]) then
print('WARNING: no prealloc memory defined for \'' .. self.prealloc .. '\'')
preallocWarning[self.prealloc] = 1
preallocTable[self.prealloc] = {
}
if self.preallocInputDim ~= nil then
preallocTable[self.prealloc].GI = preallocateTensor(self.prealloc, self.preallocInputDim)
end
if self.preallocOutputDim ~= nil then
preallocTable[self.prealloc].O = preallocateTensor(self.prealloc, self.preallocOutputDim)
end
return
end
local memmap = preallocTable[self.prealloc]
if memmap["GI"] ~= nil then
assert(type(self.gradInput) == type(memmap.GI), "invalid use of usePrealloc")
assert(type(self.gradInput) == type(memmap.GI), "invalid use of usePrealloc ["..self.prealloc.."]/GI: "..type(self.gradInput).."/"..type(memmap.GI))
self.gradInput = memmap["GI"]
end
if memmap["O"] ~= nil then
assert(type(self.output) == type(memmap.O), "invalid use of usePrealloc")
assert(type(self.output) == type(memmap.O), "invalid use of usePrealloc ["..self.prealloc.."]/O:"..type(self.output).."/"..type(memmap.O))
self.output = memmap["O"]
end
end
Expand Down
129 changes: 57 additions & 72 deletions s2sa/models.lua
Original file line number Diff line number Diff line change
@@ -1,61 +1,13 @@
require 's2sa.util'
require 's2sa.memory'

-- the actual pre-allocation of memory
preallocTable = {}

function preallocateMemory(opt)
print('Preallocating memory...')
if opt.attn then
preallocTable["DEC_ATTN_MM1"] = {
GI = {
torch.zeros(opt.max_batch_l, opt.max_sent_l_src, opt.rnn_size),
torch.zeros(opt.rnn_size, opt.rnn_size, 1)
}
}
preallocTable["DEC_ATTN_MM2"] = {
GI = {
torch.zeros(opt.max_batch_l, 1, opt.max_sent_l_src),
torch.zeros(opt.max_batch_l, opt.max_sent_l_src, opt.rnn_size)
}
}
end
-- move on GPU according to gpuid, gpuid2 settings
if opt.gpuid >= 0 then
for k,t in pairs(preallocTable) do
if opt.gpuid2 >= 0 and string.sub(k,1,"4") == "DEC_" then
cutorch.setDevice(opt.gpuid2)
else
cutorch.setDevice(opt.gpuid)
end
if t.GI then
if type(t.GI) == "table" then
for i = 1,#t.GI do
t.GI[i] = t.GI[i]:cuda()
end
else
t.GI = t.GI:cuda()
end
end
if t.O then
if type(t.O) == "table" then
for i = 1,#t.O do
t.O[i] = t.O[i]:cuda()
end
else
t.O = t.O:cuda()
end
end
end
end
end

function make_lstm(data, opt, model, use_chars)
assert(model == 'enc' or model == 'dec')
local name = '_' .. model
local dropout = opt.dropout or 0
local n = opt.num_layers
local rnn_size = opt.rnn_size
local RnnD={opt.rnn_size,opt.rnn_size}
local input_size
if use_chars == 0 then
input_size = opt.word_vec_size
Expand All @@ -82,6 +34,7 @@ function make_lstm(data, opt, model, use_chars)
local x, input_size_L
local outputs = {}
for L = 1,n do
local nameL=model..'_L'..L..'_'
-- c,h from previous timesteps
local prev_c = inputs[L*2+offset]
local prev_h = inputs[L*2+1+offset]
Expand Down Expand Up @@ -111,8 +64,10 @@ function make_lstm(data, opt, model, use_chars)
input_size_L = input_size
if model == 'dec' then
if opt.input_feed == 1 then
x = nn.JoinTable(2)({x, inputs[1+offset]}) -- batch_size x (word_vec_size + rnn_size)
input_size_L = input_size + rnn_size
x = nn.JoinTable(2):usePrealloc("dec_inputfeed_join",
{{opt.max_batch_l, opt.word_vec_size},{opt.max_batch_l, opt.rnn_size}})
({x, inputs[1+offset]}) -- batch_size x (word_vec_size + rnn_size)
input_size_L = input_size_L + rnn_size
end
end
else
Expand All @@ -127,29 +82,43 @@ function make_lstm(data, opt, model, use_chars)
x = multi_attn({x, inputs[2]})
end
if dropout > 0 then
x = nn.Dropout(dropout, nil, false)(x)
x = nn.Dropout(dropout, nil, false):usePrealloc(nameL.."dropout",
{{opt.max_batch_l, input_size_L}})
(x)
end
end
-- evaluate the input sums at once for efficiency
local i2h = nn.Linear(input_size_L, 4 * rnn_size):reuseMem()(x)
local h2h = nn.LinearNoBias(rnn_size, 4 * rnn_size):reuseMem()(prev_h)
local all_input_sums = nn.CAddTable()({i2h, h2h})
local i2h = nn.Linear(input_size_L, 4 * rnn_size):usePrealloc(nameL.."i2h-reuse",
{{opt.max_batch_l, input_size_L}},
{{opt.max_batch_l, 4 * rnn_size}})
(x)
local h2h = nn.LinearNoBias(rnn_size, 4 * rnn_size):usePrealloc(nameL.."h2h-reuse",
{{opt.max_batch_l, rnn_size}},
{{opt.max_batch_l, 4 * rnn_size}})
(prev_h)
local all_input_sums = nn.CAddTable():usePrealloc(nameL.."allinput",
{{opt.max_batch_l, 4*rnn_size},{opt.max_batch_l, 4*rnn_size}},
{{opt.max_batch_l, 4 * rnn_size}})
({i2h, h2h})

local reshaped = nn.Reshape(4, rnn_size)(all_input_sums)
local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4)
local n1, n2, n3, n4 = nn.SplitTable(2):usePrealloc(nameL.."reshapesplit",
{{opt.max_batch_l, 4, rnn_size}})
(reshaped):split(4)
-- decode the gates
local in_gate = nn.Sigmoid():reuseMem()(n1)
local forget_gate = nn.Sigmoid():reuseMem()(n2)
local out_gate = nn.Sigmoid():reuseMem()(n3)
local in_gate = nn.Sigmoid():usePrealloc(nameL.."G1-reuse",{RnnD})(n1)
local forget_gate = nn.Sigmoid():usePrealloc(nameL.."G2-reuse",{RnnD})(n2)
local out_gate = nn.Sigmoid():usePrealloc(nameL.."G3-reuse",{RnnD})(n3)
-- decode the write inputs
local in_transform = nn.Tanh():reuseMem()(n4)
local in_transform = nn.Tanh():usePrealloc(nameL.."G4-reuse",{RnnD})(n4)
-- perform the LSTM update
local next_c = nn.CAddTable()({
nn.CMulTable()({forget_gate, prev_c}),
nn.CMulTable()({in_gate, in_transform})
local next_c = nn.CAddTable():usePrealloc(nameL.."G5a",{RnnD,RnnD})({
nn.CMulTable():usePrealloc(nameL.."G5b",{RnnD,RnnD})({forget_gate, prev_c}),
nn.CMulTable():usePrealloc(nameL.."G5c",{RnnD,RnnD})({in_gate, in_transform})
})
-- gated cells form the output
local next_h = nn.CMulTable()({out_gate, nn.Tanh():reuseMem()(next_c)})
local next_h = nn.CMulTable():usePrealloc(nameL.."G5d",{RnnD,RnnD})
({out_gate, nn.Tanh():usePrealloc(nameL.."G6-reuse",{RnnD})(next_c)})

table.insert(outputs, next_c)
table.insert(outputs, next_h)
Expand All @@ -166,7 +135,8 @@ function make_lstm(data, opt, model, use_chars)
decoder_out = nn.Tanh()(nn.LinearNoBias(opt.rnn_size*2, opt.rnn_size)(decoder_out))
end
if dropout > 0 then
decoder_out = nn.Dropout(dropout, nil, false)(decoder_out)
decoder_out = nn.Dropout(dropout, nil, false):usePrealloc("dec_dropout",{RnnD})
(decoder_out)
end
table.insert(outputs, decoder_out)
end
Expand All @@ -185,23 +155,38 @@ function make_decoder_attn(data, opt, simple)
simple = simple or 0
-- get attention

local attn = nn.MM():usePrealloc("DEC_ATTN_MM1")({context, nn.Replicate(1,3)(target_t)}) -- batch_l x source_l x 1
local attn = nn.MM():usePrealloc("dec_attn_mm1",
{{opt.max_batch_l, opt.max_sent_l_src, opt.rnn_size},{opt.rnn_size, opt.rnn_size, 1}},
{{opt.max_batch_l, opt.max_sent_l_src, 1}})
({context, nn.Replicate(1,3)(target_t)}) -- batch_l x source_l x 1
attn = nn.Sum(3)(attn)
local softmax_attn = nn.SoftMax()
softmax_attn.name = 'softmax_attn'
attn = softmax_attn(attn)
attn = nn.Replicate(1,2)(attn) -- batch_l x 1 x source_l

-- apply attention to context
local context_combined = nn.MM():usePrealloc("DEC_ATTN_MM2")({attn, context}) -- batch_l x 1 x rnn_size
context_combined = nn.Sum(2)(context_combined) -- batch_l x rnn_size
local context_combined = nn.MM():usePrealloc("dec_attn_mm2",
{{opt.max_batch_l, 1, opt.max_sent_l_src},{opt.max_batch_l, opt.max_sent_l_src, opt.rnn_size}},
{{opt.max_batch_l, 1, opt.rnn_size}})
({attn, context}) -- batch_l x 1 x rnn_size
context_combined = nn.Sum(2):usePrealloc("dec_attn_sum",
{{opt.max_batch_l, 1, opt.rnn_size}},
{{opt.max_batch_l, opt.rnn_size}})
(context_combined) -- batch_l x rnn_size
local context_output
if simple == 0 then
context_combined = nn.JoinTable(2)({context_combined, inputs[1]}) -- batch_l x rnn_size*2
context_output = nn.Tanh()(nn.LinearNoBias(opt.rnn_size*2,
opt.rnn_size)(context_combined))
context_combined = nn.JoinTable(2):usePrealloc("dec_attn_jointable",
{{opt.max_batch_l,opt.rnn_size},{opt.max_batch_l, opt.rnn_size}})
({context_combined, inputs[1]}) -- batch_l x rnn_size*2
context_output = nn.Tanh():usePrealloc("dec_noattn_tanh",{{opt.max_batch_l,opt.rnn_size}})
(nn.LinearNoBias(opt.rnn_size*2,opt.rnn_size):usePrealloc("dec_noattn_linear",
{{opt.max_batch_l,2*opt.rnn_size}})
(context_combined))
else
context_output = nn.CAddTable()({context_combined,inputs[1]})
context_output = nn.CAddTable():usePrealloc("dec_attn_caddtable1",
{{opt.max_batch_l, opt.rnn_size}, {opt.max_batch_l, opt.rnn_size}})
({context_combined,inputs[1]})
end
return nn.gModule(inputs, {context_output})
end
Expand Down
13 changes: 7 additions & 6 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ cmd:option('-cudnn', 0, [[Whether to use cudnn or not for convolutions (for the
cmd:option('-save_every', 1, [[Save every this many epochs]])
cmd:option('-print_every', 50, [[Print stats after this many batches]])
cmd:option('-seed', 3435, [[Seed for random initialization]])
cmd:option('-prealloc', 1, [[Use memory preallocation and sharing between cloned encoder/decoders]])

function zero_table(t)
for i = 1, #t do
Expand Down Expand Up @@ -212,17 +213,17 @@ function train(train_data, valid_data)
for i = 1, opt.max_sent_l_src do
if encoder_clones[i].apply then
encoder_clones[i]:apply(function(m) m:setReuse() end)
encoder_clones[i]:apply(function(m) m:setPrealloc() end)
if opt.prealloc == 1 then encoder_clones[i]:apply(function(m) m:setPrealloc() end) end
end
if opt.brnn == 1 then
encoder_bwd_clones[i]:apply(function(m) m:setReuse() end)
encoder_bwd_clones[i]:apply(function(m) m:setPrealloc() end)
if opt.prealloc == 1 then encoder_bwd_clones[i]:apply(function(m) m:setPrealloc() end) end
end
end
for i = 1, opt.max_sent_l_targ do
if decoder_clones[i].apply then
decoder_clones[i]:apply(function(m) m:setReuse() end)
decoder_clones[i]:apply(function(m) m:setPrealloc() end)
if opt.prealloc == 1 then decoder_clones[i]:apply(function(m) m:setPrealloc() end) end
end
end

Expand Down Expand Up @@ -808,6 +809,9 @@ function main()
print(string.format('Source max sent len: %d, Target max sent len: %d',
valid_data.source:size(2), valid_data.target:size(2)))

-- Enable memory preallocation - see memory.lua
preallocateMemory(opt.prealloc)

-- Build model
if opt.train_from:len() == 0 then
encoder = make_lstm(valid_data, opt, 'enc', opt.use_chars_enc)
Expand Down Expand Up @@ -835,9 +839,6 @@ function main()
_, criterion = make_generator(valid_data, opt)
end

-- call memory pre-allocation
preallocateMemory(opt)

layers = {encoder, decoder, generator}
if opt.brnn == 1 then
table.insert(layers, encoder_bwd)
Expand Down

0 comments on commit 4258944

Please sign in to comment.