forked from WZMIAOMIAO/deep-learning-for-image-processing
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request WZMIAOMIAO#30 from WZMIAOMIAO/dev
update read_ckpt.py
- Loading branch information
Showing
4 changed files
with
176 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from tensorflow.keras.preprocessing.image import ImageDataGenerator | ||
import matplotlib.pyplot as plt | ||
from model import vgg | ||
import tensorflow as tf | ||
import json | ||
import os | ||
|
||
|
||
data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path | ||
image_path = data_root + "/data_set/flower_data/" # flower data set path | ||
train_dir = image_path + "train" | ||
validation_dir = image_path + "val" | ||
|
||
# create direction for saving weights | ||
if not os.path.exists("save_weights"): | ||
os.makedirs("save_weights") | ||
|
||
im_height = 224 | ||
im_width = 224 | ||
batch_size = 32 | ||
epochs = 10 | ||
|
||
_R_MEAN = 123.68 | ||
_G_MEAN = 116.78 | ||
_B_MEAN = 103.94 | ||
|
||
|
||
def pre_function(img): | ||
# img = im.open('test.jpg') | ||
# img = np.array(img).astype(np.float32) | ||
img = img - [_R_MEAN, _G_MEAN, _B_MEAN] | ||
|
||
return img | ||
|
||
|
||
# data generator with data augmentation | ||
train_image_generator = ImageDataGenerator(horizontal_flip=True, | ||
preprocessing_function=pre_function) | ||
validation_image_generator = ImageDataGenerator(preprocessing_function=pre_function) | ||
|
||
train_data_gen = train_image_generator.flow_from_directory(directory=train_dir, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
target_size=(im_height, im_width), | ||
class_mode='categorical') | ||
total_train = train_data_gen.n | ||
|
||
# get class dict | ||
class_indices = train_data_gen.class_indices | ||
|
||
# transform value and key of dict | ||
inverse_dict = dict((val, key) for key, val in class_indices.items()) | ||
# write dict into json file | ||
json_str = json.dumps(inverse_dict, indent=4) | ||
with open('class_indices.json', 'w') as json_file: | ||
json_file.write(json_str) | ||
|
||
val_data_gen = train_image_generator.flow_from_directory(directory=validation_dir, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
target_size=(im_height, im_width), | ||
class_mode='categorical') | ||
total_val = val_data_gen.n | ||
|
||
model = vgg("vgg16", 224, 224, 5) | ||
model.load_weights('./pretrain_weights.ckpt') | ||
for layer_t in model.layers: | ||
if layer_t.name == 'feature': | ||
layer_t.trainable = False | ||
break | ||
|
||
model.summary() | ||
|
||
# using keras high level api for training | ||
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), | ||
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False), | ||
metrics=["accuracy"]) | ||
|
||
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='./save_weights/myAlex_{epoch}.h5', | ||
save_best_only=True, | ||
save_weights_only=True, | ||
monitor='val_loss')] | ||
|
||
# tensorflow2.1 recommend to using fit | ||
history = model.fit(x=train_data_gen, | ||
steps_per_epoch=total_train // batch_size, | ||
epochs=epochs, | ||
validation_data=val_data_gen, | ||
validation_steps=total_val // batch_size, | ||
callbacks=callbacks) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import tensorflow as tf | ||
|
||
|
||
def rename_var(ckpt_path, new_ckpt_path, num_classes=5): | ||
with tf.Graph().as_default(), tf.compat.v1.Session().as_default() as sess: | ||
var_list = tf.train.list_variables(ckpt_path) | ||
new_var_list = [] | ||
|
||
for var_name, shape in var_list: | ||
# print(var_name) | ||
if var_name in except_list: | ||
continue | ||
|
||
var = tf.train.load_variable(ckpt_path, var_name) | ||
new_var_name = var_name.replace('vgg_16', 'feature') | ||
new_var_name = new_var_name.replace("weights", "kernel") | ||
new_var_name = new_var_name.replace("biases", "bias") | ||
|
||
new_var_name = new_var_name.replace("conv1/conv1_1", "conv2d") | ||
new_var_name = new_var_name.replace("conv1/conv1_2", "conv2d_1") | ||
|
||
new_var_name = new_var_name.replace("conv2/conv2_1", "conv2d_2") | ||
new_var_name = new_var_name.replace("conv2/conv2_2", "conv2d_3") | ||
|
||
new_var_name = new_var_name.replace("conv3/conv3_1", "conv2d_4") | ||
new_var_name = new_var_name.replace("conv3/conv3_2", "conv2d_5") | ||
new_var_name = new_var_name.replace("conv3/conv3_3", "conv2d_6") | ||
|
||
new_var_name = new_var_name.replace("conv4/conv4_1", "conv2d_7") | ||
new_var_name = new_var_name.replace("conv4/conv4_2", "conv2d_8") | ||
new_var_name = new_var_name.replace("conv4/conv4_3", "conv2d_9") | ||
|
||
new_var_name = new_var_name.replace("conv5/conv5_1", "conv2d_10") | ||
new_var_name = new_var_name.replace("conv5/conv5_2", "conv2d_11") | ||
new_var_name = new_var_name.replace("conv5/conv5_3", "conv2d_12") | ||
|
||
if 'fc' in new_var_name: | ||
# new_var_name = new_var_name.replace("feature/fc6", "dense") | ||
# new_var_name = new_var_name.replace("feature/fc7", "dense_1") | ||
# new_var_name = new_var_name.replace("fc8", "dense_2") | ||
continue | ||
|
||
# print(new_var_name) | ||
re_var = tf.Variable(var, name=new_var_name) | ||
new_var_list.append(re_var) | ||
|
||
re_var = tf.Variable(tf.keras.initializers.he_uniform()([25088, 2048]), name="dense/kernel") | ||
new_var_list.append(re_var) | ||
re_var = tf.Variable(tf.keras.initializers.he_uniform()([2048]), name="dense/bias") | ||
new_var_list.append(re_var) | ||
|
||
re_var = tf.Variable(tf.keras.initializers.he_uniform()([2048, 2048]), name="dense_1/kernel") | ||
new_var_list.append(re_var) | ||
re_var = tf.Variable(tf.keras.initializers.he_uniform()([2048]), name="dense_1/bias") | ||
new_var_list.append(re_var) | ||
|
||
re_var = tf.Variable(tf.keras.initializers.he_uniform()([2048, num_classes]), name="dense_2/kernel") | ||
new_var_list.append(re_var) | ||
re_var = tf.Variable(tf.keras.initializers.he_uniform()([num_classes]), name="dense_2/bias") | ||
new_var_list.append(re_var) | ||
|
||
saver = tf.compat.v1.train.Saver(new_var_list) | ||
sess.run(tf.compat.v1.global_variables_initializer()) | ||
saver.save(sess, save_path=new_ckpt_path, write_meta_graph=False, write_state=False) | ||
|
||
|
||
except_list = ['global_step', 'vgg_16/mean_rgb', 'vgg_16/fc8/biases', 'vgg_16/fc8/weights'] | ||
# http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz | ||
ckpt_path = './vgg_16.ckpt' | ||
new_ckpt_path = './pretrain_weights.ckpt' | ||
num_classes = 5 | ||
rename_var(ckpt_path, new_ckpt_path, num_classes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters