Skip to content

Commit

Permalink
bug-fix cifar10_cnn_capsule.py missing K.sum() (keras-team#9520)
Browse files Browse the repository at this point in the history
Just add K.sum() around the margin_loss
  • Loading branch information
saralajew authored and fchollet committed Mar 1, 2018
1 parent 98ffe08 commit 500401b
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/cifar10_cnn_capsule.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def softmax(x, axis=-1):
# define the margin loss like hinge loss
def margin_loss(y_true, y_pred):
lamb, margin = 0.5, 0.1
return y_true * K.square(K.relu(1 - margin - y_pred)) + lamb * (
1 - y_true) * K.square(K.relu(y_pred - margin))
return K.sum(y_true * K.square(K.relu(1 - margin - y_pred)) + lamb * (
1 - y_true) * K.square(K.relu(y_pred - margin)), axis=-1)


class Capsule(Layer):
Expand Down

0 comments on commit 500401b

Please sign in to comment.