Skip to content

Commit

Permalink
Fix tflite export (zzh8829#342)
Browse files Browse the repository at this point in the history
* Tflite converter fix

Co-authored-by: zwanto <[email protected]>
  • Loading branch information
Antoine HAMON and ZwAnto authored Apr 12, 2021
1 parent 71208e1 commit dbfdf41
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
7 changes: 6 additions & 1 deletion tools/export_tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
flags.DEFINE_integer('num_classes', 80, 'number of classes in the model')
flags.DEFINE_integer('size', 416, 'image size')

# TODO: This is broken DOES NOT WORK !!

def main(_argv):
if FLAGS.tiny:
yolo = YoloV3Tiny(size=FLAGS.size, classes=FLAGS.num_classes)
Expand All @@ -34,6 +34,11 @@ def main(_argv):
logging.info('weights loaded')

converter = tf.lite.TFLiteConverter.from_keras_model(yolo)

# Fix from https://stackoverflow.com/questions/64490203/tf-lite-non-max-suppression
converter.experimental_new_converter = True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]

tflite_model = converter.convert()
open(FLAGS.output, 'wb').write(tflite_model)
logging.info("model saved to: {}".format(FLAGS.output))
Expand Down
12 changes: 11 additions & 1 deletion yolov3_tf2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ def yolo_output(x_in):
return yolo_output


# As tensorflow lite doesn't support tf.size used in tf.meshgrid,
# we reimplemented a simple meshgrid function that use basic tf function.
def _meshgrid(n_a, n_b):

return [
tf.reshape(tf.tile(tf.range(n_a), [n_b]), (n_b, n_a)),
tf.reshape(tf.repeat(tf.range(n_b), n_a), (n_b, n_a))
]


def yolo_boxes(pred, anchors, classes):
# pred: (batch_size, grid, grid, anchors, (x, y, w, h, obj, ...classes))
grid_size = tf.shape(pred)[1:3]
Expand All @@ -160,7 +170,7 @@ def yolo_boxes(pred, anchors, classes):
pred_box = tf.concat((box_xy, box_wh), axis=-1) # original xywh for loss

# !!! grid[x][y] == (y, x)
grid = tf.meshgrid(tf.range(grid_size[1]), tf.range(grid_size[0]))
grid = _meshgrid(grid_size[1],grid_size[0])
grid = tf.expand_dims(tf.stack(grid, axis=-1), axis=2) # [gx, gy, 1, 2]

box_xy = (box_xy + tf.cast(grid, tf.float32)) / \
Expand Down

0 comments on commit dbfdf41

Please sign in to comment.