Skip to content

Commit

Permalink
Fix bug on distributed training in mnist using MirroredStrategy API (t…
Browse files Browse the repository at this point in the history
…ensorflow#5183)

* Fix bug on distributed training in mnist using MirroredStrategy API

* Remove unnecessary codes and chagne distribution strategy source
- Remove multi-gpu
- Remove TowerOptimizer
- Change from MirroredStrategy to distribution_utils.get_distribution_strategy
  • Loading branch information
parkjaeman authored and Taylor Robie committed Aug 28, 2018
1 parent 0d105c3 commit 6a0dda1
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions official/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def create_model(data_format):

def define_mnist_flags():
flags_core.define_base()
flags_core.define_performance(num_parallel_calls=False)
flags_core.define_image()
flags.adopt_module_key_flags(flags_core)
flags_core.set_defaults(data_dir='/tmp/mnist_data',
Expand Down Expand Up @@ -119,10 +120,6 @@ def model_fn(features, labels, mode, params):
if mode == tf.estimator.ModeKeys.TRAIN:
optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)

# If we are running multi-GPU, we need to wrap the optimizer.
if params.get('multi_gpu'):
optimizer = tf.contrib.estimator.TowerOptimizer(optimizer)

logits = model(image, training=True)
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
accuracy = tf.metrics.accuracy(
Expand Down Expand Up @@ -162,21 +159,16 @@ def run_mnist(flags_obj):
model_helpers.apply_clean(flags_obj)
model_function = model_fn

# Get number of GPUs as defined by the --num_gpus flags and the number of
# GPUs available on the machine.
num_gpus = flags_core.get_num_gpus(flags_obj)
multi_gpu = num_gpus > 1
session_config = tf.ConfigProto(
inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads,
intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads,
allow_soft_placement=True)

if multi_gpu:
# Validate that the batch size can be split into devices.
distribution_utils.per_device_batch_size(flags_obj.batch_size, num_gpus)
distribution_strategy = distribution_utils.get_distribution_strategy(
flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg)

# There are two steps required if using multi-GPU: (1) wrap the model_fn,
# and (2) wrap the optimizer. The first happens here, and (2) happens
# in the model_fn itself when the optimizer is defined.
model_function = tf.contrib.estimator.replicate_model_fn(
model_fn, loss_reduction=tf.losses.Reduction.MEAN,
devices=["/device:GPU:%d" % d for d in range(num_gpus)])
run_config = tf.estimator.RunConfig(
train_distribute=distribution_strategy, session_config=session_config)

data_format = flags_obj.data_format
if data_format is None:
Expand All @@ -185,9 +177,9 @@ def run_mnist(flags_obj):
mnist_classifier = tf.estimator.Estimator(
model_fn=model_function,
model_dir=flags_obj.model_dir,
config=run_config,
params={
'data_format': data_format,
'multi_gpu': multi_gpu
})

# Set up training and evaluation input functions.
Expand Down

0 comments on commit 6a0dda1

Please sign in to comment.