File tree Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Original file line number Diff line number Diff line change @@ -187,17 +187,22 @@ def train_on_batch(
187
187
) -> None :
188
188
"""Train on batch"""
189
189
190
+ # calculate supervision and regularization losses separately
190
191
with tf .GradientTape (persistent = True ) as tape :
191
192
prediction_loss = self .batch_loss (batch_in )
192
193
regularization_loss = tf .math .add_n (self .losses )
193
194
total_loss = prediction_loss + regularization_loss
194
195
195
196
self .total_loss .update_state (total_loss )
196
197
198
+ # calculate the gradients that comes from supervision signal
197
199
prediction_gradients = tape .gradient (prediction_loss , self .trainable_variables )
200
+ # calculate the gradients that comes from regularization
198
201
regularization_gradients = tape .gradient (
199
202
regularization_loss , self .trainable_variables
200
203
)
204
+ # delete gradient tape manually
205
+ # since it was created with `persistent=True` option
201
206
del tape
202
207
203
208
gradients = []
You can’t perform that action at this time.
0 commit comments