Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wz authored and wz committed Oct 24, 2020
1 parent 351460f commit d43b321
Show file tree
Hide file tree
Showing 6 changed files with 423 additions and 35 deletions.
86 changes: 59 additions & 27 deletions pytorch_object_detection/yolov3_spp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def train(hyp):
print("Using {} device training.".format(device.type))

wdir = "weights" + os.sep # weights dir
last = wdir + "last.pt"
best = wdir + "best.pt"
results_file = "results.txt"

Expand Down Expand Up @@ -55,8 +54,8 @@ def train(hyp):
hyp["cls"] *= nc / 80 # update coco-tuned hyp['cls'] to current dataset
hyp["obj"] *= imgsz_test / 320

# remove previous results
for f in glob.glob("*_batch*.jpg") + glob.glob(results_file):
# Remove previous results
for f in glob.glob(results_file):
os.remove(f)

# Initialize model
Expand All @@ -72,17 +71,27 @@ def train(hyp):
(x not in output_layer_indices) and
(x - 1 not in output_layer_indices)]
# Freeze non-output layers
# 总共训练3x2=6个parameters
for idx in freeze_layer_indeces:
for parameter in model.module_list[idx].parameters():
parameter.requires_grad_(False)
else:
# 如果freeze_layer为False,默认仅训练除darknet53之后的部分
# 若要训练全部权重,删除以下代码
darknet_end_layer = 74 # only yolov3spp cfg
# Freeze darknet53 layers
# 总共训练21x3+3x2=69个parameters
for idx in range(darknet_end_layer + 1): # [0, 74]
for parameter in model.module_list[idx].parameters():
parameter.requires_grad_(False)

# optimizer
pg = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(pg, lr=hyp["lr0"], momentum=hyp["momentum"],
weight_decay=hyp["weight_decay"], nesterov=True)

start_epoch = 0
best_fitness = 0.0
best_map = 0.0
if weights.endswith(".pt"):
ckpt = torch.load(weights, map_location=device)

Expand All @@ -98,7 +107,7 @@ def train(hyp):
# load optimizer
if ckpt["optimizer"] is not None:
optimizer.load_state_dict(ckpt["optimizer"])
best_fitness = ckpt["best_fitness"]
best_map = ckpt["best_map"]

# load results
if ckpt.get("training_results") is not None:
Expand Down Expand Up @@ -173,9 +182,9 @@ def train(hyp):

# start training
# caching val_data when you have plenty of memory(RAM)
print("caching val_data for evaluation.")
# coco = None
coco = get_coco_api_from_dataset(val_dataset)

print("starting traning for %g epochs..." % epochs)
print('Using %g dataloader workers' % nw)
for epoch in range(start_epoch, epochs):
Expand All @@ -192,28 +201,54 @@ def train(hyp):
# update scheduler
scheduler.step()

# evaluate on the test dataset
result_info = train_util.evaluate(model, val_datasetloader,
coco=coco, device=device)

# write into tensorboard
if tb_writer:
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', 'train/loss', "learning_rate",
"mAP@[IoU=0.50:0.95]", "mAP@[IoU=0.5]", "mAR@[IoU=0.50:0.95]"]
if opt.notest is False or epoch == epochs - 1:
# evaluate on the test dataset
result_info = train_util.evaluate(model, val_datasetloader,
coco=coco, device=device)

coco_mAP = result_info[0]
voc_mAP = result_info[1]
coco_mAR = result_info[8]

for x, tag in zip(mloss.tolist() + [lr, coco_mAP, voc_mAP, coco_mAR], tags):
tb_writer.add_scalar(tag, x, epoch)

# save weights
save_files = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch}
torch.save(save_files, "./weights/yolov3spp-{}.pth".format(epoch))
# write into tensorboard
if tb_writer:
tags = ['train/giou_loss', 'train/obj_loss', 'train/cls_loss', 'train/loss', "learning_rate",
"mAP@[IoU=0.50:0.95]", "mAP@[IoU=0.5]", "mAR@[IoU=0.50:0.95]"]

for x, tag in zip(mloss.tolist() + [lr, coco_mAP, voc_mAP, coco_mAR], tags):
tb_writer.add_scalar(tag, x, epoch)

# write into txt
with open(results_file, "a") as f:
result_info = [str(round(i, 4)) for i in result_info]
txt = "epoch:{} {}".format(epoch, ' '.join(result_info))
f.write(txt + "\n")

# update best mAP(IoU=0.50:0.95)
if coco_mAP > best_map:
best_map = coco_mAP

if opt.savebest is False:
# save weights every epoch
with open(results_file, 'r') as f:
save_files = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'training_results': f.read(),
'epoch': epoch,
'best_map': best_map}
torch.save(save_files, "./weights/yolov3spp-{}.pt".format(epoch))
else:
# only save best weights
if best_map == coco_mAP:
with open(results_file, 'r') as f:
save_files = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'training_results': f.read(),
'epoch': epoch,
'best_map': best_map}
torch.save(save_files, best.format(epoch))


if __name__ == '__main__':
Expand All @@ -227,10 +262,8 @@ def train(hyp):
help='adjust (67%% - 150%%) img_size every 10 batches')
parser.add_argument('--img-size', type=int, default=512, help='test size')
parser.add_argument('--rect', action='store_true', help='rectangular training')
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
parser.add_argument('--savebest', type=bool, default=False, help='only save best checkpoint')
parser.add_argument('--notest', action='store_true', help='only test final epoch')
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
parser.add_argument('--weights', type=str, default='weights/yolov3-spp-ultralytics-512.pt',
help='initial weights path')
Expand All @@ -244,7 +277,6 @@ def train(hyp):
opt.cfg = check_file(opt.cfg)
opt.data = check_file(opt.data)
opt.hyp = check_file(opt.hyp)

print(opt)

with open(opt.hyp) as f:
Expand Down
Loading

0 comments on commit d43b321

Please sign in to comment.