Skip to content

Commit 88a7efa

Browse files
committed
add comments
1 parent c68dc92 commit 88a7efa

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

rasa/utils/tensorflow/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,17 +187,22 @@ def train_on_batch(
187187
) -> None:
188188
"""Train on batch"""
189189

190+
# calculate supervision and regularization losses separately
190191
with tf.GradientTape(persistent=True) as tape:
191192
prediction_loss = self.batch_loss(batch_in)
192193
regularization_loss = tf.math.add_n(self.losses)
193194
total_loss = prediction_loss + regularization_loss
194195

195196
self.total_loss.update_state(total_loss)
196197

198+
# calculate the gradients that comes from supervision signal
197199
prediction_gradients = tape.gradient(prediction_loss, self.trainable_variables)
200+
# calculate the gradients that comes from regularization
198201
regularization_gradients = tape.gradient(
199202
regularization_loss, self.trainable_variables
200203
)
204+
# delete gradient tape manually
205+
# since it was created with `persistent=True` option
201206
del tape
202207

203208
gradients = []

0 commit comments

Comments
 (0)