Skip to content

Commit

Permalink
Handle data format and graph compile properly for better GPU performa…
Browse files Browse the repository at this point in the history
…nce (tensorflow#6013)

* Handle data format in Keras ResNet model properly for better performance on GPU; Compile only the training graph when skip_eval flag is True

* Added data format fix to Keras Cifar model; Removed unnecessary import

* Add a comment to the skip_eval flag per Priya's request
  • Loading branch information
haoyuz authored Jan 9, 2019
1 parent 56cbd1f commit 1cdc35c
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 11 deletions.
3 changes: 3 additions & 0 deletions official/resnet/keras/keras_cifar_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def run(flags_obj):
per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))

tf.keras.backend.set_image_data_format(flags_obj.data_format)

if flags_obj.use_synthetic_data:
input_fn = keras_common.get_synth_input_fn(
height=cifar_main.HEIGHT,
Expand Down Expand Up @@ -160,6 +162,7 @@ def run(flags_obj):

validation_data = eval_input_dataset
if flags_obj.skip_eval:
tf.keras.backend.set_learning_phase(1)
num_eval_steps = None
validation_data = None

Expand Down
6 changes: 6 additions & 0 deletions official/resnet/keras/keras_imagenet_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def run(flags_obj):
raise ValueError('dtype fp16 is not supported in Keras. Use the default '
'value(fp32).')

tf.keras.backend.set_image_data_format(flags_obj.data_format)

per_device_batch_size = distribution_utils.per_device_batch_size(
flags_obj.batch_size, flags_core.get_num_gpus(flags_obj))

Expand Down Expand Up @@ -149,6 +151,10 @@ def run(flags_obj):

validation_data = eval_input_dataset
if flags_obj.skip_eval:
# Only build the training graph. This reduces memory usage introduced by
# control flow ops in layers that have different implementations for
# training and inference (e.g., batch norm).
tf.keras.backend.set_learning_phase(1)
num_eval_steps = None
validation_data = None

Expand Down
12 changes: 7 additions & 5 deletions official/resnet/keras/resnet_cifar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,16 +187,18 @@ def resnet56(classes=100, training=None):
Returns:
A Keras model instance.
"""
# Determine proper input shape
input_shape = (32, 32, 3)
img_input = layers.Input(shape=input_shape)

if backend.image_data_format() == 'channels_first':
input_shape = (3, 32, 32)
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(img_input)
bn_axis = 1
else: # channel_last
input_shape = (32, 32, 3)
x = img_input
bn_axis = 3

img_input = layers.Input(shape=input_shape)
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(img_input)
x = tf.keras.layers.ZeroPadding2D(padding=(1, 1), name='conv1_pad')(x)
x = tf.keras.layers.Conv2D(16, (3, 3),
strides=(1, 1),
padding='valid',
Expand Down
14 changes: 8 additions & 6 deletions official/resnet/keras/resnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,18 @@ def resnet50(num_classes):
Returns:
A Keras model instance.
"""
# Determine proper input shape
input_shape = (224, 224, 3)
img_input = layers.Input(shape=input_shape)

if backend.image_data_format() == 'channels_first':
input_shape = (3, 224, 224)
x = layers.Lambda(lambda x: backend.permute_dimensions(x, (0, 3, 1, 2)),
name='transpose')(img_input)
bn_axis = 1
else:
input_shape = (224, 224, 3)
else: # channels_last
x = img_input
bn_axis = 3

img_input = layers.Input(shape=input_shape)
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(x)
x = layers.Conv2D(64, (7, 7),
strides=(2, 2),
padding='valid',
Expand Down

0 comments on commit 1cdc35c

Please sign in to comment.