Skip to content

Commit 178480e

Browse files
authored
Remove batch norm weight decay + a few other fixes (tensorflow#2755)
1 parent 4c37264 commit 178480e

File tree

2 files changed

+9
-7
lines changed

2 files changed

+9
-7
lines changed

official/resnet/imagenet_main.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,11 @@ def resnet_model_fn(features, labels, mode, params):
184184
tf.identity(cross_entropy, name='cross_entropy')
185185
tf.summary.scalar('cross_entropy', cross_entropy)
186186

187-
# Add weight decay to the loss. We perform weight decay on all trainable
188-
# variables, which includes batch norm beta and gamma variables.
187+
# Add weight decay to the loss. We exclude the batch norm variables because
188+
# doing so leads to a small improvement in accuracy.
189189
loss = cross_entropy + _WEIGHT_DECAY * tf.add_n(
190-
[tf.nn.l2_loss(v) for v in tf.trainable_variables()])
190+
[tf.nn.l2_loss(v) for v in tf.trainable_variables()
191+
if 'batch_normalization' not in v.name])
191192

192193
if mode == tf.estimator.ModeKeys.TRAIN:
193194
# Scale the learning rate linearly with the batch size. When the batch size

official/resnet/resnet_model.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,8 @@ def cifar10_resnet_v2_generator(resnet_size, num_classes, data_format=None):
242242
def model(inputs, is_training):
243243
"""Constructs the ResNet model given the inputs."""
244244
if data_format == 'channels_first':
245-
# Convert from channels_last (NHWC) to channels_first (NCHW). This
246-
# provides a large performance boost on GPU. See
245+
# Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
246+
# This provides a large performance boost on GPU. See
247247
# https://www.tensorflow.org/performance/performance_guide#data_formats
248248
inputs = tf.transpose(inputs, [0, 3, 1, 2])
249249

@@ -302,8 +302,9 @@ def imagenet_resnet_v2_generator(block_fn, layers, num_classes,
302302
def model(inputs, is_training):
303303
"""Constructs the ResNet model given the inputs."""
304304
if data_format == 'channels_first':
305-
# Convert from channels_last (NHWC) to channels_first (NCHW). This
306-
# provides a large performance boost on GPU.
305+
# Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
306+
# This provides a large performance boost on GPU. See
307+
# https://www.tensorflow.org/performance/performance_guide#data_formats
307308
inputs = tf.transpose(inputs, [0, 3, 1, 2])
308309

309310
inputs = conv2d_fixed_padding(

0 commit comments

Comments
 (0)