forked from meituan/YOLOv6
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
da172e4
commit d080e13
Showing
3 changed files
with
270 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,6 +94,15 @@ def predict_model(self, model, dataloader, task): | |
self.speed_result = torch.zeros(4, device=self.device) | ||
pred_results = [] | ||
pbar = tqdm(dataloader, desc="Inferencing model in val datasets.", ncols=NCOLS) | ||
|
||
# whether to compute metric and plot PR curve and P、R、F1 curve under iou50 match rule | ||
metric_and_plot = True if task == "val" else False | ||
if metric_and_plot: | ||
stats, ap = [], [] | ||
seen = 0 | ||
iouv = torch.linspace(0.5, 0.95, 10) # iou vector for [email protected]:0.95 | ||
niou = iouv.numel() | ||
|
||
for i, (imgs, targets, paths, shapes) in enumerate(pbar): | ||
|
||
# pre-process | ||
|
@@ -114,6 +123,10 @@ def predict_model(self, model, dataloader, task): | |
self.speed_result[3] += time_sync() - t3 # post-process time | ||
self.speed_result[0] += len(outputs) | ||
|
||
if metric_and_plot: | ||
import copy | ||
eval_outputs = copy.deepcopy([x.detach().cpu() for x in outputs]) | ||
|
||
# save result | ||
pred_results.extend(self.convert_to_coco_format(outputs, imgs, paths, shapes, self.ids)) | ||
|
||
|
@@ -122,6 +135,71 @@ def predict_model(self, model, dataloader, task): | |
vis_num = min(len(imgs), 8) | ||
vis_outputs = outputs[:vis_num] | ||
vis_paths = paths[:vis_num] | ||
|
||
if not metric_and_plot: | ||
continue | ||
|
||
# Statistics per image | ||
# This code is based on | ||
# https://github.com/ultralytics/yolov5/blob/master/val.py | ||
for si, pred in enumerate(eval_outputs): | ||
labels = targets[targets[:, 0] == si, 1:] | ||
nl = len(labels) | ||
tcls = labels[:, 0].tolist() if nl else [] # target class | ||
seen += 1 | ||
|
||
if len(pred) == 0: | ||
if nl: | ||
stats.append((torch.zeros(0, niou, dtype=torch.bool), torch.Tensor(), torch.Tensor(), tcls)) | ||
continue | ||
|
||
# Predictions | ||
predn = pred.clone() | ||
self.scale_coords(imgs[si].shape[1:], predn[:, :4], shapes[si][0], shapes[si][1]) # native-space pred | ||
|
||
# Assign all predictions as incorrect | ||
correct = torch.zeros(pred.shape[0], niou, dtype=torch.bool) | ||
if nl: | ||
|
||
from yolov6.utils.nms import xywh2xyxy | ||
|
||
# target boxes | ||
tbox = xywh2xyxy(labels[:, 1:5]) | ||
tbox[:, [0, 2]] *= imgs[si].shape[1:][0] | ||
tbox[:, [1, 3]] *= imgs[si].shape[1:][1] | ||
|
||
self.scale_coords(imgs[si].shape[1:], tbox, shapes[si][0], shapes[si][1]) # native-space labels | ||
|
||
labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels | ||
|
||
from yolov6.utils.metrics import process_batch | ||
|
||
correct = process_batch(predn, labelsn, iouv) | ||
|
||
# Append statistics (correct, conf, pcls, tcls) | ||
stats.append((correct.cpu(), pred[:, 4].cpu(), pred[:, 5].cpu(), tcls)) | ||
|
||
if metric_and_plot: | ||
# Compute statistics | ||
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy | ||
if len(stats) and stats[0].any(): | ||
|
||
from yolov6.utils.metrics import ap_per_class | ||
p, r, ap, f1, ap_class = ap_per_class(*stats, plot=metric_and_plot, save_dir=self.save_dir, names=model.names) | ||
AP50_F1_max_idx = f1.mean(0).argmax() | ||
LOGGER.info(f"IOU 50 best mF1 thershold near {AP50_F1_max_idx/1000.0}.") | ||
ap50, ap = ap[:, 0], ap.mean(1) # [email protected], [email protected]:0.95 | ||
mp, mr, map50, map = p[:, AP50_F1_max_idx].mean(), r[:, AP50_F1_max_idx].mean(), ap50.mean(), ap.mean() | ||
nt = np.bincount(stats[3].astype(np.int64), minlength=model.nc) # number of targets per class | ||
else: | ||
nt = torch.zeros(1) | ||
|
||
# Print results | ||
s = ('%s' + '%12s' * 4) % ('Class', 'Images', 'Labels', '[email protected]', '[email protected]') | ||
LOGGER.info(s) | ||
pf = '%s' + '%12i' * 2 + '%12.3g' * 2 # print format | ||
LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr)) | ||
|
||
return pred_results, vis_outputs, vis_paths | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
# Model validation metrics | ||
# This code is based on | ||
# https://github.com/ultralytics/yolov5/blob/master/utils/metrics.py | ||
|
||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import torch | ||
|
||
from . import general | ||
|
||
def ap_per_class(tp, conf, pred_cls, target_cls, plot=False, save_dir='.', names=()): | ||
""" Compute the average precision, given the recall and precision curves. | ||
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. | ||
# Arguments | ||
tp: True positives (nparray, nx1 or nx10). | ||
conf: Objectness value from 0-1 (nparray). | ||
pred_cls: Predicted object classes (nparray). | ||
target_cls: True object classes (nparray). | ||
plot: Plot precision-recall curve at [email protected] | ||
save_dir: Plot save directory | ||
# Returns | ||
The average precision as computed in py-faster-rcnn. | ||
""" | ||
|
||
# Sort by objectness | ||
i = np.argsort(-conf) | ||
tp, conf, pred_cls = tp[i], conf[i], pred_cls[i] | ||
|
||
# Find unique classes | ||
unique_classes = np.unique(target_cls) | ||
nc = unique_classes.shape[0] # number of classes, number of detections | ||
|
||
# Create Precision-Recall curve and compute AP for each class | ||
px, py = np.linspace(0, 1, 1000), [] # for plotting | ||
ap, p, r = np.zeros((nc, tp.shape[1])), np.zeros((nc, 1000)), np.zeros((nc, 1000)) | ||
for ci, c in enumerate(unique_classes): | ||
i = pred_cls == c | ||
n_l = (target_cls == c).sum() # number of labels | ||
n_p = i.sum() # number of predictions | ||
|
||
if n_p == 0 or n_l == 0: | ||
continue | ||
else: | ||
# Accumulate FPs and TPs | ||
fpc = (1 - tp[i]).cumsum(0) | ||
tpc = tp[i].cumsum(0) | ||
|
||
# Recall | ||
recall = tpc / (n_l + 1e-16) # recall curve | ||
r[ci] = np.interp(-px, -conf[i], recall[:, 0], left=0) # negative x, xp because xp decreases | ||
|
||
# Precision | ||
precision = tpc / (tpc + fpc) # precision curve | ||
p[ci] = np.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score | ||
|
||
# AP from recall-precision curve | ||
for j in range(tp.shape[1]): | ||
ap[ci, j], mpre, mrec = compute_ap(recall[:, j], precision[:, j]) | ||
if plot and j == 0: | ||
py.append(np.interp(px, mrec, mpre)) # precision at [email protected] | ||
|
||
# Compute F1 (harmonic mean of precision and recall) | ||
f1 = 2 * p * r / (p + r + 1e-16) | ||
if plot: | ||
plot_pr_curve(px, py, ap, Path(save_dir) / 'PR_curve.png', names) | ||
plot_mc_curve(px, f1, Path(save_dir) / 'F1_curve.png', names, ylabel='F1') | ||
plot_mc_curve(px, p, Path(save_dir) / 'P_curve.png', names, ylabel='Precision') | ||
plot_mc_curve(px, r, Path(save_dir) / 'R_curve.png', names, ylabel='Recall') | ||
|
||
# i = f1.mean(0).argmax() # max F1 index | ||
# return p[:, i], r[:, i], ap, f1[:, i], unique_classes.astype('int32') | ||
return p, r, ap, f1, unique_classes.astype('int32') | ||
|
||
|
||
def compute_ap(recall, precision): | ||
""" Compute the average precision, given the recall and precision curves | ||
# Arguments | ||
recall: The recall curve (list) | ||
precision: The precision curve (list) | ||
# Returns | ||
Average precision, precision curve, recall curve | ||
""" | ||
|
||
# Append sentinel values to beginning and end | ||
mrec = np.concatenate(([0.], recall, [recall[-1] + 0.01])) | ||
mpre = np.concatenate(([1.], precision, [0.])) | ||
|
||
# Compute the precision envelope | ||
mpre = np.flip(np.maximum.accumulate(np.flip(mpre))) | ||
|
||
# Integrate area under curve | ||
method = 'interp' # methods: 'continuous', 'interp' | ||
if method == 'interp': | ||
x = np.linspace(0, 1, 101) # 101-point interp (COCO) | ||
ap = np.trapz(np.interp(x, mrec, mpre), x) # integrate | ||
else: # 'continuous' | ||
i = np.where(mrec[1:] != mrec[:-1])[0] # points where x axis (recall) changes | ||
ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) # area under curve | ||
|
||
return ap, mpre, mrec | ||
|
||
# Plots ---------------------------------------------------------------------------------------------------------------- | ||
|
||
def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()): | ||
# Precision-recall curve | ||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) | ||
py = np.stack(py, axis=1) | ||
|
||
if 0 < len(names) < 21: # display per-class legend if < 21 classes | ||
for i, y in enumerate(py.T): | ||
ax.plot(px, y, linewidth=1, label=f'{names[i]} {ap[i, 0]:.3f}') # plot(recall, precision) | ||
else: | ||
ax.plot(px, py, linewidth=1, color='grey') # plot(recall, precision) | ||
|
||
ax.plot(px, py.mean(1), linewidth=3, color='blue', label='all classes %.3f [email protected]' % ap[:, 0].mean()) | ||
ax.set_xlabel('Recall') | ||
ax.set_ylabel('Precision') | ||
ax.set_xlim(0, 1) | ||
ax.set_ylim(0, 1) | ||
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") | ||
fig.savefig(Path(save_dir), dpi=250) | ||
|
||
|
||
def plot_mc_curve(px, py, save_dir='mc_curve.png', names=(), xlabel='Confidence', ylabel='Metric'): | ||
# Metric-confidence curve | ||
fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True) | ||
|
||
if 0 < len(names) < 21: # display per-class legend if < 21 classes | ||
for i, y in enumerate(py): | ||
ax.plot(px, y, linewidth=1, label=f'{names[i]}') # plot(confidence, metric) | ||
else: | ||
ax.plot(px, py.T, linewidth=1, color='grey') # plot(confidence, metric) | ||
|
||
y = py.mean(0) | ||
ax.plot(px, y, linewidth=3, color='blue', label=f'all classes {y.max():.2f} at {px[y.argmax()]:.3f}') | ||
ax.set_xlabel(xlabel) | ||
ax.set_ylabel(ylabel) | ||
ax.set_xlim(0, 1) | ||
ax.set_ylim(0, 1) | ||
plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left") | ||
fig.savefig(Path(save_dir), dpi=250) | ||
|
||
def process_batch(detections, labels, iouv): | ||
""" | ||
Return correct predictions matrix. Both sets of boxes are in (x1, y1, x2, y2) format. | ||
Arguments: | ||
detections (Array[N, 6]), x1, y1, x2, y2, conf, class | ||
labels (Array[M, 5]), class, x1, y1, x2, y2 | ||
Returns: | ||
correct (Array[N, 10]), for 10 IoU levels | ||
""" | ||
correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool) | ||
iou = general.box_iou(labels[:, 1:], detections[:, :4]) | ||
correct_class = labels[:, 0:1] == detections[:, 5] | ||
for i in range(len(iouv)): | ||
x = torch.where((iou >= iouv[i]) & correct_class) # IoU > threshold and classes match | ||
if x[0].shape[0]: | ||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou] | ||
if x[0].shape[0] > 1: | ||
matches = matches[matches[:, 2].argsort()[::-1]] | ||
matches = matches[np.unique(matches[:, 1], return_index=True)[1]] | ||
# matches = matches[matches[:, 2].argsort()[::-1]] | ||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]] | ||
correct[matches[:, 1].astype(int), i] = True | ||
return torch.tensor(correct, dtype=torch.bool, device=iouv.device) |