Skip to content

Commit

Permalink
enable absl logging in efficientnet_weight_update_util.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 408701312
  • Loading branch information
haifeng-jin authored and tensorflower-gardener committed Nov 9, 2021
1 parent 5fe94ef commit 8328312
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions keras/applications/efficientnet_weight_update_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@
--ckpt noisy_student_efficientnet-b3/model.ckpt --o efficientnetb3_new.h5
"""

import tensorflow.compat.v2 as tf

import argparse
import warnings

from keras.utils import io_utils
import tensorflow.compat.v2 as tf
from tensorflow.keras.applications import efficientnet


Expand All @@ -65,11 +66,11 @@ def write_ckpt_to_h5(path_h5, path_ckpt, keras_model, use_ema=True):
keras_blocks = get_keras_blocks(keras_weight_names)
tf_blocks = get_tf_blocks(tf_weight_names)

print('check variables match in each block')
io_utils.print_msg('check variables match in each block')
for keras_block, tf_block in zip(keras_blocks, tf_blocks):
check_match(keras_block, tf_block, keras_weight_names, tf_weight_names,
model_name_tf)
print('{} and {} match.'.format(tf_block, keras_block))
io_utils.print_msg('{} and {} match.'.format(tf_block, keras_block))

block_mapping = {x[0]: x[1] for x in zip(keras_blocks, tf_blocks)}

Expand All @@ -89,9 +90,9 @@ def write_ckpt_to_h5(path_h5, path_ckpt, keras_model, use_ema=True):
tf_name = keras_name_to_tf_name_stem_top(
w.name, use_ema=use_ema, model_name_tf=model_name_tf)
elif 'normalization' in w.name:
print('skipping variable {}: normalization is a layer'
'in keras implementation, but preprocessing in '
'TF implementation.'.format(w.name))
io_utils.print_msg(
f'Skipping variable {w.name}: normalization is a Keras '
'preprocessing layer, which does not exist in the TF ckpt.')
continue
else:
raise ValueError('{} failed to parse.'.format(w.name))
Expand All @@ -111,7 +112,7 @@ def write_ckpt_to_h5(path_h5, path_ckpt, keras_model, use_ema=True):
raise ValueError('Fail to load {} from {}'.format(w.name, tf_name))

total_weights = len(keras_model.weights)
print('{}/{} weights updated'.format(changed_weights, total_weights))
io_utils.print_msg(f'{changed_weights}/{total_weights} weights updated')
keras_model.save_weights(path_h5)


Expand Down

0 comments on commit 8328312

Please sign in to comment.