Skip to content

Commit 33f61d7

Browse files
committed
add entmax
1 parent fc71ed7 commit 33f61d7

File tree

2 files changed

+180
-2
lines changed

2 files changed

+180
-2
lines changed

rasa/utils/tensorflow/layers.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def build(self, input_shape: tf.TensorShape) -> None:
7979
)
8080
self.kernel_mask = tf.Variable(initial_value=kernel_mask, trainable=False)
8181

82+
8283
def call(self, inputs: tf.Tensor) -> tf.Tensor:
8384
# set some weights to 0 according to precomputed mask
8485
self.kernel.assign(self.kernel * self.kernel_mask)
@@ -521,9 +522,14 @@ def _loss_softmax(
521522
[sim_pos, sim_neg_il, sim_neg_ll, sim_neg_ii, sim_neg_li], -1
522523
)
523524

525+
# pos_labels = tf.ones_like(logits[..., :1])
526+
# neg_labels = tf.zeros_like(logits[..., 1:])
527+
# labels = tf.concat([pos_labels, neg_labels], -1)
528+
524529
# create label_ids for softmax
525530
label_ids = tf.zeros_like(logits[..., 0], tf.int32)
526531

532+
# loss = entmax15_loss_with_logits(labels, logits)
527533
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
528534
labels=label_ids, logits=logits
529535
)
@@ -601,3 +607,175 @@ def call(
601607
)
602608

