@@ -55,9 +55,17 @@ def call(
55
55
state = tf .expand_dims (state [0 ], 2 )
56
56
transition_scores = state + self ._transition_params
57
57
new_state = inputs + tf .reduce_max (transition_scores , [1 ])
58
+
58
59
backpointers = tf .argmax (transition_scores , 1 )
59
60
backpointers = tf .cast (backpointers , tf .float32 )
61
+
62
+ # apply softmax to transition_scores to get scores in range from 0 to 1
60
63
scores = tf .reduce_max (tf .nn .softmax (transition_scores , axis = 1 ), [1 ])
64
+
65
+ # In the RNN implementation only the first value that is returned from a cell
66
+ # is kept throughout the RNN, so that you will have the values from each time
67
+ # step in the final output. As we need the backpointers as well as the scores
68
+ # for each time step, we concatenate them.
61
69
return tf .concat ([backpointers , scores ], axis = 1 ), new_state
62
70
63
71
@@ -90,12 +98,12 @@ def crf_decode_forward(
90
98
91
99
92
100
def crf_decode_backward (
93
- inputs : TensorLike , scores : TensorLike , state : TensorLike
101
+ backpointers : TensorLike , scores : TensorLike , state : TensorLike
94
102
) -> Tuple [tf .Tensor , tf .Tensor ]:
95
103
"""Computes backward decoding in a linear-chain CRF.
96
104
97
105
Args:
98
- inputs : A [batch_size, num_tags] matrix of backpointer of next step
106
+ backpointers : A [batch_size, num_tags] matrix of backpointer of next step
99
107
(in time order).
100
108
scores: A [batch_size, num_tags] matrix of scores of next step (in time order).
101
109
state: A [batch_size, 1] matrix of tag index of next step.
@@ -104,16 +112,17 @@ def crf_decode_backward(
104
112
new_tags: A [batch_size, num_tags] tensor containing the new tag indices.
105
113
new_scores: A [batch_size, num_tags] tensor containing the new score values.
106
114
"""
107
- inputs = tf .transpose (inputs , [1 , 0 , 2 ])
115
+ backpointers = tf .transpose (backpointers , [1 , 0 , 2 ])
108
116
scores = tf .transpose (scores , [1 , 0 , 2 ])
109
117
110
- def _scan_fn (state , inputs ):
111
- state = tf .cast (tf .squeeze (state , axis = [1 ]), dtype = tf .int32 )
112
- idxs = tf .stack ([tf .range (tf .shape (inputs )[0 ]), state ], axis = 1 )
113
- new_tags = tf .expand_dims (tf .gather_nd (inputs , idxs ), axis = - 1 )
114
- return new_tags
118
+ def _scan_fn (_state : TensorLike , _inputs : TensorLike ) -> tf .Tensor :
119
+ _state = tf .cast (tf .squeeze (_state , axis = [1 ]), dtype = tf .int32 )
120
+ idxs = tf .stack ([tf .range (tf .shape (_inputs )[0 ]), _state ], axis = 1 )
121
+ return tf .expand_dims (tf .gather_nd (_inputs , idxs ), axis = - 1 )
115
122
116
- output_tags = tf .scan (_scan_fn , inputs , state )
123
+ output_tags = tf .scan (_scan_fn , backpointers , state )
124
+ # the dtype of the input parameters of tf.scan need to match
125
+ # convert state to float32 to match the type of scores
117
126
state = tf .cast (state , dtype = tf .float32 )
118
127
output_scores = tf .scan (_scan_fn , scores , state )
119
128
@@ -122,7 +131,7 @@ def _scan_fn(state, inputs):
122
131
123
132
def crf_decode (
124
133
potentials : TensorLike , transition_params : TensorLike , sequence_length : TensorLike
125
- ) -> Tuple [tf .Tensor , tf .Tensor ]:
134
+ ) -> Tuple [tf .Tensor , tf .Tensor , tf . Tensor ]:
126
135
"""Decode the highest scoring sequence of tags.
127
136
128
137
Args:
@@ -135,18 +144,21 @@ def crf_decode(
135
144
Returns:
136
145
decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
137
146
Contains the highest scoring tag indices.
138
- scores: A [batch_size, max_seq_len] vector, containing the score of `decode_tags`.
147
+ decode_scores: A [batch_size, max_seq_len] matrix, containing the score of
148
+ `decode_tags`.
149
+ best_score: A [batch_size] vector, containing the best score of `decode_tags`.
139
150
"""
140
151
sequence_length = tf .cast (sequence_length , dtype = tf .int32 )
141
152
142
153
# If max_seq_len is 1, we skip the algorithm and simply return the
143
154
# argmax tag and the max activation.
144
155
def _single_seq_fn ():
145
156
decode_tags = tf .cast (tf .argmax (potentials , axis = 2 ), dtype = tf .int32 )
146
- best_score = tf .reshape (
157
+ decode_scores = tf .reshape (
147
158
tf .reduce_max (tf .nn .softmax (potentials , axis = 2 ), axis = 2 ), shape = [- 1 ]
148
159
)
149
- return decode_tags , best_score
160
+ best_score = tf .reshape (tf .reduce_max (potentials , axis = 2 ), shape = [- 1 ])
161
+ return decode_tags , decode_scores , best_score
150
162
151
163
def _multi_seq_fn ():
152
164
# Computes forward decoding. Get last score and backpointers.
@@ -162,6 +174,9 @@ def _multi_seq_fn():
162
174
inputs , initial_state , transition_params , sequence_length_less_one
163
175
)
164
176
177
+ # output is a matrix of size [batch-size, max-seq-length, num-tags * 2]
178
+ # split the matrix on axis 2 to get the backpointers and scores, which are
179
+ # both of size [batch-size, max-seq-length, num-tags]
165
180
backpointers , scores = tf .split (output , 2 , axis = 2 )
166
181
167
182
backpointers = tf .cast (backpointers , dtype = tf .int32 )
@@ -189,7 +204,9 @@ def _multi_seq_fn():
189
204
decode_scores = tf .concat ([initial_score , decode_scores ], axis = 1 )
190
205
decode_scores = tf .reverse_sequence (decode_scores , sequence_length , seq_axis = 1 )
191
206
192
- return decode_tags , decode_scores
207
+ best_score = tf .reduce_max (last_score , axis = 1 )
208
+
209
+ return decode_tags , decode_scores , best_score
193
210
194
211
if potentials .shape [1 ] is not None :
195
212
# shape is statically know, so we just execute
0 commit comments