Skip to content

Commit

Permalink
Merge pull request WZMIAOMIAO#30 from WZMIAOMIAO/dev
Browse files Browse the repository at this point in the history
update read_ckpt.py
  • Loading branch information
WZMIAOMIAO authored Jul 2, 2020
2 parents 1959fd5 + ec9db30 commit b87fc2a
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 10 deletions.
90 changes: 90 additions & 0 deletions tensorflow_classification/Test3_vgg/fine_train_vgg16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from model import vgg
import tensorflow as tf
import json
import os


data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
image_path = data_root + "/data_set/flower_data/" # flower data set path
train_dir = image_path + "train"
validation_dir = image_path + "val"

# create direction for saving weights
if not os.path.exists("save_weights"):
os.makedirs("save_weights")

im_height = 224
im_width = 224
batch_size = 32
epochs = 10

_R_MEAN = 123.68
_G_MEAN = 116.78
_B_MEAN = 103.94


def pre_function(img):
# img = im.open('test.jpg')
# img = np.array(img).astype(np.float32)
img = img - [_R_MEAN, _G_MEAN, _B_MEAN]

return img


# data generator with data augmentation
train_image_generator = ImageDataGenerator(horizontal_flip=True,
preprocessing_function=pre_function)
validation_image_generator = ImageDataGenerator(preprocessing_function=pre_function)

train_data_gen = train_image_generator.flow_from_directory(directory=train_dir,
batch_size=batch_size,
shuffle=True,
target_size=(im_height, im_width),
class_mode='categorical')
total_train = train_data_gen.n

# get class dict
class_indices = train_data_gen.class_indices

# transform value and key of dict
inverse_dict = dict((val, key) for key, val in class_indices.items())
# write dict into json file
json_str = json.dumps(inverse_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)

val_data_gen = train_image_generator.flow_from_directory(directory=validation_dir,
batch_size=batch_size,
shuffle=True,
target_size=(im_height, im_width),
class_mode='categorical')
total_val = val_data_gen.n

model = vgg("vgg16", 224, 224, 5)
model.load_weights('./pretrain_weights.ckpt')
for layer_t in model.layers:
if layer_t.name == 'feature':
layer_t.trainable = False
break

model.summary()

# using keras high level api for training
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
metrics=["accuracy"])

callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='./save_weights/myAlex_{epoch}.h5',
save_best_only=True,
save_weights_only=True,
monitor='val_loss')]

# tensorflow2.1 recommend to using fit
history = model.fit(x=train_data_gen,
steps_per_epoch=total_train // batch_size,
epochs=epochs,
validation_data=val_data_gen,
validation_steps=total_val // batch_size,
callbacks=callbacks)
72 changes: 72 additions & 0 deletions tensorflow_classification/Test3_vgg/read_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import tensorflow as tf


def rename_var(ckpt_path, new_ckpt_path, num_classes=5):
with tf.Graph().as_default(), tf.compat.v1.Session().as_default() as sess:
var_list = tf.train.list_variables(ckpt_path)
new_var_list = []

for var_name, shape in var_list:
# print(var_name)
if var_name in except_list:
continue

var = tf.train.load_variable(ckpt_path, var_name)
new_var_name = var_name.replace('vgg_16', 'feature')
new_var_name = new_var_name.replace("weights", "kernel")
new_var_name = new_var_name.replace("biases", "bias")

new_var_name = new_var_name.replace("conv1/conv1_1", "conv2d")
new_var_name = new_var_name.replace("conv1/conv1_2", "conv2d_1")

new_var_name = new_var_name.replace("conv2/conv2_1", "conv2d_2")
new_var_name = new_var_name.replace("conv2/conv2_2", "conv2d_3")

new_var_name = new_var_name.replace("conv3/conv3_1", "conv2d_4")
new_var_name = new_var_name.replace("conv3/conv3_2", "conv2d_5")
new_var_name = new_var_name.replace("conv3/conv3_3", "conv2d_6")

new_var_name = new_var_name.replace("conv4/conv4_1", "conv2d_7")
new_var_name = new_var_name.replace("conv4/conv4_2", "conv2d_8")
new_var_name = new_var_name.replace("conv4/conv4_3", "conv2d_9")

new_var_name = new_var_name.replace("conv5/conv5_1", "conv2d_10")
new_var_name = new_var_name.replace("conv5/conv5_2", "conv2d_11")
new_var_name = new_var_name.replace("conv5/conv5_3", "conv2d_12")

if 'fc' in new_var_name:
# new_var_name = new_var_name.replace("feature/fc6", "dense")
# new_var_name = new_var_name.replace("feature/fc7", "dense_1")
# new_var_name = new_var_name.replace("fc8", "dense_2")
continue

# print(new_var_name)
re_var = tf.Variable(var, name=new_var_name)
new_var_list.append(re_var)

re_var = tf.Variable(tf.keras.initializers.he_uniform()([25088, 2048]), name="dense/kernel")
new_var_list.append(re_var)
re_var = tf.Variable(tf.keras.initializers.he_uniform()([2048]), name="dense/bias")
new_var_list.append(re_var)

