Skip to content

Commit

Permalink
Reduce test flakiness
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Jan 4, 2018
1 parent 45c838c commit 4da8c81
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions tests/keras/layers/normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@ def test_batchnorm_correctness_1d():
model = Sequential()
norm = normalization.BatchNormalization(input_shape=(10,), momentum=0.8)
model.add(norm)
model.compile(loss='mse', optimizer='sgd')
model.compile(loss='mse', optimizer='rmsprop')

# centered on 5.0, variance 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
model.fit(x, x, epochs=4, verbose=0)
model.fit(x, x, epochs=5, verbose=0)
out = model.predict(x)
out -= K.eval(norm.beta)
out /= K.eval(norm.gamma)
Expand All @@ -67,11 +67,11 @@ def test_batchnorm_correctness_2d():
model = Sequential()
norm = normalization.BatchNormalization(axis=1, input_shape=(10, 6), momentum=0.8)
model.add(norm)
model.compile(loss='mse', optimizer='sgd')
model.compile(loss='mse', optimizer='rmsprop')

# centered on 5.0, variance 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 6))
model.fit(x, x, epochs=4, verbose=0)
model.fit(x, x, epochs=5, verbose=0)
out = model.predict(x)
out -= np.reshape(K.eval(norm.beta), (1, 10, 1))
out /= np.reshape(K.eval(norm.gamma), (1, 10, 1))
Expand Down

0 comments on commit 4da8c81

Please sign in to comment.