-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathinit.lua
450 lines (405 loc) · 13.4 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
require 'nn'
local optnet = require 'optnet.env'
require 'optnet.countUsedMemory'
require 'optnet.tests'
local utils = require 'optnet.utils'
local kNotUsed = 10000---1
local kNotDefined = 0
local kMinimumForSharing = 2
local kAlwaysLive = 10000
local function analyse(net, input, opts)
opts = opts or {}
local mode = opts.mode or 'inference'
local track = {}
local analysis = {}
local function entry_fun(t, args)
local ptr = torch.pointer(t:storage())
local info = {used=kNotUsed, defined=kNotDefined,
name=args.name, ptr=ptr, tensor=t}
table.insert(analysis, info)
return info
end
local function fun(t, track, args)
local ptr = torch.pointer(t:storage())
local val = track[ptr][args.var]
if val == args.notUsed then
track[ptr][args.var] = args.c
else
track[ptr][args.var] = args.f(args.c,val)
end
end
local c = 1
local function apply_func(m)
local func = 'updateOutput'
local basefunc = m[func]
m[func] = function(self, input)
local opts = {
analysis=analysis, c=c, name=tostring(m),
kNotUsed=kNotUsed, kNotDefined=kNotDefined
}
if mode == 'inference' then
-- always keep track of the input
opts.var = 'used'; opts.f = math.max; opts.notUsed = kNotUsed
utils.keepTrack(input, track, entry_fun, fun, opts)
if not m.modules then
-- always keep track of the outputs of non-containers
opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined
utils.keepTrack(self.output, track, entry_fun, fun, opts)
elseif torch.typename(m) == 'nn.Concat' or
torch.typename(m) == 'nn.Parallel' or
torch.typename(m) == 'nn.DepthConcat' then
-- for containers that do some operations on the input, need to keep
-- track of each output of its branches uppon entry on the module,
-- as well as to keep track of it's own output (as it's a non-trivial
-- operation on the childs output, contrary to nn.Sequential for
-- example)
opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined
utils.keepTrack(self.output, track, entry_fun, fun, opts)
for i,branch in ipairs(m.modules) do
local last_module
-- if brach is a container, get its last element, if not, take it
if branch.modules then
last_module = branch:get(branch:size())
else
last_module = branch
end
local out = last_module.output
opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined
utils.keepTrack(out, track, entry_fun, fun, opts)
end
end
end
c = c + 1
return basefunc(self,input)
end
for _, func in ipairs({'updateGradInput','accGradParameters','backward'}) do
local basefunc = m[func]
m[func] = function(self, input, gradOutput, scale)
local opts = {
analysis=analysis, c=c, name=tostring(m),
kNotUsed=kNotUsed, kNotDefined=kNotDefined
}
-- always keep track of the input
--opts.var = 'used'; opts.f = math.max; opts.notUsed = kNotUsed
--utils.keepTrack(input, track, entry_fun, fun, opts)
if not torch.isTypeOf(m, 'nn.Sequential') then
-- always keep track of the gradOutput
opts.var = 'used'; opts.f = math.max; opts.notUsed = kNotUsed
utils.keepTrack(gradOutput, track, entry_fun, fun, opts)
opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined
utils.keepTrack(self.gradInput, track, entry_fun, fun, opts)
end
--[[
if not m.modules then
-- always keep track of the gradInput of non-containers
opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined
utils.keepTrack(self.gradInput, track, entry_fun, fun, opts)
elseif torch.typename(m) == 'nn.Concat' or
torch.typename(m) == 'nn.Parallel' or
torch.typename(m) == 'nn.DepthConcat' then
-- for containers that do some operations on the gradOutput, need to keep
-- track of each gradInput of its branches uppon entry on the module,
-- as well as to keep track of it's own gradInput (as it's a non-trivial
-- operation on the childs output, contrary to nn.Sequential for
-- example)
opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined
utils.keepTrack(self.gradInput, track, entry_fun, fun, opts)
for i,branch in ipairs(m.modules) do
local first_module = branch:get(1)
local out = first_module.gradInput
opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined
utils.keepTrack(out, track, entry_fun, fun, opts)
end
end
--]]
c = c + 1
return basefunc(self,input,gradOutput,scale)
end
end
end
net:apply(apply_func)
local out = net['forward'](net, input)
local grad
if mode == 'training' then
grad = utils.recursiveClone(out)
net['backward'](net, input, grad)
end
local function trackInputs(t, name)
if torch.isTensor(t) then
local f = function(a,b) return a end
utils.keepTrack(t, track, entry_fun, fun,
{var='used', c=kAlwaysLive,
f=f, notUsed=0, name=name})
utils.keepTrack(t, track, entry_fun, fun,
{var='defined', c=-kAlwaysLive,
f=f, notUsed=0, name=name})
else
for k,v in ipairs(t) do
trackInputs(v)
end
end
end
trackInputs(input,'input')
if grad then
trackInputs(grad,'grad')
end
-- clean up the modified function
net:apply(function(x)
x['updateOutput'] = nil
x['updateGradInput'] = nil
x['accGradParameters'] = nil
x['backward'] = nil
end)
-- disable backward pass if in evaluation mode
if mode == 'inference' then
net:apply(function(m)
m.updateGradInput = function(self, input, gradInput)
error([[Backward pass disabled!
You are using inference optimization.
Call optnet.removeOptimization(net) to enable backward again]])
end
end)
end
return analysis
end
local function isCompatible(candidate, assignment)
if candidate.used == kNotUsed then
return false
end
if candidate.tensor:numel() < kMinimumForSharing then
return false
end
local a_used = assignment[#assignment].used
return candidate.defined > a_used
end
local function assign(net, analysis)
table.sort(analysis, function(a,b)
local x = a.used
local y = b.used
return x < y
end)
local assignments = {}
for _,candidate in ipairs(analysis) do
local assigned = false
local bestAssignment = 0
local minDist = math.huge
local candidateSize = candidate.tensor:numel()
for idx, assignment in ipairs(assignments) do
if isCompatible(candidate, assignment) then
assigned = true
local dist = math.abs(assignment.maxSize-candidateSize)
if dist < minDist then
minDist = dist
bestAssignment = idx
end
end
end
if assigned then
local assignment = assignments[bestAssignment]
table.insert(assignment, candidate)
assignment.maxSize = math.max(assignment.maxSize, candidateSize)
else
table.insert(assignments, {candidate, maxSize=candidateSize})
end
end
return assignments
end
local function applyAssignments(net, assignments)
for _, assignment in ipairs(assignments) do
local storage
for k, v in ipairs(assignment) do
if v.used == kAlwaysLive and v.defined == -kAlwaysLive then
break
end
storage = storage or v.tensor.new(1):storage()
v.tensor:set(storage)
end
end
end
local function defaultValue(var, val)
if var == nil then
var = val
end
return var
end
-- set to inplace modules that allows it
local function setInplace(net, opts)
local inplace = defaultValue(opts.inplace, true)
if inplace then
net:apply(function(m)
if m.inplace ~= nil then
-- inplace is not always supported for threshold,
-- depending on the values. Disabling it for the moment
if torch.typename(m) ~= 'nn.Threshold' then
m.inplace = true
end
end
end)
end
end
local reusableBuffers = {
['nn.SpatialConvolution'] = {{'finput','fgradInput'},{'fgradInput'}},
['nn.SpatialConvolutionMM'] = {{'finput','fgradInput'},{'fgradInput'}},
['nn.Normalize'] = {{'norm','buffer','normp','_indices'},{}},
['nn.SpatialCrossMapLRN'] = {{'scale'},{}},
['nn.SpatialMaxPooling'] = {{'indices'},{}},
}
-- basic reusing scheme: keeps a list of all possible buffers
-- that can be reused in evaluation mode and also in training
-- mode.
local function reuseStateBuffers(net, opts)
local reuseBuffers = defaultValue(opts.reuseBuffers, true)
local mode = defaultValue(opts.mode, 'inference')
local mode_idx = 1
if mode == 'training' then
mode_idx = 2
end
-- workaround SpatialMaxUnpooling corner case
-- https://github.com/fmassa/optimize-net/issues/14
local reusableBuffers = utils.copyTable(reusableBuffers)
if #net:findModules('nn.SpatialMaxUnpooling') > 0 then
reusableBuffers['nn.SpatialMaxPooling'] = {{},{}}
end
if reuseBuffers then
local reusedBuffers = {}
net:apply(function(m)
local name = torch.typename(m)
if reusableBuffers[name] then
local rb = reusableBuffers[name][mode_idx]
for k, v in ipairs(rb) do
if m[v] then
reusedBuffers[name..','..v] = reusedBuffers[name..','..v] or m[v]:storage()
if reusedBuffers[name..','..v] then
m[v]:set(reusedBuffers[name..','..v])
end
end
end
end
end)
end
end
-- needed for cudnn
local function resetInputDescriptors(net)
net:apply(function(m)
if torch.typename(m):find('cudnn') and
torch.typename(m.iSize) == 'torch.LongStorage' then
m.iSize:fill(0)
end
end)
end
-- need to keep a list of shared gradParams
-- to avoid problems when removing the optimization
local function removeGradParams(net, opts)
local removeGradParams = defaultValue(opts.removeGradParams, true)
local mode = defaultValue(opts.mode, 'inference')
if not removeGradParams then return end
if mode == 'training' then return end
local storages = {}
net:apply(function(m)
for _, k in ipairs({'gradWeight','gradBias'}) do
if m[k] and m[k]:storage() then
local strPtr = torch.pointer(m[k]:storage())
local strOffset = m[k]:storageOffset()
local strSize = m[k]:storage():size()
local tPtr = torch.pointer(m[k])
storages[tPtr] = {storage=strPtr, offSet=strOffset,
size=strSize, stride=m[k]:stride()}
m[k]:set()
end
end
-- disabling getParameters
m.getParameters = function(self)
error('getParameters was disabled by optnet '..
'(by option removeGradParams=true). '..
'Call optnet.removeOptimization(net) to enable it back.')
end
end)
net.__gradParamsInfo = storages
end
local function addGradParams(net)
local storages = net.__gradParamsInfo
if not storages then return end
local createdStorages = {}
net:apply(function(m)
for k, v in pairs({weight='gradWeight',bias='gradBias'}) do
if m[v] then
local tPtr = torch.pointer(m[v])
local info = storages[tPtr]
if not createdStorages[info.storage] then
local strSize = info.size
createdStorages[info.storage] = m[v].new(strSize):storage()
end
local storage = createdStorages[info.storage]
local tSize = m[k]:size()
local tStride = info.stride
local tOffset = info.offSet
m[v]:set(storage, tOffset, tSize, tStride)
end
end
-- add back original getParameters
m.getParameters = nil
end)
net.__gradParamsInfo = nil
end
function optnet.optimizeMemory(net, input, opts)
opts = opts or {}
if net.__memoryOptimized then
print('Skipping memory optimization. '..
'Network is already optimized for '..net.__memoryOptimized..' mode.')
return
end
local mode = defaultValue(opts.mode,'inference')
local out = net['forward'](net, input)
local grad
if mode == 'training' then
grad = utils.recursiveClone(out)
net['backward'](net, input, grad)
end
setInplace(net, opts)
reuseStateBuffers(net, opts)
removeGradParams(net, opts)
-- share outputs
local analysis = analyse(net, input, opts)
--print(analysis)
local assignments = assign(net,analysis)
--print(assignments)
applyAssignments(net, assignments)
resetInputDescriptors(net)
-- add flag to mention that it was optimized
net.__memoryOptimized = mode
end
function optnet.removeOptimization(net)
if not net.__memoryOptimized then
print('Skipping memory optimization removal, as the network was not optimized.')
return
end
local function rem(m)
if torch.isTensor(m) then
m:set()
end
if torch.type(m) == 'table' then
for k, v in ipairs(m) do
rem(v)
end
end
end
net:apply(function(m)
rem(m.output)
rem(m.gradInput)
local name = torch.typename(m)
if reusableBuffers[name] then
local rb = reusableBuffers[name][1]
for k, v in ipairs(rb) do
if m[v] then
m[v]:set()
end
end
end
-- remove backward blocking
m.updateGradInput = nil
end)
resetInputDescriptors(net)
addGradParams(net)
net.__memoryOptimized = nil
end
return optnet