re_var = tf.Variable(tf.keras.initializers.he_uniform()([2048, 2048]), name="dense_1/kernel")
new_var_list.append(re_var)
re_var = tf.Variable(tf.keras.initializers.he_uniform()([2048]), name="dense_1/bias")
new_var_list.append(re_var)

re_var = tf.Variable(tf.keras.initializers.he_uniform()([2048, num_classes]), name="dense_2/kernel")
new_var_list.append(re_var)
re_var = tf.Variable(tf.keras.initializers.he_uniform()([num_classes]), name="dense_2/bias")
new_var_list.append(re_var)

saver = tf.compat.v1.train.Saver(new_var_list)
sess.run(tf.compat.v1.global_variables_initializer())
saver.save(sess, save_path=new_ckpt_path, write_meta_graph=False, write_state=False)


except_list = ['global_step', 'vgg_16/mean_rgb', 'vgg_16/fc8/biases', 'vgg_16/fc8/weights']
# http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz
ckpt_path = './vgg_16.ckpt'
new_ckpt_path = './pretrain_weights.ckpt'
num_classes = 5
rename_var(ckpt_path, new_ckpt_path, num_classes)
12 changes: 7 additions & 5 deletions tensorflow_classification/Test5_resnet/read_ckpt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import tensorflow as tf


def rename_var(ckpt_path, new_ckpt_path):
def rename_var(ckpt_path, new_ckpt_path, num_classes=5):
with tf.Graph().as_default(), tf.compat.v1.Session().as_default() as sess:
var_list = tf.train.list_variables(ckpt_path)
new_var_list = []

for var_name, shape in var_list:
print(var_name)
if var_name in except_list:
Expand All @@ -17,9 +19,9 @@ def rename_var(ckpt_path, new_ckpt_path):
re_var = tf.Variable(var, name=new_var_name)
new_var_list.append(re_var)

re_var = tf.Variable(tf.keras.initializers.he_uniform()([2048, 5]), name="logits/kernel")
re_var = tf.Variable(tf.keras.initializers.he_uniform()([2048, num_classes]), name="logits/kernel")
new_var_list.append(re_var)
re_var = tf.Variable(tf.keras.initializers.he_uniform()([5]), name="logits/bias")
re_var = tf.Variable(tf.keras.initializers.he_uniform()([num_classes]), name="logits/bias")
new_var_list.append(re_var)
saver = tf.compat.v1.train.Saver(new_var_list)
sess.run(tf.compat.v1.global_variables_initializer())
Expand All @@ -29,5 +31,5 @@ def rename_var(ckpt_path, new_ckpt_path):
except_list = ['global_step', 'resnet_v1_50/mean_rgb', 'resnet_v1_50/logits/biases', 'resnet_v1_50/logits/weights']
ckpt_path = './resnet_v1_50.ckpt'
new_ckpt_path = './pretrain_weights.ckpt'
new_var_list = []
rename_var(ckpt_path, new_ckpt_path)
num_classes = 5
rename_var(ckpt_path, new_ckpt_path, num_classes)
12 changes: 7 additions & 5 deletions tensorflow_classification/Test6_mobilenet/read_ckpt.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import tensorflow as tf


def rename_var(ckpt_path, new_ckpt_path):
def rename_var(ckpt_path, new_ckpt_path, num_classes=5):
with tf.Graph().as_default(), tf.compat.v1.Session().as_default() as sess:
var_list = tf.train.list_variables(ckpt_path)
new_var_list = []

for var_name, shape in var_list:
# print(var_name)
if var_name in except_list:
Expand All @@ -30,9 +32,9 @@ def rename_var(ckpt_path, new_ckpt_path):
re_var = tf.Variable(var, name=new_var_name)
new_var_list.append(re_var)

re_var = tf.Variable(tf.keras.initializers.he_uniform()([1280, 5]), name="Logits/kernel")
re_var = tf.Variable(tf.keras.initializers.he_uniform()([1280, num_classes]), name="Logits/kernel")
new_var_list.append(re_var)
re_var = tf.Variable(tf.keras.initializers.he_uniform()([5]), name="Logits/bias")
re_var = tf.Variable(tf.keras.initializers.he_uniform()([num_classes]), name="Logits/bias")

new_var_list.append(re_var)
tf.keras.initializers.he_uniform()
Expand All @@ -44,5 +46,5 @@ def rename_var(ckpt_path, new_ckpt_path):
except_list = ['global_step', 'MobilenetV2/Logits/Conv2d_1c_1x1/biases', 'MobilenetV2/Logits/Conv2d_1c_1x1/weights']
ckpt_path = './pretain_model/mobilenet_v2_1.0_224.ckpt'
new_ckpt_path = './pretrain_weights.ckpt'
new_var_list = []
rename_var(ckpt_path, new_ckpt_path)
num_classes = 5
rename_var(ckpt_path, new_ckpt_path, num_classes)

0 comments on commit b87fc2a

Please sign in to comment.