forked from meituan/YOLOv6
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add train_batch/val_predictions visualization to tensorboard
- Loading branch information
Showing
6 changed files
with
127 additions
and
8 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,9 @@ | |
|
||
from tqdm import tqdm | ||
|
||
import cv2 | ||
import numpy as np | ||
import math | ||
import torch | ||
from torch.cuda import amp | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
|
@@ -17,11 +19,12 @@ | |
from yolov6.data.data_load import create_dataloader | ||
from yolov6.models.yolo import build_model | ||
from yolov6.models.loss import ComputeLoss | ||
from yolov6.utils.events import LOGGER, NCOLS, load_yaml, write_tblog | ||
from yolov6.utils.events import LOGGER, NCOLS, load_yaml, write_tblog, write_tbimg | ||
from yolov6.utils.ema import ModelEMA, de_parallel | ||
from yolov6.utils.checkpoint import load_state_dict, save_checkpoint, strip_optimizer | ||
from yolov6.solver.build import build_optimizer, build_lr_scheduler | ||
from yolov6.utils.RepOptimizer import extract_scales, RepVGGOptimizer | ||
from yolov6.utils.nms import xywh2xyxy | ||
|
||
|
||
class Trainer: | ||
|
@@ -72,6 +75,9 @@ def __init__(self, args, cfg, device): | |
self.batch_size = args.batch_size | ||
self.img_size = args.img_size | ||
|
||
# set color for classnames | ||
self.color = [tuple(np.random.choice(range(256), size=3)) for _ in range(self.model.nc)] | ||
|
||
# Training Process | ||
|
||
def train(self): | ||
|
@@ -105,6 +111,11 @@ def train_in_loop(self): | |
# Training loop for batchdata | ||
def train_in_steps(self): | ||
images, targets = self.prepro_data(self.batch_data, self.device) | ||
|
||
# plot train_batch and save to tensorboard | ||
self.plot_train_batch(images, targets) | ||
write_tbimg(self.tblogger, self.vis_train_batch, self.step + self.max_stepnum * self.epoch, type='train') | ||
|
||
# forward | ||
with amp.autocast(enabled=self.device != 'cpu'): | ||
preds = self.model(images) | ||
|
@@ -141,8 +152,93 @@ def eval_and_save(self): | |
# log for tensorboard | ||
write_tblog(self.tblogger, self.epoch, self.evaluate_results, self.mean_loss) | ||
|
||
# save validation predictions to tensorboard | ||
write_tbimg(self.tblogger, self.vis_imgs_list, self.epoch, type='val') | ||
|
||
def plot_train_batch(self, images, targets, max_size=1920, max_subplots=16): | ||
# Plot train_batch with labels | ||
if isinstance(images, torch.Tensor): | ||
images = images.cpu().float().numpy() | ||
if isinstance(targets, torch.Tensor): | ||
targets = targets.cpu().numpy() | ||
if np.max(images[0]) <= 1: | ||
images *= 255 # de-normalise (optional) | ||
bs, _, h, w = images.shape # batch size, _, height, width | ||
bs = min(bs, max_subplots) # limit plot images | ||
ns = np.ceil(bs ** 0.5) # number of subplots (square) | ||
paths = self.batch_data[2] # image paths | ||
|
||
# Build Image | ||
mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init | ||
for i, im in enumerate(images): | ||
if i == max_subplots: # if last batch has fewer images than we expect | ||
break | ||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin | ||
im = im.transpose(1, 2, 0) | ||
mosaic[y:y + h, x:x + w, :] = im | ||
|
||
# Resize (optional) | ||
scale = max_size / ns / max(h, w) | ||
if scale < 1: | ||
h = math.ceil(scale * h) | ||
w = math.ceil(scale * w) | ||
mosaic = cv2.resize(mosaic, tuple(int(x * ns) for x in (w, h))) | ||
|
||
for i in range(bs + 1): | ||
x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin | ||
cv2.rectangle(mosaic, (x, y), (x + w, y + h), (255, 255, 255), thickness=2) # borders | ||
cv2.putText(mosaic, f"{os.path.basename(paths[i])[:40]}", (x + 5, y + 15), | ||
cv2.FONT_HERSHEY_COMPLEX, 0.5, color=(220, 220, 220), thickness=1) # filename | ||
if len(targets) > 0: | ||
ti = targets[targets[:, 0] == i] # image targets | ||
boxes = xywh2xyxy(ti[:, 2:6]).T | ||
classes = ti[:, 1].astype('int') | ||
labels = ti.shape[1] == 6 # labels if no conf column | ||
|
||
if boxes.shape[1]: | ||
if boxes.max() <= 1.01: # if normalized with tolerance 0.01 | ||
boxes[[0, 2]] *= w # scale to pixels | ||
boxes[[1, 3]] *= h | ||
elif scale < 1: # absolute coords need scale if image scales | ||
boxes *= scale | ||
boxes[[0, 2]] += x | ||
boxes[[1, 3]] += y | ||
for j, box in enumerate(boxes.T.tolist()): | ||
box = [int(k) for k in box] | ||
cls = classes[j] | ||
color = tuple([int(x) for x in self.color[cls]]) | ||
cls = self.data_dict['names'][cls] if self.data_dict['names'] else cls | ||
if labels: | ||
label = f'{cls}' | ||
cv2.rectangle(mosaic, (box[0], box[1]), (box[2], box[3]), color, thickness=1) | ||
cv2.putText(mosaic, label, (box[0], box[1] - 5), cv2.FONT_HERSHEY_COMPLEX, 0.5, color, thickness=1) | ||
|
||
self.vis_train_batch = mosaic.copy() | ||
|
||
def plot_val_pred(self, vis_outputs, vis_paths, vis_conf=0.3, vis_max_box_num=5): | ||
# plot validation predictions | ||
self.vis_imgs_list = [] | ||
for (vis_output, vis_path) in zip(vis_outputs, vis_paths): | ||
vis_output_array = vis_output.cpu().numpy() # xyxy | ||
ori_img = cv2.imread(vis_path) | ||
|
||
for bbox_idx, vis_bbox in enumerate(vis_output_array): | ||
x_tl = int(vis_bbox[0]) | ||
y_tl = int(vis_bbox[1]) | ||
x_br = int(vis_bbox[2]) | ||
y_br = int(vis_bbox[3]) | ||
box_score = vis_bbox[4] | ||
cls_id = int(vis_bbox[5]) | ||
|
||
# draw top n bbox | ||
if box_score < vis_conf or bbox_idx > vis_max_box_num: | ||
break | ||
cv2.rectangle(ori_img, (x_tl, y_tl), (x_br, y_br), tuple([int(x) for x in self.color[cls_id]]), thickness=1) | ||
cv2.putText(ori_img, f"{self.data_dict['names'][cls_id]}: {box_score:.2f}", (x_tl, y_tl - 10), cv2.FONT_HERSHEY_COMPLEX, 0.5, tuple([int(x) for x in self.color[cls_id]]), thickness=1) | ||
self.vis_imgs_list.append(torch.from_numpy(ori_img[:, :, ::-1].copy())) | ||
|
||
def eval_model(self): | ||
results = eval.run(self.data_dict, | ||
results, vis_outputs, vis_paths = eval.run(self.data_dict, | ||
batch_size=self.batch_size // self.world_size * 2, | ||
img_size=self.img_size, | ||
model=self.ema.ema, | ||
|
@@ -153,6 +249,9 @@ def eval_model(self): | |
LOGGER.info(f"Epoch: {self.epoch} | [email protected]: {results[0]} | [email protected]:0.95: {results[1]}") | ||
self.evaluate_results = results[:2] | ||
|
||
# plot validation predictions | ||
self.plot_val_pred(vis_outputs, vis_paths) | ||
|
||
def train_before_loop(self): | ||
LOGGER.info('Training start...') | ||
self.start_time = time.time() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters