Skip to content

Commit

Permalink
MXNet: improve imagenet example performance (horovod#794)
Browse files Browse the repository at this point in the history
* use local variables for rank and local_rank

* set device_id in ImageRecordIter to use CPUPinned context
  • Loading branch information
yuxihu authored and alsrgv committed Jan 30, 2019
1 parent b58347b commit 1feb955
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions examples/mxnet_imagenet_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@
# Horovod: initialize Horovod
hvd.init()
num_workers = hvd.size()
rank = hvd.rank()
local_rank = hvd.local_rank()

num_classes = 1000
num_training_samples = 1281167
Expand Down Expand Up @@ -129,7 +131,7 @@
raise ValueError('Invalid lr mode')

# Horovod: pin GPU to local rank
context = mx.cpu() if args.no_cuda else mx.gpu(hvd.local_rank())
context = mx.cpu() if args.no_cuda else mx.gpu(local_rank)
kwargs = {'ctx': context, 'pretrained': args.use_pretrained,
'classes': num_classes}
if args.last_gamma:
Expand Down Expand Up @@ -178,8 +180,9 @@ def batch_fn(batch, ctx):
saturation=jitter_param,
contrast=jitter_param,
pca_noise=lighting_param,
num_parts=hvd.size(),
part_index=hvd.rank()
num_parts=num_workers,
part_index=rank,
device_id=local_rank
)
# Kept each node to use full val data to make it easy to monitor results
val_data = mx.io.ImageRecordIter(
Expand All @@ -195,7 +198,8 @@ def batch_fn(batch, ctx):
data_shape=(3, 224, 224),
mean_r=mean_rgb[0],
mean_g=mean_rgb[1],
mean_b=mean_rgb[2]
mean_b=mean_rgb[2],
device_id=local_rank
)

return train_data, val_data, batch_fn
Expand Down Expand Up @@ -331,7 +335,7 @@ def train():
epoch_callback = None
if args.save_frequency > 0:
epoch_callback = mx.callback.do_checkpoint(
'%s-%d' % (args.model, hvd.rank()),
'%s-%d' % (args.model, rank),
period=args.save_frequency)

# Train model
Expand All @@ -351,7 +355,7 @@ def train():
res = mod.score(val_data, [acc_top1, acc_top5])
for name, val in res:
logging.info('Epoch[%d] Rank[%d] Validation-%s=%f',
args.num_epochs - 1, hvd.rank(), name, val)
args.num_epochs - 1, rank, name, val)


if __name__ == '__main__':
Expand Down

0 comments on commit 1feb955

Please sign in to comment.