Skip to content

Commit

Permalink
Update save_pruned_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
lucas-ecm authored Feb 25, 2023
1 parent 836113e commit a8baf6b
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions save_pruned_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
flags.DEFINE_boolean('tiny', False, 'is yolo-tiny or not')
flags.DEFINE_integer('input_size', 416, 'define input size of export model')
flags.DEFINE_float('score_thres', 0.2, 'define score threshold')
flags.DEFINE_string('method', 'constant', 'constant or polynomial_decay')
flags.DEFINE_float('target_sparsity', 0.5, 'target_sparsity')
flags.DEFINE_float('initial_sparsity', 0.5, 'initial_sparsity')
flags.DEFINE_float('final_sparsity', 0.8, 'final_sparsity')
flags.DEFINE_string('framework', 'tf', 'define what framework do you want to convert (tf, trt, tflite)')
Expand Down Expand Up @@ -84,12 +86,22 @@ def save_tf():

# Helper function uses `prune_low_magnitude` to make only the
# Dense layers train with pruning.
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=FLAGS.initial_sparsity,
final_sparsity=FLAGS.final_sparsity,
begin_step=0,
end_step=end_step)
}

if FLAGS.method == 'constant':
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(
target_sparsity=FLAGS.target_sparsity,
begin_step=0,
end_step=end_step)
}

else:
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=FLAGS.initial_sparsity,
final_sparsity=FLAGS.final_sparsity,
begin_step=0,
end_step=end_step)
}

def apply_pruning_to_dense(layer):
if isinstance(layer, tf.keras.layers.Conv2D):
Expand All @@ -104,9 +116,18 @@ def apply_pruning_to_dense(layer):
)

model_for_pruning.summary()
model_for_pruning.compile()

model_for_pruning.save(FLAGS.output)


model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

model_for_export.save(FLAGS.output+'stripped')

print("Size of gzipped non pruned model: %.2f bytes" % (get_gzipped_model_size(model)))
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning)))
print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export)))

def main(_argv):
save_tf()

Expand All @@ -115,4 +136,3 @@ def main(_argv):
app.run(main)
except SystemExit:
pass

0 comments on commit a8baf6b

Please sign in to comment.