Skip to content

Commit

Permalink
add weighted crossentropy
Browse files Browse the repository at this point in the history
  • Loading branch information
aymericdamien committed Jun 6, 2017
1 parent 47d4507 commit 4eba935
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions tflearn/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,54 @@ def binary_crossentropy(y_pred, y_true):
logits=y_pred, labels=y_true))


def weighted_crossentropy(y_pred, y_true):
""" Weighted Crossentropy.
Computes weighted sigmoid cross entropy between y_pred (logits) and y_true
(labels).
Computes a weighted cross entropy.
This is like sigmoid_cross_entropy_with_logits() except that pos_weight,
allows one to trade off recall and precision by up- or down-weighting the
cost of a positive error relative to a negative error.
The usual cross-entropy cost is defined as:
`targets * -log(sigmoid(logits)) + (1 - targets) * -log(1 - sigmoid(logits))`
The argument pos_weight is used as a multiplier for the positive targets:
`targets * -log(sigmoid(logits)) * pos_weight + (1 - targets) * -log(1 - sigmoid(logits))`
For brevity, let x = logits, z = targets, q = pos_weight. The loss is:
```
qz * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= qz * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= qz * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + (qz + 1 - z) * log(1 + exp(-x))
= (1 - z) * x + (1 + (q - 1) * z) * log(1 + exp(-x))
```
Setting l = (1 + (q - 1) * z), to ensure stability and avoid overflow,
the implementation uses
`(1 - z) * x + l * (log(1 + exp(-abs(x))) + max(-x, 0))`
logits and targets must have the same type and shape.
Arguments:
y_pred: `Tensor` of `float` type. Predicted values.
y_true: `Tensor` of `float` type. Targets (labels).
"""
with tf.name_scope("WeightedCrossentropy"):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=y_pred, labels=y_true))


def mean_square(y_pred, y_true):
""" Mean Square Loss.
Expand Down

0 comments on commit 4eba935

Please sign in to comment.