-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
100 lines (81 loc) · 3.07 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#!/usr/bin/env python
"""
Trains a model using one or more GPUs.
"""
from multiprocessing import Process
import caffe
def train(
solver, # solver proto definition
snapshot, # solver snapshot to restore
gpus, # list of device ids
timing=False, # show timing info for compute and communications
):
# NCCL uses a uid to identify a session
uid = caffe.NCCL.new_uid()
caffe.init_log()
caffe.log('Using devices %s' % str(gpus))
procs = []
for rank in range(len(gpus)):
p = Process(target=solve,
args=(solver, snapshot, gpus, timing, uid, rank))
p.daemon = True
p.start()
procs.append(p)
for p in procs:
p.join()
def time(solver, nccl):
fprop = []
bprop = []
total = caffe.Timer()
allrd = caffe.Timer()
for _ in range(len(solver.net.layers)):
fprop.append(caffe.Timer())
bprop.append(caffe.Timer())
display = solver.param.display
def show_time():
if solver.iter % display == 0:
s = '\n'
for i in range(len(solver.net.layers)):
s += 'forw %3d %8s ' % (i, solver.net._layer_names[i])
s += ': %.2f\n' % fprop[i].ms
for i in range(len(solver.net.layers) - 1, -1, -1):
s += 'back %3d %8s ' % (i, solver.net._layer_names[i])
s += ': %.2f\n' % bprop[i].ms
s += 'solver total: %.2f\n' % total.ms
s += 'allreduce: %.2f\n' % allrd.ms
caffe.log(s)
solver.net.before_forward(lambda layer: fprop[layer].start())
solver.net.after_forward(lambda layer: fprop[layer].stop())
solver.net.before_backward(lambda layer: bprop[layer].start())
solver.net.after_backward(lambda layer: bprop[layer].stop())
solver.add_callback(lambda: total.start(), lambda: (total.stop(), allrd.start()))
solver.add_callback(nccl)
solver.add_callback(lambda: '', lambda: (allrd.stop(), show_time()))
def solve(proto, snapshot, gpus, timing, uid, rank):
caffe.set_mode_gpu()
caffe.set_device(gpus[rank])
caffe.set_solver_count(len(gpus))
caffe.set_solver_rank(rank)
caffe.set_multiprocess(True)
solver = caffe.SGDSolver(proto)
if snapshot and len(snapshot) != 0:
solver.restore(snapshot)
nccl = caffe.NCCL(solver, uid)
nccl.bcast()
if timing and rank == 0:
time(solver, nccl)
else:
solver.add_callback(nccl)
if solver.param.layer_wise_reduce:
solver.net.after_backward(nccl)
solver.step(solver.param.max_iter)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--solver", required=True, help="Solver proto definition.")
parser.add_argument("--snapshot", help="Solver snapshot to restore.")
parser.add_argument("--gpus", type=int, nargs='+', default=[0],
help="List of device ids.")
parser.add_argument("--timing", action='store_true', help="Show timing info.")
args = parser.parse_args()
train(args.solver, args.snapshot, args.gpus, args.timing)