Skip to content

Commit

Permalink
update train vgg script
Browse files Browse the repository at this point in the history
  • Loading branch information
WZMIAOMIAO committed Nov 15, 2020
1 parent efee6d9 commit 44fad31
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 84 deletions.
167 changes: 85 additions & 82 deletions tensorflow_classification/Test3_vgg/fine_train_vgg16.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,85 +6,88 @@
import os


data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
image_path = data_root + "/data_set/flower_data/" # flower data set path
train_dir = image_path + "train"
validation_dir = image_path + "val"

# create direction for saving weights
if not os.path.exists("save_weights"):
os.makedirs("save_weights")

im_height = 224
im_width = 224
batch_size = 32
epochs = 10

_R_MEAN = 123.68
_G_MEAN = 116.78
_B_MEAN = 103.94


def pre_function(img):
# img = im.open('test.jpg')
# img = np.array(img).astype(np.float32)
img = img - [_R_MEAN, _G_MEAN, _B_MEAN]

return img


# data generator with data augmentation
train_image_generator = ImageDataGenerator(horizontal_flip=True,
preprocessing_function=pre_function)
validation_image_generator = ImageDataGenerator(preprocessing_function=pre_function)

train_data_gen = train_image_generator.flow_from_directory(directory=train_dir,
batch_size=batch_size,
shuffle=True,
target_size=(im_height, im_width),
class_mode='categorical')
total_train = train_data_gen.n

# get class dict
class_indices = train_data_gen.class_indices

# transform value and key of dict
inverse_dict = dict((val, key) for key, val in class_indices.items())
# write dict into json file
json_str = json.dumps(inverse_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)

val_data_gen = train_image_generator.flow_from_directory(directory=validation_dir,
batch_size=batch_size,
shuffle=True,
target_size=(im_height, im_width),
class_mode='categorical')
total_val = val_data_gen.n

model = vgg("vgg16", 224, 224, 5)
model.load_weights('./pretrain_weights.ckpt')
for layer_t in model.layers:
if layer_t.name == 'feature':
layer_t.trainable = False
break

model.summary()

# using keras high level api for training
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
metrics=["accuracy"])

callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='./save_weights/myAlex_{epoch}.h5',
save_best_only=True,
save_weights_only=True,
monitor='val_loss')]

# tensorflow2.1 recommend to using fit
history = model.fit(x=train_data_gen,
steps_per_epoch=total_train // batch_size,
epochs=epochs,
validation_data=val_data_gen,
validation_steps=total_val // batch_size,
callbacks=callbacks)
def main():
data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
image_path = data_root + "/data_set/flower_data/" # flower data set path
train_dir = image_path + "train"
validation_dir = image_path + "val"

# create direction for saving weights
if not os.path.exists("save_weights"):
os.makedirs("save_weights")

im_height = 224
im_width = 224
batch_size = 32
epochs = 10

_R_MEAN = 123.68
_G_MEAN = 116.78
_B_MEAN = 103.94

def pre_function(img):
# img = im.open('test.jpg')
# img = np.array(img).astype(np.float32)
img = img - [_R_MEAN, _G_MEAN, _B_MEAN]

return img

# data generator with data augmentation
train_image_generator = ImageDataGenerator(horizontal_flip=True,
preprocessing_function=pre_function)
validation_image_generator = ImageDataGenerator(preprocessing_function=pre_function)

train_data_gen = train_image_generator.flow_from_directory(directory=train_dir,
batch_size=batch_size,
shuffle=True,
target_size=(im_height, im_width),
class_mode='categorical')
total_train = train_data_gen.n

# get class dict
class_indices = train_data_gen.class_indices

# transform value and key of dict
inverse_dict = dict((val, key) for key, val in class_indices.items())
# write dict into json file
json_str = json.dumps(inverse_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)

val_data_gen = validation_image_generator.flow_from_directory(directory=validation_dir,
batch_size=batch_size,
shuffle=False,
target_size=(im_height, im_width),
class_mode='categorical')
total_val = val_data_gen.n

model = vgg("vgg16", 224, 224, 5)
model.load_weights('./pretrain_weights.ckpt')
for layer_t in model.layers:
if layer_t.name == 'feature':
layer_t.trainable = False
break

model.summary()

# using keras high level api for training
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
metrics=["accuracy"])

callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='./save_weights/myAlex_{epoch}.h5',
save_best_only=True,
save_weights_only=True,
monitor='val_loss')]

# tensorflow2.1 recommend to using fit
history = model.fit(x=train_data_gen,
steps_per_epoch=total_train // batch_size,
epochs=epochs,
validation_data=val_data_gen,
validation_steps=total_val // batch_size,
callbacks=callbacks)


if __name__ == '__main__':
main()
2 changes: 1 addition & 1 deletion tensorflow_classification/Test3_vgg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def main():

val_data_gen = validation_image_generator.flow_from_directory(directory=validation_dir,
batch_size=batch_size,
shuffle=True,
shuffle=False,
target_size=(im_height, im_width),
class_mode='categorical')
total_val = val_data_gen.n
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_classification/Test3_vgg/trainGPU.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def process_path(img_path, label):
train_dataset = train_dataset.shuffle(buffer_size=train_num)\
.map(process_path, num_parallel_calls=AUTOTUNE)\
.repeat().batch(batch_size).prefetch(AUTOTUNE)

# load train dataset
val_dataset = tf.data.Dataset.from_tensor_slices((val_image_list, val_label_list))
val_dataset = val_dataset.map(process_path, num_parallel_calls=tf.data.experimental.AUTOTUNE)\
Expand Down

0 comments on commit 44fad31

Please sign in to comment.