From b6f95fa1aaccf05622d764acb1cf3ea41f645d0f Mon Sep 17 00:00:00 2001 From: liangliang <107097683+mtjhl@users.noreply.github.com> Date: Sat, 6 May 2023 23:30:21 +0800 Subject: [PATCH] optimize function name of engine and resume logic (#809) optimize function name and resume logic --- yolov6/core/engine.py | 50 ++++++++++++++++++++---------------------- yolov6/utils/events.py | 8 +++---- 2 files changed, 28 insertions(+), 30 deletions(-) diff --git a/yolov6/core/engine.py b/yolov6/core/engine.py index 667d1865..5966e519 100644 --- a/yolov6/core/engine.py +++ b/yolov6/core/engine.py @@ -40,6 +40,7 @@ def __init__(self, args, cfg, device): self.args = args self.cfg = cfg self.device = device + self.max_epoch = args.epochs if args.resume: self.ckpt = torch.load(args.resume, map_location='cpu') @@ -84,10 +85,14 @@ def __init__(self, args, cfg, device): if self.main_process: self.ema.ema.load_state_dict(self.ckpt['ema'].float().state_dict()) self.ema.updates = self.ckpt['updates'] + if self.start_epoch > (self.max_epoch - self.args.stop_aug_last_n_epoch): + self.cfg.data_aug.mosaic = 0.0 + self.cfg.data_aug.mixup = 0.0 + self.train_loader, self.val_loader = self.get_data_loader(self.args, self.cfg, self.data_dict) + self.model = self.parallel_model(args, model, device) self.model.nc, self.model.names = self.data_dict['nc'], self.data_dict['names'] - self.max_epoch = args.epochs self.max_stepnum = len(self.train_loader) self.batch_size = args.batch_size self.img_size = args.img_size @@ -106,9 +111,11 @@ def __init__(self, args, cfg, device): # Training Process def train(self): try: - self.train_before_loop() + self.before_train_loop() for self.epoch in range(self.start_epoch, self.max_epoch): - self.train_in_loop(self.epoch) + self.before_epoch() + self.train_one_epoch(self.epoch) + self.after_epoch() self.strip_model() except Exception as _: @@ -118,22 +125,16 @@ def train(self): self.train_after_loop() # Training loop for each epoch - def train_in_loop(self, epoch_num): + def train_one_epoch(self, epoch_num): try: - self.prepare_for_steps() for self.step, self.batch_data in self.pbar: self.train_in_steps(epoch_num, self.step) self.print_details() except Exception as _: LOGGER.error('ERROR in training steps.') raise - try: - self.eval_and_save() - except Exception as _: - LOGGER.error('ERROR in evaluate and save model.') - raise - # Training loop for batchdata + # Training one batch data. def train_in_steps(self, epoch_num, step_num): images, targets = self.prepro_data(self.batch_data, self.device) # plot train_batch and save to tensorboard once an epoch @@ -165,12 +166,15 @@ def train_in_steps(self, epoch_num, step_num): self.loss_items = loss_items self.update_optimizer() - def eval_and_save(self): - remaining_epochs = self.max_epoch - 1 - self.epoch # self.epoch is start from 0 - eval_interval = self.args.eval_interval if remaining_epochs >= self.args.heavy_eval_range else 3 - is_val_epoch = (remaining_epochs == 0) or ((not self.args.eval_final_only) and ((self.epoch + 1) % eval_interval == 0)) + def after_epoch(self): + lrs_of_this_epoch = [x['lr'] for x in self.optimizer.param_groups] + self.scheduler.step() # update lr if self.main_process: self.ema.update_attr(self.model, include=['nc', 'names', 'stride']) # update attributes for ema model + + remaining_epochs = self.max_epoch - 1 - self.epoch # self.epoch is start from 0 + eval_interval = self.args.eval_interval if remaining_epochs >= self.args.heavy_eval_range else 3 + is_val_epoch = (remaining_epochs == 0) or ((not self.args.eval_final_only) and ((self.epoch + 1) % eval_interval == 0)) if is_val_epoch: self.eval_model() self.ap = self.evaluate_results[1] @@ -198,12 +202,11 @@ def eval_and_save(self): save_checkpoint(ckpt, False, save_ckpt_dir, model_name='best_stop_aug_ckpt') del ckpt - # log for learning rate - lr = [x['lr'] for x in self.optimizer.param_groups] - self.evaluate_results = list(self.evaluate_results) + lr + + self.evaluate_results = list(self.evaluate_results) # log for tensorboard - write_tblog(self.tblogger, self.epoch, self.evaluate_results, self.mean_loss) + write_tblog(self.tblogger, self.epoch, self.evaluate_results, lrs_of_this_epoch, self.mean_loss) # save validation predictions to tensorboard write_tbimg(self.tblogger, self.vis_imgs_list, self.epoch, type='val') @@ -250,7 +253,7 @@ def get_cfg_value(cfg_dict, value_str, default_value): self.plot_val_pred(vis_outputs, vis_paths) - def train_before_loop(self): + def before_train_loop(self): LOGGER.info('Training start...') self.start_time = time.time() self.warmup_stepnum = max(round(self.cfg.solver.warmup_epochs * self.max_stepnum), 1000) if self.args.quant is False else 0 @@ -301,12 +304,7 @@ def train_before_loop(self): distill_feat = self.args.distill_feat, ) - def prepare_for_steps(self): - if self.epoch > self.start_epoch: - self.scheduler.step() - elif hasattr(self, "ckpt") and self.epoch == self.start_epoch: # resume first epoch, load lr - for k, param in enumerate(self.optimizer.param_groups): - param['lr'] = self.scheduler.get_lr()[k] + def before_epoch(self): #stop strong aug like mosaic and mixup from last n epoch by recreate dataloader if self.epoch == self.max_epoch - self.args.stop_aug_last_n_epoch: self.cfg.data_aug.mosaic = 0.0 diff --git a/yolov6/utils/events.py b/yolov6/utils/events.py index 4120a15f..bbc007af 100644 --- a/yolov6/utils/events.py +++ b/yolov6/utils/events.py @@ -30,7 +30,7 @@ def save_yaml(data_dict, save_path): yaml.safe_dump(data_dict, f, sort_keys=False) -def write_tblog(tblogger, epoch, results, losses): +def write_tblog(tblogger, epoch, results, lrs, losses): """Display mAP and loss information to log.""" tblogger.add_scalar("val/mAP@0.5", results[0], epoch + 1) tblogger.add_scalar("val/mAP@0.50:0.95", results[1], epoch + 1) @@ -39,9 +39,9 @@ def write_tblog(tblogger, epoch, results, losses): tblogger.add_scalar("train/dist_focalloss", losses[1], epoch + 1) tblogger.add_scalar("train/cls_loss", losses[2], epoch + 1) - tblogger.add_scalar("x/lr0", results[2], epoch + 1) - tblogger.add_scalar("x/lr1", results[3], epoch + 1) - tblogger.add_scalar("x/lr2", results[4], epoch + 1) + tblogger.add_scalar("x/lr0", lrs[0], epoch + 1) + tblogger.add_scalar("x/lr1", lrs[1], epoch + 1) + tblogger.add_scalar("x/lr2", lrs[2], epoch + 1) def write_tbimg(tblogger, imgs, step, type='train'):