From 6a0dda1ff86915d15cba0c9c12a9fc8a5e71a1a7 Mon Sep 17 00:00:00 2001 From: Jaeman Date: Wed, 29 Aug 2018 08:08:07 +0900 Subject: [PATCH] Fix bug on distributed training in mnist using MirroredStrategy API (#5183) * Fix bug on distributed training in mnist using MirroredStrategy API * Remove unnecessary codes and chagne distribution strategy source - Remove multi-gpu - Remove TowerOptimizer - Change from MirroredStrategy to distribution_utils.get_distribution_strategy --- official/mnist/mnist.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/official/mnist/mnist.py b/official/mnist/mnist.py index 8cbaa6a96b8..089a5493e32 100644 --- a/official/mnist/mnist.py +++ b/official/mnist/mnist.py @@ -89,6 +89,7 @@ def create_model(data_format): def define_mnist_flags(): flags_core.define_base() + flags_core.define_performance(num_parallel_calls=False) flags_core.define_image() flags.adopt_module_key_flags(flags_core) flags_core.set_defaults(data_dir='/tmp/mnist_data', @@ -119,10 +120,6 @@ def model_fn(features, labels, mode, params): if mode == tf.estimator.ModeKeys.TRAIN: optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE) - # If we are running multi-GPU, we need to wrap the optimizer. - if params.get('multi_gpu'): - optimizer = tf.contrib.estimator.TowerOptimizer(optimizer) - logits = model(image, training=True) loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits) accuracy = tf.metrics.accuracy( @@ -162,21 +159,16 @@ def run_mnist(flags_obj): model_helpers.apply_clean(flags_obj) model_function = model_fn - # Get number of GPUs as defined by the --num_gpus flags and the number of - # GPUs available on the machine. - num_gpus = flags_core.get_num_gpus(flags_obj) - multi_gpu = num_gpus > 1 + session_config = tf.ConfigProto( + inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, + intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, + allow_soft_placement=True) - if multi_gpu: - # Validate that the batch size can be split into devices. - distribution_utils.per_device_batch_size(flags_obj.batch_size, num_gpus) + distribution_strategy = distribution_utils.get_distribution_strategy( + flags_core.get_num_gpus(flags_obj), flags_obj.all_reduce_alg) - # There are two steps required if using multi-GPU: (1) wrap the model_fn, - # and (2) wrap the optimizer. The first happens here, and (2) happens - # in the model_fn itself when the optimizer is defined. - model_function = tf.contrib.estimator.replicate_model_fn( - model_fn, loss_reduction=tf.losses.Reduction.MEAN, - devices=["/device:GPU:%d" % d for d in range(num_gpus)]) + run_config = tf.estimator.RunConfig( + train_distribute=distribution_strategy, session_config=session_config) data_format = flags_obj.data_format if data_format is None: @@ -185,9 +177,9 @@ def run_mnist(flags_obj): mnist_classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=flags_obj.model_dir, + config=run_config, params={ 'data_format': data_format, - 'multi_gpu': multi_gpu }) # Set up training and evaluation input functions.