Skip to content

Commit d6becaf

Browse files
committed
Added Sqeeuze/Excite and BatchNorm recipes.
1 parent e6ac9f3 commit d6becaf

File tree

1 file changed

+73
-1
lines changed

1 file changed

+73
-1
lines changed

README.md

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ Table of Contents
2424
- [KL-Divergence](#kld)
2525
- [Make parallel](#make_parallel)
2626
- [Leaky Relu](#leaky_relu)
27-
27+
- [Batch normalization](#batch_norm)
28+
- [Squeeze and excitation](#squeeze_excite)
2829
---
2930

3031
_We aim to gradually expand this series by adding new articles and keep the content up to date with the latest releases of TensorFlow API. If you have suggestions on how to improve this series or find the explanations ambiguous, feel free to create an issue, send patches, or reach out by email._
@@ -1409,3 +1410,74 @@ def leaky_relu(tensor, alpha=0.1):
14091410
"""Computes the leaky rectified linear activation."""
14101411
return tf.maximum(x, alpha * x)
14111412
```
1413+
1414+
## Batch normalization <a name="batch_norm"></a>
1415+
```python
1416+
def batch_normalization(tensor, training=False, epsilon=0.001, momentum=0.9,
1417+
fused_batch_norm=False, name=None):
1418+
"""Performs batch normalization on given 4-D tensor.
1419+
1420+
The features are assumed to be in NHWC format. Noe that you need to
1421+
run UPDATE_OPS in order for this function to perform correctly, e.g.:
1422+
1423+
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
1424+
train_op = optimizer.minimize(loss)
1425+
"""
1426+
with tf.variable_scope(name, default_name="batch_normalization"):
1427+
channels = tensor.shape.as_list()[-1]
1428+
axes = list(range(tensor.shape.ndims - 1))
1429+
1430+
beta = tf.get_variable(
1431+
'beta', channels, initializer=tf.zeros_initializer())
1432+
gamma = tf.get_variable(
1433+
'gamma', channels, initializer=tf.ones_initializer())
1434+
1435+
avg_mean = tf.get_variable(
1436+
"avg_mean", channels, initializer=tf.zeros_initializer(),
1437+
trainable=False)
1438+
avg_variance = tf.get_variable(
1439+
"avg_variance", channels, initializer=tf.ones_initializer(),
1440+
trainable=False)
1441+
1442+
if training:
1443+
if fused_batch_norm:
1444+
mean, variance = None, None
1445+
else:
1446+
mean, variance = tf.nn.moments(tensor, axes=axes)
1447+
else:
1448+
mean, variance = avg_mean, avg_variance
1449+
1450+
if fused_batch_norm:
1451+
tensor, mean, variance = tf.nn.fused_batch_norm(
1452+
tensor, scale=gamma, offset=beta, mean=mean, variance=variance,
1453+
epsilon=epsilon, is_training=training)
1454+
else:
1455+
tensor = tf.nn.batch_normalization(
1456+
tensor, mean, variance, beta, gamma, epsilon)
1457+
1458+
if training:
1459+
update_mean = tf.assign(
1460+
avg_mean, avg_mean * momentum + mean * (1.0 - momentum))
1461+
update_variance = tf.assign(
1462+
avg_variance, avg_variance * momentum + variance * (1.0 - momentum))
1463+
1464+
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean)
1465+
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_variance)
1466+
1467+
return tensor
1468+
```
1469+
1470+
## Squeeze and excitation <a name="squeeze_excite"></a>
1471+
```python
1472+
def squeeze_and_excite(tensor, ratio):
1473+
"""Squeeze and excite layer."""
1474+
original = tensor
1475+
units = tensor.shape.as_list()[-1]
1476+
tensor = tf.reduce_mean(tensor, [1, 2], keep_dims=True)
1477+
tensor = tf.layers.dense(tensor, units / ratio, use_bias=False)
1478+
tensor = tf.nn.relu(tensor)
1479+
tensor = tf.layers.dense(tensor, units, use_bias=False)
1480+
tensor = tf.nn.sigmoid(tensor)
1481+
tensor = original * tensor
1482+
return tensor
1483+
```

0 commit comments

Comments
 (0)