Skip to content

Commit

Permalink
optimize function name of engine and resume logic (meituan#809)
Browse files Browse the repository at this point in the history
optimize function name and resume logic
  • Loading branch information
mtjhl authored May 6, 2023
1 parent 94c846b commit b6f95fa
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 30 deletions.
50 changes: 24 additions & 26 deletions yolov6/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, args, cfg, device):
self.args = args
self.cfg = cfg
self.device = device
self.max_epoch = args.epochs

if args.resume:
self.ckpt = torch.load(args.resume, map_location='cpu')
Expand Down Expand Up @@ -84,10 +85,14 @@ def __init__(self, args, cfg, device):
if self.main_process:
self.ema.ema.load_state_dict(self.ckpt['ema'].float().state_dict())
self.ema.updates = self.ckpt['updates']
if self.start_epoch > (self.max_epoch - self.args.stop_aug_last_n_epoch):
self.cfg.data_aug.mosaic = 0.0
self.cfg.data_aug.mixup = 0.0
self.train_loader, self.val_loader = self.get_data_loader(self.args, self.cfg, self.data_dict)

self.model = self.parallel_model(args, model, device)
self.model.nc, self.model.names = self.data_dict['nc'], self.data_dict['names']

self.max_epoch = args.epochs
self.max_stepnum = len(self.train_loader)
self.batch_size = args.batch_size
self.img_size = args.img_size
Expand All @@ -106,9 +111,11 @@ def __init__(self, args, cfg, device):
# Training Process
def train(self):
try:
self.train_before_loop()
self.before_train_loop()
for self.epoch in range(self.start_epoch, self.max_epoch):
self.train_in_loop(self.epoch)
self.before_epoch()
self.train_one_epoch(self.epoch)
self.after_epoch()
self.strip_model()

except Exception as _:
Expand All @@ -118,22 +125,16 @@ def train(self):
self.train_after_loop()

# Training loop for each epoch
def train_in_loop(self, epoch_num):
def train_one_epoch(self, epoch_num):
try:
self.prepare_for_steps()
for self.step, self.batch_data in self.pbar:
self.train_in_steps(epoch_num, self.step)
self.print_details()
except Exception as _:
LOGGER.error('ERROR in training steps.')
raise
try:
self.eval_and_save()
except Exception as _:
LOGGER.error('ERROR in evaluate and save model.')
raise

# Training loop for batchdata
# Training one batch data.
def train_in_steps(self, epoch_num, step_num):
images, targets = self.prepro_data(self.batch_data, self.device)
# plot train_batch and save to tensorboard once an epoch
Expand Down Expand Up @@ -165,12 +166,15 @@ def train_in_steps(self, epoch_num, step_num):
self.loss_items = loss_items
self.update_optimizer()

def eval_and_save(self):
remaining_epochs = self.max_epoch - 1 - self.epoch # self.epoch is start from 0
eval_interval = self.args.eval_interval if remaining_epochs >= self.args.heavy_eval_range else 3
is_val_epoch = (remaining_epochs == 0) or ((not self.args.eval_final_only) and ((self.epoch + 1) % eval_interval == 0))
def after_epoch(self):
lrs_of_this_epoch = [x['lr'] for x in self.optimizer.param_groups]
self.scheduler.step() # update lr
if self.main_process:
self.ema.update_attr(self.model, include=['nc', 'names', 'stride']) # update attributes for ema model

remaining_epochs = self.max_epoch - 1 - self.epoch # self.epoch is start from 0
eval_interval = self.args.eval_interval if remaining_epochs >= self.args.heavy_eval_range else 3
is_val_epoch = (remaining_epochs == 0) or ((not self.args.eval_final_only) and ((self.epoch + 1) % eval_interval == 0))
if is_val_epoch:
self.eval_model()
self.ap = self.evaluate_results[1]
Expand Down Expand Up @@ -198,12 +202,11 @@ def eval_and_save(self):
save_checkpoint(ckpt, False, save_ckpt_dir, model_name='best_stop_aug_ckpt')

del ckpt
# log for learning rate
lr = [x['lr'] for x in self.optimizer.param_groups]
self.evaluate_results = list(self.evaluate_results) + lr

self.evaluate_results = list(self.evaluate_results)

# log for tensorboard
write_tblog(self.tblogger, self.epoch, self.evaluate_results, self.mean_loss)
write_tblog(self.tblogger, self.epoch, self.evaluate_results, lrs_of_this_epoch, self.mean_loss)
# save validation predictions to tensorboard
write_tbimg(self.tblogger, self.vis_imgs_list, self.epoch, type='val')

Expand Down Expand Up @@ -250,7 +253,7 @@ def get_cfg_value(cfg_dict, value_str, default_value):
self.plot_val_pred(vis_outputs, vis_paths)


def train_before_loop(self):
def before_train_loop(self):
LOGGER.info('Training start...')
self.start_time = time.time()
self.warmup_stepnum = max(round(self.cfg.solver.warmup_epochs * self.max_stepnum), 1000) if self.args.quant is False else 0
Expand Down Expand Up @@ -301,12 +304,7 @@ def train_before_loop(self):
distill_feat = self.args.distill_feat,
)

def prepare_for_steps(self):
if self.epoch > self.start_epoch:
self.scheduler.step()
elif hasattr(self, "ckpt") and self.epoch == self.start_epoch: # resume first epoch, load lr
for k, param in enumerate(self.optimizer.param_groups):
param['lr'] = self.scheduler.get_lr()[k]
def before_epoch(self):
#stop strong aug like mosaic and mixup from last n epoch by recreate dataloader
if self.epoch == self.max_epoch - self.args.stop_aug_last_n_epoch:
self.cfg.data_aug.mosaic = 0.0
Expand Down
8 changes: 4 additions & 4 deletions yolov6/utils/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def save_yaml(data_dict, save_path):
yaml.safe_dump(data_dict, f, sort_keys=False)


def write_tblog(tblogger, epoch, results, losses):
def write_tblog(tblogger, epoch, results, lrs, losses):
"""Display mAP and loss information to log."""
tblogger.add_scalar("val/[email protected]", results[0], epoch + 1)
tblogger.add_scalar("val/[email protected]:0.95", results[1], epoch + 1)
Expand All @@ -39,9 +39,9 @@ def write_tblog(tblogger, epoch, results, losses):
tblogger.add_scalar("train/dist_focalloss", losses[1], epoch + 1)
tblogger.add_scalar("train/cls_loss", losses[2], epoch + 1)

tblogger.add_scalar("x/lr0", results[2], epoch + 1)
tblogger.add_scalar("x/lr1", results[3], epoch + 1)
tblogger.add_scalar("x/lr2", results[4], epoch + 1)
tblogger.add_scalar("x/lr0", lrs[0], epoch + 1)
tblogger.add_scalar("x/lr1", lrs[1], epoch + 1)
tblogger.add_scalar("x/lr2", lrs[2], epoch + 1)


def write_tbimg(tblogger, imgs, step, type='train'):
Expand Down

0 comments on commit b6f95fa

Please sign in to comment.