Skip to content

Commit ec2a44e

Browse files
committed
Simplify entropy function.
1 parent 51eb429 commit ec2a44e

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

README.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -933,12 +933,12 @@ Let's look at an example:
933933
```python
934934
import tensorflow as tf
935935

936-
def non_differentiable_entropy(logits):
936+
def non_differentiable_softmax_entropy(logits):
937937
probs = tf.nn.softmax(logits)
938938
return tf.nn.softmax_cross_entropy_with_logits(labels=probs, logits=logits)
939939

940940
w = tf.get_variable("w", shape=[5])
941-
y = -non_differentiable_entropy(w)
941+
y = -non_differentiable_softmax_entropy(w)
942942

943943
opt = tf.train.AdamOptimizer()
944944
train_op = opt.minimize(y)
@@ -970,13 +970,12 @@ Now let's fix our function with a differentiable version of the entropy and chec
970970
import tensorflow as tf
971971
import numpy as np
972972

973-
def entropy(logits, dim=-1):
974-
probs = tf.nn.softmax(logits, dim)
975-
nplogp = probs * (tf.reduce_logsumexp(logits, dim, keep_dims=True) - logits)
976-
return tf.reduce_sum(nplogp, dim)
973+
def softmax_entropy(logits, dim=-1):
974+
plogp = tf.nn.softmax(logits, dim) * tf.nn.log_softmax(logits, dim)
975+
return -tf.reduce_sum(nplogp, dim)
977976

978977
w = tf.get_variable("w", shape=[5])
979-
y = -entropy(w)
978+
y = -softmax_entropy(w)
980979

981980
print(w.get_shape())
982981
print(y.get_shape())

0 commit comments

Comments
 (0)