Skip to content

Commit

Permalink
prune_only_conv2D
Browse files Browse the repository at this point in the history
  • Loading branch information
lucas-ecm authored Feb 11, 2023
1 parent af65c67 commit 10d0d0f
Showing 1 changed file with 36 additions and 13 deletions.
49 changes: 36 additions & 13 deletions save_pruned_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,23 +64,46 @@ def save_tf():
#end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs
end_step = 1000

# Define model for pruning.
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=FLAGS.initial_sparsity,
final_sparsity=FLAGS.final_sparsity,
begin_step=0,
end_step=end_step)
}
# # Define model for pruning.
# pruning_params = {
# 'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=FLAGS.initial_sparsity,
# final_sparsity=FLAGS.final_sparsity,
# begin_step=0,
# end_step=end_step)
# }

model_for_pruning = prune_low_magnitude(model, **pruning_params)
# model_for_pruning = prune_low_magnitude(model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
# # `prune_low_magnitude` requires a recompile.
# model_for_pruning.compile(optimizer='adam',
# loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
# metrics=['accuracy'])

model_for_pruning.summary()
# model_for_pruning.summary()


# Helper function uses `prune_low_magnitude` to make only the
# Dense layers train with pruning.
pruning_params = {
'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=FLAGS.initial_sparsity,
final_sparsity=FLAGS.final_sparsity,
begin_step=0,
end_step=end_step)
#}

def apply_pruning_to_dense(layer):
if isinstance(layer, tf.keras.layers.Conv2D):
return tfmot.sparsity.keras.prune_low_magnitude(layer)
return layer

# Use `tf.keras.models.clone_model` to apply `apply_pruning_to_dense`
# to the layers of the model.
model_for_pruning = tf.keras.models.clone_model(
base_model,
clone_function=apply_pruning_to_dense,
)

model_for_pruning.summary()

model_for_pruning.save(FLAGS.output)

Expand Down

0 comments on commit 10d0d0f

Please sign in to comment.