Skip to content

Commit 8b6f439

Browse files
committed
review comments
1 parent def314f commit 8b6f439

File tree

2 files changed

+32
-15
lines changed

2 files changed

+32
-15
lines changed

rasa/utils/tensorflow/crf.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,17 @@ def call(
5555
state = tf.expand_dims(state[0], 2)
5656
transition_scores = state + self._transition_params
5757
new_state = inputs + tf.reduce_max(transition_scores, [1])
58+
5859
backpointers = tf.argmax(transition_scores, 1)
5960
backpointers = tf.cast(backpointers, tf.float32)
61+
62+
# apply softmax to transition_scores to get scores in range from 0 to 1
6063
scores = tf.reduce_max(tf.nn.softmax(transition_scores, axis=1), [1])
64+
65+
# In the RNN implementation only the first value that is returned from a cell
66+
# is kept throughout the RNN, so that you will have the values from each time
67+
# step in the final output. As we need the backpointers as well as the scores
68+
# for each time step, we concatenate them.
6169
return tf.concat([backpointers, scores], axis=1), new_state
6270

6371

@@ -90,12 +98,12 @@ def crf_decode_forward(
9098

9199

92100
def crf_decode_backward(
93-
inputs: TensorLike, scores: TensorLike, state: TensorLike
101+
backpointers: TensorLike, scores: TensorLike, state: TensorLike
94102
) -> Tuple[tf.Tensor, tf.Tensor]:
95103
"""Computes backward decoding in a linear-chain CRF.
96104
97105
Args:
98-
inputs: A [batch_size, num_tags] matrix of backpointer of next step
106+
backpointers: A [batch_size, num_tags] matrix of backpointer of next step
99107
(in time order).
100108
scores: A [batch_size, num_tags] matrix of scores of next step (in time order).
101109
state: A [batch_size, 1] matrix of tag index of next step.
@@ -104,16 +112,17 @@ def crf_decode_backward(
104112
new_tags: A [batch_size, num_tags] tensor containing the new tag indices.
105113
new_scores: A [batch_size, num_tags] tensor containing the new score values.
106114
"""
107-
inputs = tf.transpose(inputs, [1, 0, 2])
115+
backpointers = tf.transpose(backpointers, [1, 0, 2])
108116
scores = tf.transpose(scores, [1, 0, 2])
109117

110-
def _scan_fn(state, inputs):
111-
state = tf.cast(tf.squeeze(state, axis=[1]), dtype=tf.int32)
112-
idxs = tf.stack([tf.range(tf.shape(inputs)[0]), state], axis=1)
113-
new_tags = tf.expand_dims(tf.gather_nd(inputs, idxs), axis=-1)
114-
return new_tags
118+
def _scan_fn(_state: TensorLike, _inputs: TensorLike) -> tf.Tensor:
119+
_state = tf.cast(tf.squeeze(_state, axis=[1]), dtype=tf.int32)
120+
idxs = tf.stack([tf.range(tf.shape(_inputs)[0]), _state], axis=1)
121+
return tf.expand_dims(tf.gather_nd(_inputs, idxs), axis=-1)
115122

116-
output_tags = tf.scan(_scan_fn, inputs, state)
123+
output_tags = tf.scan(_scan_fn, backpointers, state)
124+
# the dtype of the input parameters of tf.scan need to match
125+
# convert state to float32 to match the type of scores
117126
state = tf.cast(state, dtype=tf.float32)
118127
output_scores = tf.scan(_scan_fn, scores, state)
119128

@@ -122,7 +131,7 @@ def _scan_fn(state, inputs):
122131

123132
def crf_decode(
124133
potentials: TensorLike, transition_params: TensorLike, sequence_length: TensorLike
125-
) -> Tuple[tf.Tensor, tf.Tensor]:
134+
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
126135
"""Decode the highest scoring sequence of tags.
127136
128137
Args:
@@ -135,18 +144,21 @@ def crf_decode(
135144
Returns:
136145
decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
137146
Contains the highest scoring tag indices.
138-
scores: A [batch_size, max_seq_len] vector, containing the score of `decode_tags`.
147+
decode_scores: A [batch_size, max_seq_len] matrix, containing the score of
148+
`decode_tags`.
149+
best_score: A [batch_size] vector, containing the best score of `decode_tags`.
139150
"""
140151
sequence_length = tf.cast(sequence_length, dtype=tf.int32)
141152

142153
# If max_seq_len is 1, we skip the algorithm and simply return the
143154
# argmax tag and the max activation.
144155
def _single_seq_fn():
145156
decode_tags = tf.cast(tf.argmax(potentials, axis=2), dtype=tf.int32)
146-
best_score = tf.reshape(
157+
decode_scores = tf.reshape(
147158
tf.reduce_max(tf.nn.softmax(potentials, axis=2), axis=2), shape=[-1]
148159
)
149-
return decode_tags, best_score
160+
best_score = tf.reshape(tf.reduce_max(potentials, axis=2), shape=[-1])
161+
return decode_tags, decode_scores, best_score
150162

151163
def _multi_seq_fn():
152164
# Computes forward decoding. Get last score and backpointers.
@@ -162,6 +174,9 @@ def _multi_seq_fn():
162174
inputs, initial_state, transition_params, sequence_length_less_one
163175
)
164176

177+
# output is a matrix of size [batch-size, max-seq-length, num-tags * 2]
178+
# split the matrix on axis 2 to get the backpointers and scores, which are
179+
# both of size [batch-size, max-seq-length, num-tags]
165180
backpointers, scores = tf.split(output, 2, axis=2)
166181

167182
backpointers = tf.cast(backpointers, dtype=tf.int32)
@@ -189,7 +204,9 @@ def _multi_seq_fn():
189204
decode_scores = tf.concat([initial_score, decode_scores], axis=1)
190205
decode_scores = tf.reverse_sequence(decode_scores, sequence_length, seq_axis=1)
191206

192-
return decode_tags, decode_scores
207+
best_score = tf.reduce_max(last_score, axis=1)
208+
209+
return decode_tags, decode_scores, best_score
193210

194211
if potentials.shape[1] is not None:
195212
# shape is statically know, so we just execute

rasa/utils/tensorflow/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def call(
477477
A [batch_size, max_seq_len] matrix, with dtype `tf.float32`.
478478
Contains the confidence values of the highest scoring tag indices.
479479
"""
480-
predicted_ids, scores = rasa.utils.tensorflow.crf.crf_decode(
480+
predicted_ids, scores, _ = rasa.utils.tensorflow.crf.crf_decode(
481481
logits, self.transition_params, sequence_lengths
482482
)
483483
# set prediction index for padding to `0`

0 commit comments

Comments
 (0)