Skip to content

Commit

Permalink
MXNet: add DistributedTrainer for Gluon (horovod#943)
Browse files Browse the repository at this point in the history
* Add DistributedTrainer for Gluon

Signed-off-by: Yuxi Hu <[email protected]>

* check optimizer type

Signed-off-by: Yuxi Hu <[email protected]>

* update README

Signed-off-by: Yuxi Hu <[email protected]>

* address comment

Signed-off-by: Yuxi Hu <[email protected]>

* add doc for DistributedTrainer

Signed-off-by: Yuxi Hu <[email protected]>
  • Loading branch information
yuxihu authored and alsrgv committed Mar 29, 2019
1 parent 3a1083e commit 5d8f80c
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 19 deletions.
15 changes: 7 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,12 @@ See a full training [example](examples/tensorflow_mnist_estimator.py).

Horovod supports MXNet and regular TensorFlow in similar ways.

See full training [MNIST](examples/mxnet_mnist.py) and [ImageNet](examples/mxnet_imagenet_resnet50.py) examples.
See full training [MNIST](examples/mxnet_mnist.py) and [ImageNet](examples/mxnet_imagenet_resnet50.py) examples. The script below provides a simple skeleton of code block based on MXNet Gluon API.

```python
import mxnet as mx
import horovod.mxnet as hvd
from mxnet import autograd, gluon
from mxnet import autograd

# Initialize Horovod
hvd.init()
Expand All @@ -220,12 +220,9 @@ num_workers = hvd.size()
model = ...
model.hybridize()

# Define hyper parameters
# Create optimizer
optimizer_params = ...

# Add Horovod Distributed Optimizer
opt = mx.optimizer.create('sgd', **optimizer_params)
opt = hvd.DistributedOptimizer(opt)

# Initialize parameters
model.initialize(initializer, ctx=context)
Expand All @@ -235,8 +232,10 @@ params = model.collect_params()
if params is not None:
hvd.broadcast_parameters(params, root_rank=0)

# Create trainer and loss function
trainer = gluon.Trainer(params, opt, kvstore=None)
# Create DistributedTrainer, a subclass of gluon.Trainer
trainer = hvd.DistributedTrainer(params, opt)

# Create loss function
loss_fn = ...

# Train model
Expand Down
14 changes: 8 additions & 6 deletions examples/mxnet_imagenet_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,6 @@ def reset(self):
optimizer_params['multi_precision'] = True
opt = mx.optimizer.create('sgd', **optimizer_params)

# Horovod: wrap optimizer with DistributedOptimizer
opt = hvd.DistributedOptimizer(opt)


def train_gluon():
def evaluate(epoch):
Expand Down Expand Up @@ -320,8 +317,10 @@ def evaluate(epoch):
if params is not None:
hvd.broadcast_parameters(params, root_rank=0)

# Create trainer, loss function and train metric
trainer = gluon.Trainer(params, opt, kvstore=None)
# Horovod: create DistributedTrainer, a subclass of gluon.Trainer
trainer = hvd.DistributedTrainer(params, opt)

# Create loss function and train metric
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
metric = mx.metric.Accuracy()

Expand Down Expand Up @@ -410,6 +409,9 @@ def train_module():
hvd.broadcast_parameters(aux_params, root_rank=0)
mod.set_params(arg_params=arg_params, aux_params=aux_params)

# Horovod: wrap optimizer with DistributedOptimizer
dist_opt = hvd.DistributedOptimizer(opt)

# Setup validation data and callback during training
eval_data = None
if args.eval_epoch:
Expand All @@ -432,7 +434,7 @@ def train_module():
kvstore=None,
batch_end_callback=batch_callback,
epoch_end_callback=epoch_callback,
optimizer=opt)
optimizer=dist_opt)

# Evaluate performance if not using synthetic data
if args.use_rec:
Expand Down
10 changes: 5 additions & 5 deletions examples/mxnet_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,21 +117,21 @@ def evaluate(model, data_iter, context):
'learning_rate': args.lr * hvd.size(),
'rescale_grad': 1.0 / args.batch_size}
opt = mx.optimizer.create('sgd', **optimizer_params)
# Horovod: wrap optimizer with DistributedOptimizer
opt = hvd.DistributedOptimizer(opt)

# Initialize parameters
initializer = mx.init.Xavier(rnd_type='gaussian', factor_type="in",
magnitude=2)
model.initialize(initializer, ctx=context)

# Fetch and broadcast parameters
# Horovod: fetch and broadcast parameters
params = model.collect_params()
if params is not None:
hvd.broadcast_parameters(params, root_rank=0)

# Create trainer, loss function and train metric
trainer = gluon.Trainer(params, opt, kvstore=None)
# Horovod: create DistributedTrainer, a subclass of gluon.Trainer
trainer = hvd.DistributedTrainer(params, opt)

# Create loss function and train metric
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
metric = mx.metric.Accuracy()

Expand Down
23 changes: 23 additions & 0 deletions horovod/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import mxnet as mx
import types
import warnings


# This is where Horovod's DistributedOptimizer wrapper for MXNet goes
Expand Down Expand Up @@ -69,6 +70,28 @@ def set_wd_mult(self, args_wd_mult):
self._optimizer.set_wd_mult(args_wd_mult)


# DistributedTrainer, a subclass of MXNet gluon.Trainer.
# There are two differences between DistributedTrainer and Trainer:
# 1. DistributedTrainer calculates gradients using Horovod allreduce
# API while Trainer does it using kvstore push/pull APIs;
# 2. DistributedTrainer performs allreduce(summation) and average
# while Trainer only performs allreduce(summation).
class DistributedTrainer(mx.gluon.Trainer):
def __init__(self, params, optimizer, optimizer_params=None):
if isinstance(optimizer, DistributedOptimizer):
optimizer = optimizer._optimizer
warnings.warn("DistributedTrainer does not take DistributedOptimizer "
"as its optimizer. We have unwrapped it for you.")

super(DistributedTrainer, self).__init__(
params, optimizer, optimizer_params=optimizer_params, kvstore=None)

def _allreduce_grads(self):
for i, param in enumerate(self._params):
if param.grad_req != 'null':
allreduce_(param.list_grad()[0], average=True, name=str(i))


# Wrapper to inject Horovod broadcast after parameter initialization
def _append_broadcast_init(param, root_rank):
init_impl = getattr(param, '_init_impl')
Expand Down

0 comments on commit 5d8f80c

Please sign in to comment.