603609
return loss, acc
610+
611+
612+
# https://gist.github.com/justheuristic/60167e77a95221586be315ae527c3cbd
613+
def entmax15(inputs, axis=-1):
614+
"""
615+
Entmax 1.5 implementation, heavily inspired by
616+
* paper: https://arxiv.org/pdf/1905.05702.pdf
617+
* pytorch code: https://github.com/deep-spin/entmax
618+
:param inputs: similar to softmax logits, but for entmax1.5
619+
:param axis: entmax1.5 outputs will sum to 1 over this axis
620+
:return: entmax activations of same shape as inputs
621+
"""
622+
623+
@tf.custom_gradient
624+
def _entmax_inner(inputs):
625+
with tf.name_scope('entmax'):
626+
inputs = inputs / 2 # divide by 2 so as to solve actual entmax
627+
inputs -= tf.reduce_max(inputs, axis, keepdims=True) # subtract max for stability
628+
629+
threshold, _ = entmax_threshold_and_support(inputs, axis)
630+
outputs_sqrt = tf.nn.relu(inputs - threshold)
631+
outputs = tf.square(outputs_sqrt)
632+
633+
def grad_fn(d_outputs):
634+
with tf.name_scope('entmax_grad'):
635+
d_inputs = d_outputs * outputs_sqrt
636+
q = tf.reduce_sum(d_inputs, axis=axis, keepdims=True)
637+
q = q / tf.reduce_sum(outputs_sqrt, axis=axis, keepdims=True)
638+
d_inputs -= q * outputs_sqrt
639+
return d_inputs
640+
641+
return outputs, grad_fn
642+
643+
return _entmax_inner(inputs)
644+
645+
646+
@tf.custom_gradient
647+
def sparse_entmax15_loss_with_logits(labels, logits):
648+
"""
649+
Computes sample-wise entmax1.5 loss
650+
:param labels: reference answers vector int64[batch_size] \in [0, num_classes)
651+
:param logits: output matrix float32[batch_size, num_classes] (not actually logits :)
652+
:returns: elementwise loss, float32[batch_size]
653+
"""
654+
assert labels.shape.ndims == logits.shape.ndims - 1
655+
with tf.name_scope('entmax_loss'):
656+
p_star = entmax15(logits, axis=-1)
657+
omega_entmax15 = (1 - (tf.reduce_sum(p_star * tf.sqrt(p_star), axis=-1))) / 0.75
658+
p_incr = p_star - tf.one_hot(labels, depth=tf.shape(logits)[-1], axis=-1)
659+
# loss = omega_entmax15 + tf.einsum("ij,ij->i", p_incr, logits)
660+
loss = omega_entmax15 + tf.reduce_sum(p_star * logits, axis=-1)
661+
662+
def grad_fn(grad_output):
663+
with tf.name_scope('entmax_loss_grad'):
664+
return None, grad_output[..., None] * p_incr
665+
666+
return loss, grad_fn
667+
668+
669+
@tf.custom_gradient
670+
def entmax15_loss_with_logits(labels, logits):
671+
"""
672+
Computes sample-wise entmax1.5 loss
673+
:param logits: "logits" matrix float32[batch_size, num_classes]
674+
:param labels: reference answers indicators, float32[batch_size, num_classes]
675+
:returns: elementwise loss, float32[batch_size]
676+
WARNING: this function does not propagate gradients through :labels:
677+
This behavior is the same as like softmax_crossentropy_with_logits v1
678+
It may become an issue if you do something like co-distillation
679+
"""
680+
assert labels.shape.ndims == logits.shape.ndims
681+
with tf.name_scope('entmax_loss'):
682+
p_star = entmax15(logits, axis=-1)
683+
omega_entmax15 = (1 - (tf.reduce_sum(p_star * tf.sqrt(p_star), axis=-1))) / 0.75
684+
p_incr = p_star - labels
685+
# loss = omega_entmax15 + tf.einsum("ij,ij->i", p_incr, logits)
686+
loss = omega_entmax15 + tf.reduce_sum(p_star * logits, axis=-1)
687+
688+
def grad_fn(grad_output):
689+
with tf.name_scope('entmax_loss_grad'):
690+
return None, grad_output[..., None] * p_incr
691+
692+
return loss, grad_fn
693+
694+
695+
def top_k_over_axis(inputs, k, axis=-1, **kwargs):
696+
""" performs tf.nn.top_k over any chosen axis """
697+
with tf.name_scope('top_k_along_axis'):
698+
if axis == -1:
699+
return tf.nn.top_k(inputs, k, **kwargs)
700+
701+
perm_order = list(range(inputs.shape.ndims))
702+
perm_order.append(perm_order.pop(axis))
703+
inv_order = [perm_order.index(i) for i in range(len(perm_order))]
704+
705+
input_perm = tf.transpose(inputs, perm_order)
706+
input_perm_sorted, sort_indices_perm = tf.nn.top_k(
707+
input_perm, k=k, **kwargs)
708+
709+
input_sorted = tf.transpose(input_perm_sorted, inv_order)
710+
sort_indices = tf.transpose(sort_indices_perm, inv_order)
711+
return input_sorted, sort_indices
712+
713+
714+
def _make_ix_like(inputs, axis=-1):
715+
""" creates indices 0, ... , input[axis] unsqueezed to input dimensios """
716+
assert inputs.shape.ndims is not None
717+
rho = tf.cast(tf.range(1, tf.shape(inputs)[axis] + 1), dtype=inputs.dtype)
718+
view = [1] * inputs.shape.ndims
719+
view[axis] = -1
720+
return tf.reshape(rho, view)
721+
722+
723+
def gather_over_axis(values, indices, gather_axis):
724+
"""
725+
replicates the behavior of torch.gather for tf<=1.8;
726+
for newer versions use tf.gather with batch_dims
727+
:param values: tensor [d0, ..., dn]
728+
:param indices: int64 tensor of same shape as values except for gather_axis
729+
:param gather_axis: performs gather along this axis
730+
:returns: gathered values, same shape as values except for gather_axis
731+
If gather_axis == 2
732+
gathered_values[i, j, k, ...] = values[i, j, indices[i, j, k, ...], ...]
733+
see torch.gather for more detils
734+
"""
735+
assert indices.shape.ndims is not None
736+
assert indices.shape.ndims == values.shape.ndims
737+
738+
ndims = indices.shape.ndims
739+
gather_axis = gather_axis % ndims
740+
shape = tf.shape(indices)
741+
742+
selectors = []
743+
for axis_i in range(ndims):
744+
if axis_i == gather_axis:
745+
selectors.append(indices)
746+
else:
747+
index_i = tf.range(tf.cast(shape[axis_i], dtype=indices.dtype), dtype=indices.dtype)
748+
index_i = tf.reshape(index_i, [-1 if i == axis_i else 1 for i in range(ndims)])
749+
index_i = tf.tile(index_i, [shape[i] if i != axis_i else 1 for i in range(ndims)])
750+
selectors.append(index_i)
751+
752+
return tf.gather_nd(values, tf.stack(selectors, axis=-1))
753+
754+
755+
def entmax_threshold_and_support(inputs, axis=-1):
756+
"""
757+
Computes clipping threshold for entmax1.5 over specified axis
758+
NOTE this implementation uses the same heuristic as
759+
the original code: https://tinyurl.com/pytorch-entmax-line-203
760+
:param inputs: (entmax1.5 inputs - max) / 2
761+
:param axis: entmax1.5 outputs will sum to 1 over this axis
762+
"""
763+
764+
with tf.name_scope('entmax_threshold_and_support'):
765+
num_outcomes = tf.shape(inputs)[axis]
766+
inputs_sorted, _ = top_k_over_axis(inputs, k=num_outcomes, axis=axis, sorted=True)
767+
768+
rho = _make_ix_like(inputs, axis=axis)
769+
770+
mean = tf.cumsum(inputs_sorted, axis=axis) / rho
771+
772+
mean_sq = tf.cumsum(tf.square(inputs_sorted), axis=axis) / rho
773+
delta = (1 - rho * (mean_sq - tf.square(mean))) / rho
774+
775+
delta_nz = tf.nn.relu(delta)
776+
tau = mean - tf.sqrt(delta_nz)
777+
778+
support_size = tf.reduce_sum(tf.cast(tf.less_equal(tau, inputs_sorted), dtype=tf.int64), axis=axis, keepdims=True)
779+
780+
tau_star = gather_over_axis(tau, support_size - 1, axis)
781+
return tau_star, support_size

rasa/utils/tensorflow/transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from tensorflow.python.keras.utils import tf_utils
55
from tensorflow.python.keras import backend as K
66
import numpy as np
7-
from rasa.utils.tensorflow.layers import DenseWithSparseWeights
7+
from rasa.utils.tensorflow.layers import DenseWithSparseWeights, entmax15
88

99

1010
# from https://www.tensorflow.org/tutorials/text/transformer
@@ -264,7 +264,7 @@ def _scaled_dot_product_attention(
264264

265265
# softmax is normalized on the last axis (seq_len_k) so that the scores
266266
# add up to 1.
267-
attention_weights = tf.nn.softmax(
267+
attention_weights = entmax15(
268268
logits, axis=-1
269269
) # (..., seq_len_q, seq_len_k)
270270

0 commit comments

Comments
 (0)