Skip to content

Commit

Permalink
Simplify implementation of BN layer.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Apr 24, 2017
1 parent b8134f5 commit 5be73f1
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 44 deletions.
90 changes: 46 additions & 44 deletions keras/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,55 +135,57 @@ def call(self, inputs, training=None):
# Determines whether broadcasting is needed.
needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])

normed, mean, variance = K.normalize_batch_in_training(
def normalize_inference():
if needs_broadcasting:
# In this case we must explictly broadcast all parameters.
broadcast_moving_mean = K.reshape(self.moving_mean,
broadcast_shape)
broadcast_moving_variance = K.reshape(self.moving_variance,
broadcast_shape)
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
else:
broadcast_beta = None
if self.scale:
broadcast_gamma = K.reshape(self.gamma,
broadcast_shape)
else:
broadcast_gamma = None
return K.batch_normalization(
inputs,
broadcast_moving_mean,
broadcast_moving_variance,
broadcast_beta,
broadcast_gamma,
epsilon=self.epsilon)
else:
return K.batch_normalization(
inputs,
self.moving_mean,
self.moving_variance,
self.beta,
self.gamma,
epsilon=self.epsilon)

# If the learning phase is *static* and set to inference:
if training in {0, False}:
return normalize_inference()

# If the learning is either dynamic, or set to training:
normed_training, mean, variance = K.normalize_batch_in_training(
inputs, self.gamma, self.beta, reduction_axes,
epsilon=self.epsilon)

if training in {0, False}:
return normed
else:
self.add_update([K.moving_average_update(self.moving_mean,
mean,
self.momentum),
K.moving_average_update(self.moving_variance,
variance,
self.momentum)],
inputs)

def normalize_inference():
if needs_broadcasting:
# In this case we must explictly broadcast all parameters.
broadcast_moving_mean = K.reshape(self.moving_mean,
broadcast_shape)
broadcast_moving_variance = K.reshape(self.moving_variance,
broadcast_shape)
if self.center:
broadcast_beta = K.reshape(self.beta, broadcast_shape)
else:
broadcast_beta = None
if self.scale:
broadcast_gamma = K.reshape(self.gamma,
broadcast_shape)
else:
broadcast_gamma = None
return K.batch_normalization(
inputs,
broadcast_moving_mean,
broadcast_moving_variance,
broadcast_beta,
broadcast_gamma,
epsilon=self.epsilon)
else:
return K.batch_normalization(
inputs,
self.moving_mean,
self.moving_variance,
self.beta,
self.gamma,
epsilon=self.epsilon)
self.add_update([K.moving_average_update(self.moving_mean,
mean,
self.momentum),
K.moving_average_update(self.moving_variance,
variance,
self.momentum)],
inputs)

# Pick the normalized form corresponding to the training phase.
return K.in_train_phase(normed,
return K.in_train_phase(normed_training,
normalize_inference,
training=training)

Expand Down
25 changes: 25 additions & 0 deletions tests/keras/layers/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,31 @@ def test_batchnorm_correctness():
assert_allclose(out.std(), 1.0, atol=1e-1)


@keras_test
def test_batchnorm_training_argument():
bn1 = normalization.BatchNormalization(input_shape=(10,))
x1 = Input(shape=(10,))
y1 = bn1(x1, training=True)
assert bn1.updates

model1 = Model(x1, y1)
np.random.seed(123)
x = np.random.normal(loc=5.0, scale=10.0, size=(20, 10))
output_a = model1.predict(x)

model1.compile(loss='mse', optimizer='rmsprop')
model1.fit(x, x, epochs=1, verbose=0)
output_b = model1.predict(x)
assert np.abs(np.sum(output_a - output_b)) > 0.1
assert_allclose(output_b.mean(), 0.0, atol=1e-1)
assert_allclose(output_b.std(), 1.0, atol=1e-1)

bn2 = normalization.BatchNormalization(input_shape=(10,))
x2 = Input(shape=(10,))
bn2(x2, training=False)
assert not bn2.updates


@keras_test
def test_batchnorm_mode_twice():
# This is a regression test for issue #4881 with the old
Expand Down

0 comments on commit 5be73f1

Please sign in to comment.