@@ -79,6 +79,7 @@ def build(self, input_shape: tf.TensorShape) -> None:
79
79
)
80
80
self .kernel_mask = tf .Variable (initial_value = kernel_mask , trainable = False )
81
81
82
+
82
83
def call (self , inputs : tf .Tensor ) -> tf .Tensor :
83
84
# set some weights to 0 according to precomputed mask
84
85
self .kernel .assign (self .kernel * self .kernel_mask )
@@ -521,9 +522,14 @@ def _loss_softmax(
521
522
[sim_pos , sim_neg_il , sim_neg_ll , sim_neg_ii , sim_neg_li ], - 1
522
523
)
523
524
525
+ # pos_labels = tf.ones_like(logits[..., :1])
526
+ # neg_labels = tf.zeros_like(logits[..., 1:])
527
+ # labels = tf.concat([pos_labels, neg_labels], -1)
528
+
524
529
# create label_ids for softmax
525
530
label_ids = tf .zeros_like (logits [..., 0 ], tf .int32 )
526
531
532
+ # loss = entmax15_loss_with_logits(labels, logits)
527
533
loss = tf .nn .sparse_softmax_cross_entropy_with_logits (
528
534
labels = label_ids , logits = logits
529
535
)
@@ -601,3 +607,175 @@ def call(
601
607
)
602
608
603
609
return loss , acc
610
+
611
+
612
+ # https://gist.github.com/justheuristic/60167e77a95221586be315ae527c3cbd
613
+ def entmax15 (inputs , axis = - 1 ):
614
+ """
615
+ Entmax 1.5 implementation, heavily inspired by
616
+ * paper: https://arxiv.org/pdf/1905.05702.pdf
617
+ * pytorch code: https://github.com/deep-spin/entmax
618
+ :param inputs: similar to softmax logits, but for entmax1.5
619
+ :param axis: entmax1.5 outputs will sum to 1 over this axis
620
+ :return: entmax activations of same shape as inputs
621
+ """
622
+
623
+ @tf .custom_gradient
624
+ def _entmax_inner (inputs ):
625
+ with tf .name_scope ('entmax' ):
626
+ inputs = inputs / 2 # divide by 2 so as to solve actual entmax
627
+ inputs -= tf .reduce_max (inputs , axis , keepdims = True ) # subtract max for stability
628
+
629
+ threshold , _ = entmax_threshold_and_support (inputs , axis )
630
+ outputs_sqrt = tf .nn .relu (inputs - threshold )
631
+ outputs = tf .square (outputs_sqrt )
632
+
633
+ def grad_fn (d_outputs ):
634
+ with tf .name_scope ('entmax_grad' ):
635
+ d_inputs = d_outputs * outputs_sqrt
636
+ q = tf .reduce_sum (d_inputs , axis = axis , keepdims = True )
637
+ q = q / tf .reduce_sum (outputs_sqrt , axis = axis , keepdims = True )
638
+ d_inputs -= q * outputs_sqrt
639
+ return d_inputs
640
+
641
+ return outputs , grad_fn
642
+
643
+ return _entmax_inner (inputs )
644
+
645
+
646
+ @tf .custom_gradient
647
+ def sparse_entmax15_loss_with_logits (labels , logits ):
648
+ """
649
+ Computes sample-wise entmax1.5 loss
650
+ :param labels: reference answers vector int64[batch_size] \in [0, num_classes)
651
+ :param logits: output matrix float32[batch_size, num_classes] (not actually logits :)
652
+ :returns: elementwise loss, float32[batch_size]
653
+ """
654
+ assert labels .shape .ndims == logits .shape .ndims - 1
655
+ with tf .name_scope ('entmax_loss' ):
656
+ p_star = entmax15 (logits , axis = - 1 )
657
+ omega_entmax15 = (1 - (tf .reduce_sum (p_star * tf .sqrt (p_star ), axis = - 1 ))) / 0.75
658
+ p_incr = p_star - tf .one_hot (labels , depth = tf .shape (logits )[- 1 ], axis = - 1 )
659
+ # loss = omega_entmax15 + tf.einsum("ij,ij->i", p_incr, logits)
660
+ loss = omega_entmax15 + tf .reduce_sum (p_star * logits , axis = - 1 )
661
+
662
+ def grad_fn (grad_output ):
663
+ with tf .name_scope ('entmax_loss_grad' ):
664
+ return None , grad_output [..., None ] * p_incr
665
+
666
+ return loss , grad_fn
667
+
668
+
669
+ @tf .custom_gradient
670
+ def entmax15_loss_with_logits (labels , logits ):
671
+ """
672
+ Computes sample-wise entmax1.5 loss
673
+ :param logits: "logits" matrix float32[batch_size, num_classes]
674
+ :param labels: reference answers indicators, float32[batch_size, num_classes]
675
+ :returns: elementwise loss, float32[batch_size]
676
+ WARNING: this function does not propagate gradients through :labels:
677
+ This behavior is the same as like softmax_crossentropy_with_logits v1
678
+ It may become an issue if you do something like co-distillation
679
+ """
680
+ assert labels .shape .ndims == logits .shape .ndims
681
+ with tf .name_scope ('entmax_loss' ):
682
+ p_star = entmax15 (logits , axis = - 1 )
683
+ omega_entmax15 = (1 - (tf .reduce_sum (p_star * tf .sqrt (p_star ), axis = - 1 ))) / 0.75
684
+ p_incr = p_star - labels
685
+ # loss = omega_entmax15 + tf.einsum("ij,ij->i", p_incr, logits)
686
+ loss = omega_entmax15 + tf .reduce_sum (p_star * logits , axis = - 1 )
687
+
688
+ def grad_fn (grad_output ):
689
+ with tf .name_scope ('entmax_loss_grad' ):
690
+ return None , grad_output [..., None ] * p_incr
691
+
692
+ return loss , grad_fn
693
+
694
+
695
+ def top_k_over_axis (inputs , k , axis = - 1 , ** kwargs ):
696
+ """ performs tf.nn.top_k over any chosen axis """
697
+ with tf .name_scope ('top_k_along_axis' ):
698
+ if axis == - 1 :
699
+ return tf .nn .top_k (inputs , k , ** kwargs )
700
+
701
+ perm_order = list (range (inputs .shape .ndims ))
702
+ perm_order .append (perm_order .pop (axis ))
703
+ inv_order = [perm_order .index (i ) for i in range (len (perm_order ))]
704
+
705
+ input_perm = tf .transpose (inputs , perm_order )
706
+ input_perm_sorted , sort_indices_perm = tf .nn .top_k (
707
+ input_perm , k = k , ** kwargs )
708
+
709
+ input_sorted = tf .transpose (input_perm_sorted , inv_order )
710
+ sort_indices = tf .transpose (sort_indices_perm , inv_order )
711
+ return input_sorted , sort_indices
712
+
713
+
714
+ def _make_ix_like (inputs , axis = - 1 ):
715
+ """ creates indices 0, ... , input[axis] unsqueezed to input dimensios """
716
+ assert inputs .shape .ndims is not None
717
+ rho = tf .cast (tf .range (1 , tf .shape (inputs )[axis ] + 1 ), dtype = inputs .dtype )
718
+ view = [1 ] * inputs .shape .ndims
719
+ view [axis ] = - 1
720
+ return tf .reshape (rho , view )
721
+
722
+
723
+ def gather_over_axis (values , indices , gather_axis ):
724
+ """
725
+ replicates the behavior of torch.gather for tf<=1.8;
726
+ for newer versions use tf.gather with batch_dims
727
+ :param values: tensor [d0, ..., dn]
728
+ :param indices: int64 tensor of same shape as values except for gather_axis
729
+ :param gather_axis: performs gather along this axis
730
+ :returns: gathered values, same shape as values except for gather_axis
731
+ If gather_axis == 2
732
+ gathered_values[i, j, k, ...] = values[i, j, indices[i, j, k, ...], ...]
733
+ see torch.gather for more detils
734
+ """
735
+ assert indices .shape .ndims is not None
736
+ assert indices .shape .ndims == values .shape .ndims
737
+
738
+ ndims = indices .shape .ndims
739
+ gather_axis = gather_axis % ndims
740
+ shape = tf .shape (indices )
741
+
742
+ selectors = []
743
+ for axis_i in range (ndims ):
744
+ if axis_i == gather_axis :
745
+ selectors .append (indices )
746
+ else :
747
+ index_i = tf .range (tf .cast (shape [axis_i ], dtype = indices .dtype ), dtype = indices .dtype )
748
+ index_i = tf .reshape (index_i , [- 1 if i == axis_i else 1 for i in range (ndims )])
749
+ index_i = tf .tile (index_i , [shape [i ] if i != axis_i else 1 for i in range (ndims )])
750
+ selectors .append (index_i )
751
+
752
+ return tf .gather_nd (values , tf .stack (selectors , axis = - 1 ))
753
+
754
+
755
+ def entmax_threshold_and_support (inputs , axis = - 1 ):
756
+ """
757
+ Computes clipping threshold for entmax1.5 over specified axis
758
+ NOTE this implementation uses the same heuristic as
759
+ the original code: https://tinyurl.com/pytorch-entmax-line-203
760
+ :param inputs: (entmax1.5 inputs - max) / 2
761
+ :param axis: entmax1.5 outputs will sum to 1 over this axis
762
+ """
763
+
764
+ with tf .name_scope ('entmax_threshold_and_support' ):
765
+ num_outcomes = tf .shape (inputs )[axis ]
766
+ inputs_sorted , _ = top_k_over_axis (inputs , k = num_outcomes , axis = axis , sorted = True )
767
+
768
+ rho = _make_ix_like (inputs , axis = axis )
769
+
770
+ mean = tf .cumsum (inputs_sorted , axis = axis ) / rho
771
+
772
+ mean_sq = tf .cumsum (tf .square (inputs_sorted ), axis = axis ) / rho
773
+ delta = (1 - rho * (mean_sq - tf .square (mean ))) / rho
774
+
775
+ delta_nz = tf .nn .relu (delta )
776
+ tau = mean - tf .sqrt (delta_nz )
777
+
778
+ support_size = tf .reduce_sum (tf .cast (tf .less_equal (tau , inputs_sorted ), dtype = tf .int64 ), axis = axis , keepdims = True )
779
+
780
+ tau_star = gather_over_axis (tau , support_size - 1 , axis )
781
+ return tau_star , support_size
0 commit comments