Skip to content

Commit

Permalink
修改数据增强
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzhiyizhan authored Oct 15, 2020
1 parent 6d02774 commit eff3ff6
Showing 1 changed file with 199 additions and 46 deletions.
245 changes: 199 additions & 46 deletions New_work.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,21 @@
def callback(work_path, project_name):
return f"""import re
import os
import numpy as np
import tensorflow as tf
from loguru import logger
from tqdm.keras import TqdmCallback
from tensorflow.keras import backend as K
from {work_path}.{project_name}.settings import LR
from {work_path}.{project_name}.settings import EPOCHS
from {work_path}.{project_name}.settings import BATCH_SIZE
from {work_path}.{project_name}.settings import log_dir
from {work_path}.{project_name}.settings import csv_path
from {work_path}.{project_name}.settings import train_path
from {work_path}.{project_name}.settings import UPDATE_FREQ
from {work_path}.{project_name}.settings import LR_PATIENCE
from {work_path}.{project_name}.settings import EARLY_PATIENCE
from {work_path}.{project_name}.settings import COSINE_SCHEDULER
from {work_path}.{project_name}.settings import checkpoint_path
from {work_path}.{project_name}.settings import checkpoint_file_path
from {work_path}.{project_name}.utils import Image_Processing
Expand All @@ -31,14 +38,27 @@ def calculate_the_best_weight(self):
value = Image_Processing.extraction_image(checkpoint_path)
extract_num = [os.path.splitext(os.path.split(i)[-1])[0] for i in value]
num = [re.split('-', i) for i in extract_num]
accs = [0-float(i[-1]) for i in num]
accs = [0 - float(i[-1]) for i in num]
losses = [float('-' + str(abs(float(i[-2])))) for i in num]
index = [loss for acc, loss in zip(accs, losses)]
model_dict = dict((ind, val) for ind, val in zip(index, value))
return model_dict.get(max(index))
else:
logger.debug('没有可用的检查点')
@classmethod
def cosine_scheduler(self):
train_number = len(Image_Processing.extraction_image(train_path))
warmup_epoch = int(EPOCHS * 0.2)
total_steps = int(EPOCHS * train_number / BATCH_SIZE)
warmup_steps = int(warmup_epoch * train_number / BATCH_SIZE)
cosine_scheduler_callback = WarmUpCosineDecayScheduler(learning_rate_base=LR, total_steps=total_steps,
warmup_learning_rate=LR * 0.1,
warmup_steps=warmup_steps,
hold_base_rate_steps=train_number,
min_learn_rate=LR * 0.2)
return cosine_scheduler_callback
@classmethod
def callback(self, model):
call = []
Expand All @@ -56,8 +76,10 @@ def callback(self, model):
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, write_images=True,
update_freq=UPDATE_FREQ, write_graph=False)
call.append(tensorboard_callback)
lr_callback = tf.keras.callbacks.ReduceLROnPlateau(factor=0.01, patience=LR_PATIENCE)
if COSINE_SCHEDULER:
lr_callback = self.cosine_scheduler()
else:
lr_callback = tf.keras.callbacks.ReduceLROnPlateau(factor=0.01, patience=LR_PATIENCE)
call.append(lr_callback)
csv_callback = tf.keras.callbacks.CSVLogger(filename=csv_path, append=True)
Expand All @@ -69,6 +91,95 @@ def callback(self, model):
return (model, call)
class WarmUpCosineDecayScheduler(tf.keras.callbacks.Callback):
def __init__(self, learning_rate_base, total_steps, global_step_init=0, warmup_learning_rate=0.0, warmup_steps=0,
hold_base_rate_steps=0, min_learn_rate=0., verbose=1):
super(WarmUpCosineDecayScheduler, self).__init__()
# 基础的学习率
self.learning_rate_base = learning_rate_base
# 热调整参数
self.warmup_learning_rate = warmup_learning_rate
# 参数显示
self.verbose = verbose
# learning_rates用于记录每次更新后的学习率,方便图形化观察
self.min_learn_rate = min_learn_rate
self.learning_rates = []
self.interval_epoch = [0.05, 0.15, 0.30, 0.50]
# 贯穿全局的步长
self.global_step_for_interval = global_step_init
# 用于上升的总步长
self.warmup_steps_for_interval = warmup_steps
# 保持最高峰的总步长
self.hold_steps_for_interval = hold_base_rate_steps
# 整个训练的总步长
self.total_steps_for_interval = total_steps
self.interval_index = 0
# 计算出来两个最低点的间隔
self.interval_reset = [self.interval_epoch[0]]
for i in range(len(self.interval_epoch) - 1):
self.interval_reset.append(self.interval_epoch[i + 1] - self.interval_epoch[i])
self.interval_reset.append(1 - self.interval_epoch[-1])
def cosine_decay_with_warmup(self, global_step, learning_rate_base, total_steps, warmup_learning_rate=0.0,
warmup_steps=0,
hold_base_rate_steps=0, min_learn_rate=0, ):
if total_steps < warmup_steps:
raise ValueError('total_steps must be larger or equal to '
'warmup_steps.')
# 这里实现了余弦退火的原理,设置学习率的最小值为0,所以简化了表达式
learning_rate = 0.5 * learning_rate_base * (1 + np.cos(np.pi *
(
global_step - warmup_steps - hold_base_rate_steps) / float(
total_steps - warmup_steps - hold_base_rate_steps)))
# 如果hold_base_rate_steps大于0,表明在warm up结束后学习率在一定步数内保持不变
if hold_base_rate_steps > 0:
learning_rate = np.where(global_step > warmup_steps + hold_base_rate_steps,
learning_rate, learning_rate_base)
if warmup_steps > 0:
if learning_rate_base < warmup_learning_rate:
raise ValueError('learning_rate_base must be larger or equal to '
'warmup_learning_rate.')
# 线性增长的实现
slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
warmup_rate = slope * global_step + warmup_learning_rate
# 只有当global_step 仍然处于warm up阶段才会使用线性增长的学习率warmup_rate,否则使用余弦退火的学习率learning_rate
learning_rate = np.where(global_step < warmup_steps, warmup_rate,
learning_rate)
learning_rate = max(learning_rate, min_learn_rate)
return learning_rate
# 更新global_step,并记录当前学习率
def on_batch_end(self, batch, logs=None):
self.global_step = self.global_step + 1
self.global_step_for_interval = self.global_step_for_interval + 1
lr = K.get_value(self.model.optimizer.lr)
self.learning_rates.append(lr)
# 更新学习率
def on_batch_begin(self, batch, logs=None):
# 每到一次最低点就重新更新参数
if self.global_step_for_interval in [0] + [int(i * self.total_steps_for_interval) for i in self.interval_epoch]:
self.total_steps = self.total_steps_for_interval * self.interval_reset[self.interval_index]
self.warmup_steps = self.warmup_steps_for_interval * self.interval_reset[self.interval_index]
self.hold_base_rate_steps = self.hold_steps_for_interval * self.interval_reset[self.interval_index]
self.global_step = 0
self.interval_index += 1
lr = self.cosine_decay_with_warmup(global_step=self.global_step,
learning_rate_base=self.learning_rate_base,
total_steps=self.total_steps,
warmup_learning_rate=self.warmup_learning_rate,
warmup_steps=self.warmup_steps,
hold_base_rate_steps=self.hold_base_rate_steps,
min_learn_rate=self.min_learn_rate)
K.set_value(self.model.optimizer.lr, lr)
if self.verbose > 0:
logger.info(f'Batch {{self.global_step}}: setting learning rate to {{lr}}.')
if __name__ == '__main__':
logger.debug(CallBack.calculate_the_best_weight())
Expand Down Expand Up @@ -281,6 +392,7 @@ def utils(work_path, project_name):
from {work_path}.{project_name}.settings import IMAGE_WIDTH
from {work_path}.{project_name}.settings import CAPTCHA_LENGTH
from {work_path}.{project_name}.settings import IMAGE_CHANNALS
from {work_path}.{project_name}.settings import DATA_ENHANCEMENT
from concurrent.futures import ThreadPoolExecutor
right_value = 0
Expand Down Expand Up @@ -479,50 +591,87 @@ def rename_suffix(self, path: list):
name, mix = os.path.splitext(name)
os.rename(i, os.path.join(paths, name + '.jpg'))
# 增强图片
# # 增强图片
# @classmethod
# def preprosess_save_images(self, image, number):
# logger.info(f'开始处理{{image}}')
# with open(image, 'rb') as images:
# im = Image.open(images)
# blur_im = im.filter(ImageFilter.BLUR)
# contour_im = im.filter(ImageFilter.CONTOUR)
# detail_im = im.filter(ImageFilter.DETAIL)
# edge_enhance_im = im.filter(ImageFilter.EDGE_ENHANCE)
# edge_enhance_more_im = im.filter(ImageFilter.EDGE_ENHANCE_MORE)
# emboss_im = im.filter(ImageFilter.EMBOSS)
# flnd_edges_im = im.filter(ImageFilter.FIND_EDGES)
# smooth_im = im.filter(ImageFilter.SMOOTH)
# smooth_more_im = im.filter(ImageFilter.SMOOTH_MORE)
# sharpen_im = im.filter(ImageFilter.SHARPEN)
# maxfilter_im = im.filter(ImageFilter.MaxFilter)
# minfilter_im = im.filter(ImageFilter.MinFilter)
# modefilter_im = im.filter(ImageFilter.ModeFilter)
# medianfilter_im = im.filter(ImageFilter.MedianFilter)
# unsharpmask_im = im.filter(ImageFilter.UnsharpMask)
# left_right_im = im.transpose(Image.FLIP_LEFT_RIGHT)
# top_bottom_im = im.transpose(Image.FLIP_TOP_BOTTOM)
# rotate_list = [im.rotate(i) for i in list(range(1, 360, 60))]
# brightness_im = ImageEnhance.Brightness(im).enhance(0.5)
# brightness_up_im = ImageEnhance.Brightness(im).enhance(1.5)
# color_im = ImageEnhance.Color(im).enhance(0.5)
# color_up_im = ImageEnhance.Color(im).enhance(1.5)
# contrast_im = ImageEnhance.Contrast(im).enhance(0.5)
# contrast_up_im = ImageEnhance.Contrast(im).enhance(1.5)
# sharpness_im = ImageEnhance.Sharpness(im).enhance(0.5)
# sharpness_up_im = ImageEnhance.Sharpness(im).enhance(1.5)
# image_list = [im, blur_im, contour_im, detail_im, edge_enhance_im, edge_enhance_more_im, emboss_im,
# flnd_edges_im,
# smooth_im, smooth_more_im, sharpen_im, maxfilter_im, minfilter_im, modefilter_im,
# medianfilter_im,
# unsharpmask_im, left_right_im,
# top_bottom_im, brightness_im, brightness_up_im, color_im, color_up_im, contrast_im,
# contrast_up_im, sharpness_im, sharpness_up_im] + rotate_list
# for index, file in enumerate(image_list):
# paths, files = os.path.split(image)
# files, suffix = os.path.splitext(files)
# new_file = os.path.join(paths, train_enhance_path, files + str(index) + suffix)
# file.save(new_file)
# logger.success(f'处理完成{{image}},还剩{{number}}张图片待增强')
@classmethod
def preprosess_save_images(self, image, number):
logger.info(f'开始处理{{image}}')
with open(image, 'rb') as images:
im = Image.open(images)
blur_im = im.filter(ImageFilter.BLUR)
contour_im = im.filter(ImageFilter.CONTOUR)
detail_im = im.filter(ImageFilter.DETAIL)
edge_enhance_im = im.filter(ImageFilter.EDGE_ENHANCE)
edge_enhance_more_im = im.filter(ImageFilter.EDGE_ENHANCE_MORE)
emboss_im = im.filter(ImageFilter.EMBOSS)
flnd_edges_im = im.filter(ImageFilter.FIND_EDGES)
smooth_im = im.filter(ImageFilter.SMOOTH)
smooth_more_im = im.filter(ImageFilter.SMOOTH_MORE)
sharpen_im = im.filter(ImageFilter.SHARPEN)
maxfilter_im = im.filter(ImageFilter.MaxFilter)
minfilter_im = im.filter(ImageFilter.MinFilter)
modefilter_im = im.filter(ImageFilter.ModeFilter)
medianfilter_im = im.filter(ImageFilter.MedianFilter)
unsharpmask_im = im.filter(ImageFilter.UnsharpMask)
left_right_im = im.transpose(Image.FLIP_LEFT_RIGHT)
top_bottom_im = im.transpose(Image.FLIP_TOP_BOTTOM)
rotate_list = [im.rotate(i) for i in list(range(1, 360, 60))]
brightness_im = ImageEnhance.Brightness(im).enhance(0.5)
brightness_up_im = ImageEnhance.Brightness(im).enhance(1.5)
color_im = ImageEnhance.Color(im).enhance(0.5)
color_up_im = ImageEnhance.Color(im).enhance(1.5)
contrast_im = ImageEnhance.Contrast(im).enhance(0.5)
contrast_up_im = ImageEnhance.Contrast(im).enhance(1.5)
sharpness_im = ImageEnhance.Sharpness(im).enhance(0.5)
sharpness_up_im = ImageEnhance.Sharpness(im).enhance(1.5)
image_list = [im, blur_im, contour_im, detail_im, edge_enhance_im, edge_enhance_more_im, emboss_im,
flnd_edges_im,
smooth_im, smooth_more_im, sharpen_im, maxfilter_im, minfilter_im, modefilter_im,
medianfilter_im,
unsharpmask_im, left_right_im,
top_bottom_im, brightness_im, brightness_up_im, color_im, color_up_im, contrast_im,
contrast_up_im, sharpness_im, sharpness_up_im] + rotate_list
for index, file in enumerate(image_list):
paths, files = os.path.split(image)
files, suffix = os.path.splitext(files)
new_file = os.path.join(paths, train_enhance_path, files + str(index) + suffix)
file.save(new_file)
name = re.split('_', os.path.splitext(os.path.split(image)[-1])[0])[0]
datagen = tf.keras.preprocessing.image.ImageDataGenerator(featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
zca_epsilon=1e-6,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
brightness_range=(0.7, 1.3),
shear_range=30,
zoom_range=0.2,
channel_shift_range=0.,
fill_mode='nearest',
cval=0.,
horizontal_flip=False,
vertical_flip=False,
rescale=1 / 255,
preprocessing_function=None,
data_format=None,
validation_split=0.0,
dtype=None)
shutil.copy(image, train_enhance_path)
img = tf.keras.preprocessing.image.load_img(image)
x = tf.keras.preprocessing.image.img_to_array(img)
x = np.expand_dims(x, 0)
i = 0
for _ in datagen.flow(x, batch_size=1, save_to_dir=train_enhance_path, save_prefix=name, save_format='jpg'):
i += 1
if i == DATA_ENHANCEMENT:
break
logger.success(f'处理完成{{image}},还剩{{number}}张图片待增强')
@classmethod
Expand Down Expand Up @@ -899,6 +1048,7 @@ def cheak_path(path):
return path
if not os.path.exists(path):
return path
"""


Expand Down Expand Up @@ -2714,6 +2864,9 @@ def settings(work_path, project_name):
# 学习率
LR = 1e-4
#是否使用余弦退火衰减(默认关闭,False使用的是常数衰减)
COSINE_SCHEDULER = False
# 训练次数
EPOCHS = 200
Expand All @@ -2739,7 +2892,7 @@ def settings(work_path, project_name):
# 验证码的长度
CAPTCHA_LENGTH = 8
# 是否使用数据增强(数据集多的时候不需要用)
# 是否使用数据增强(数据集多的时候不需要用,接收一个整数,代表增强多少张图片)
DATA_ENHANCEMENT = False
## 模型设置
Expand Down Expand Up @@ -3131,4 +3284,4 @@ def main(self):


if __name__ == '__main__':
New_Work(work_path='works', project_name='test').main()
New_Work(work_path='works', project_name='simple').main()

0 comments on commit eff3ff6

Please sign in to comment.