Skip to content

Commit

Permalink
Merge pull request zzh8829#158 from osljw/ljw_map_fn
Browse files Browse the repository at this point in the history
change return of map_fn
  • Loading branch information
zzh8829 authored Jan 14, 2020
2 parents 7c3ede5 + f41e7c3 commit a50fe6d
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions yolov3_tf2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,11 @@ def yolo_loss(y_true, y_pred):
# 4. calculate all masks
obj_mask = tf.squeeze(true_obj, -1)
# ignore false positive when iou is over threshold
best_iou, _, _ = tf.map_fn(
lambda x: (tf.reduce_max(broadcast_iou(x[0], tf.boolean_mask(
x[1], tf.cast(x[2], tf.bool))), axis=-1), 0, 0),
(pred_box, true_box, obj_mask))
best_iou = tf.map_fn(
lambda x: tf.reduce_max(broadcast_iou(x[0], tf.boolean_mask(
x[1], tf.cast(x[2], tf.bool))), axis=-1),
(pred_box, true_box, obj_mask),
tf.float32)
ignore_mask = tf.cast(best_iou < ignore_thresh, tf.float32)

# 5. calculate all losses
Expand Down

0 comments on commit a50fe6d

Please sign in to comment.