-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathmain.lua
81 lines (72 loc) · 3.1 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
-- Main script that loads other scripts, sets some of the options automatically.
-- User should typically set the following options to start training/testing a model:
-- 'expName', 'dataset', 'split', 'stream'. See opts.lua for others.
require 'torch'
require 'cutorch'
require 'paths'
require 'xlua'
require 'optim'
require 'nn'
paths.dofile('trainplot/TrainPlotter.lua')
local opts = paths.dofile('opts.lua')
opt = opts.parse(arg)
local nChannels
if(opt.stream == 'flow') then opt.mean = 0; nChannels = 2
elseif(opt.stream == 'rgb') then opt.mean = 96; nChannels = 3; opt.coeff = 255 end
opt.save = paths.concat(opt.logRoot, opt.dataset, opt.expName)
opt.cache = paths.concat(opt.logRoot, opt.dataset, 'cache', opt.stream)
opt.data = paths.concat(opt.dataRoot, opt.dataset, 'splits', 'split' .. opt.split)
opt.framesRoot = paths.concat(opt.dataRoot, opt.dataset, opt.stream, 't7')
opt.forceClasses = torch.load(paths.concat(opt.dataRoot, opt.dataset, 'annot/forceClasses.t7'))
opt.loadSize = {nChannels, opt.nFrames, opt.loadHeight, opt.loadWidth}
opt.sampleSize = {nChannels, opt.nFrames, opt.sampleHeight, opt.sampleWidth}
paths.dofile(opt.LRfile)
-- Testing final predictions
if(opt.evaluate) then
opt.save = paths.concat(opt.logRoot, opt.dataset, opt.expName, 'test_' .. opt.modelNo .. '_slide' .. opt.slide)
opt.cache = paths.concat(opt.logRoot, opt.dataset, 'cache', 'test', opt.stream)
opt.scales = false
opt.crops10 = true
opt.testDir = 'test_' .. opt.loadSize[2] .. '_' .. opt.slide
opt.retrain = paths.concat(opt.logRoot, opt.dataset, opt.expName, 'model_' .. opt.modelNo .. '.t7')
opt.finetune = 'none'
end
-- Continue training (epochNumber has to be set for this option)
if(opt.continue) then
print('Continuing from epoch ' .. opt.epochNumber)
opt.retrain = opt.save .. '/model_' .. opt.epochNumber -1 ..'.t7'
opt.finetune = 'none'
opt.optimState = opt.save .. '/optimState_'.. opt.epochNumber -1 ..'.t7'
local backupDir = opt.save .. '/delete' .. os.time()
os.execute('mkdir -p ' .. backupDir)
os.execute('cp ' .. opt.save .. '/train.log ' ..backupDir)
os.execute('cp ' .. opt.save .. '/' .. opt.testDir..'.log ' ..backupDir)
os.execute('cp ' .. opt.save .. '/plot.json ' ..backupDir)
end
os.execute('mkdir -p ' .. opt.save)
os.execute('mkdir -p ' .. opt.cache)
opt.plotter = TrainPlotter.new(paths.concat(opt.save, 'plot.json'))
opt.plotter:info({created_time=io.popen('date'):read(), tag=opt.expName})
print(opt)
print('Saving everything to: ' .. opt.save)
torch.save(paths.concat(opt.save, 'opt' .. os.time() .. '.t7'), opt)
torch.setdefaulttensortype('torch.FloatTensor')
cutorch.setDevice(opt.GPU)
torch.manualSeed(opt.manualSeed)
paths.dofile('data.lua')
paths.dofile('model.lua')
paths.dofile('test.lua')
if(not opt.evaluate) then
-- Training
paths.dofile('train.lua')
epoch = opt.epochNumber
for i=1,opt.nEpochs do
train()
test()
os.execute('scp ' .. paths.concat(opt.save, 'plot.json') .. ' ' .. paths.concat('trainplot/plot-data/', opt.dataset, opt.expName:gsub('%W','') ..'.json'))
epoch = epoch + 1
end
else
-- Testing final predictions
test()
end