forked from ultralytics/yolov5
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
629d370
commit 260b172
Showing
4 changed files
with
28 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,10 +20,12 @@ def test(data, | |
model=None, | ||
dataloader=None, | ||
fast=False, | ||
verbose=False): # 0 fast, 1 accurate | ||
verbose=False, | ||
half=False): # FP16 | ||
# Initialize/load model and set device | ||
if model is None: | ||
device = torch_utils.select_device(opt.device, batch_size=batch_size) | ||
half &= device.type != 'cpu' # half precision only supported on CUDA | ||
|
||
# Remove previous | ||
for f in glob.glob('test_batch*.jpg'): | ||
|
@@ -35,6 +37,8 @@ def test(data, | |
torch_utils.model_info(model) | ||
# model.fuse() | ||
model.to(device) | ||
if half: | ||
model.half() # to FP16 | ||
|
||
if device.type != 'cpu' and torch.cuda.device_count() > 1: | ||
model = nn.DataParallel(model) | ||
|
@@ -72,24 +76,27 @@ def test(data, | |
|
||
seen = 0 | ||
model.eval() | ||
_ = model(torch.zeros((1, 3, imgsz, imgsz), device=device)) if device.type != 'cpu' else None # run once | ||
img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img | ||
_ = model(img.half() if half else img) if device.type != 'cpu' else None # run once | ||
names = model.names if hasattr(model, 'names') else model.module.names | ||
coco91class = coco80_to_coco91_class() | ||
s = ('%20s' + '%12s' * 6) % ('Class', 'Images', 'Targets', 'P', 'R', '[email protected]', '[email protected]:.95') | ||
p, r, f1, mp, mr, map50, map, t0, t1 = 0., 0., 0., 0., 0., 0., 0., 0., 0. | ||
loss = torch.zeros(3, device=device) | ||
jdict, stats, ap, ap_class = [], [], [], [] | ||
for batch_i, (imgs, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): | ||
imgs = imgs.to(device).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0 | ||
for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)): | ||
img = img.to(device) | ||
img = img.half() if half else img.float() # uint8 to fp16/32 | ||
img /= 255.0 # 0 - 255 to 0.0 - 1.0 | ||
targets = targets.to(device) | ||
nb, _, height, width = imgs.shape # batch size, channels, height, width | ||
nb, _, height, width = img.shape # batch size, channels, height, width | ||
whwh = torch.Tensor([width, height, width, height]).to(device) | ||
|
||
# Disable gradients | ||
with torch.no_grad(): | ||
# Run model | ||
t = torch_utils.time_synchronized() | ||
inf_out, train_out = model(imgs, augment=augment) # inference and training outputs | ||
inf_out, train_out = model(img, augment=augment) # inference and training outputs | ||
t0 += torch_utils.time_synchronized() - t | ||
|
||
# Compute loss | ||
|
@@ -125,7 +132,7 @@ def test(data, | |
# [{"image_id": 42, "category_id": 18, "bbox": [258.15, 41.29, 348.26, 243.78], "score": 0.236}, ... | ||
image_id = int(Path(paths[si]).stem.split('_')[-1]) | ||
box = pred[:, :4].clone() # xyxy | ||
scale_coords(imgs[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape | ||
scale_coords(img[si].shape[1:], box, shapes[si][0], shapes[si][1]) # to original shape | ||
box = xyxy2xywh(box) # xywh | ||
box[:, :2] -= box[:, 2:] / 2 # xy center to top-left corner | ||
for p, b in zip(pred.tolist(), box.tolist()): | ||
|
@@ -168,9 +175,9 @@ def test(data, | |
# Plot images | ||
if batch_i < 1: | ||
f = 'test_batch%g_gt.jpg' % batch_i # filename | ||
plot_images(imgs, targets, paths, f, names) # ground truth | ||
plot_images(img, targets, paths, f, names) # ground truth | ||
f = 'test_batch%g_pred.jpg' % batch_i | ||
plot_images(imgs, output_to_target(output, width, height), paths, f, names) # predictions | ||
plot_images(img, output_to_target(output, width, height), paths, f, names) # predictions | ||
|
||
# Compute statistics | ||
stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy | ||
|
@@ -241,6 +248,7 @@ def test(data, | |
parser.add_argument('--save-json', action='store_true', help='save a cocoapi-compatible JSON results file') | ||
parser.add_argument('--task', default='val', help="'val', 'test', 'study'") | ||
parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu') | ||
parser.add_argument('--half', action='store_true', help='half precision FP16 inference') | ||
parser.add_argument('--single-cls', action='store_true', help='treat as single-class dataset') | ||
parser.add_argument('--augment', action='store_true', help='augmented inference') | ||
parser.add_argument('--verbose', action='store_true', help='report mAP by class') | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters