Skip to content

Commit

Permalink
Merge pull request facebookresearch#1111 from bernhardschaefer/infere…
Browse files Browse the repository at this point in the history
…nce-tta-device-fix

bugfix: use correct config for tta and device handling during inference
  • Loading branch information
botcs authored Oct 16, 2019
2 parents 77b06cb + a6e9634 commit b2a2a74
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
10 changes: 5 additions & 5 deletions maskrcnn_benchmark/engine/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
from tqdm import tqdm

from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.data.datasets.evaluation import evaluate
from ..utils.comm import is_main_process, get_world_size
from ..utils.comm import all_gather
Expand All @@ -15,7 +14,7 @@
from .bbox_aug import im_detect_bbox_aug


def compute_on_dataset(model, data_loader, device, timer=None):
def compute_on_dataset(model, data_loader, device, bbox_aug, timer=None):
model.eval()
results_dict = {}
cpu_device = torch.device("cpu")
Expand All @@ -24,12 +23,12 @@ def compute_on_dataset(model, data_loader, device, timer=None):
with torch.no_grad():
if timer:
timer.tic()
if cfg.TEST.BBOX_AUG.ENABLED:
if bbox_aug:
output = im_detect_bbox_aug(model, images, device)
else:
output = model(images.to(device))
if timer:
if not cfg.MODEL.DEVICE == 'cpu':
if not device.type == 'cpu':
torch.cuda.synchronize()
timer.toc()
output = [o.to(cpu_device) for o in output]
Expand Down Expand Up @@ -67,6 +66,7 @@ def inference(
dataset_name,
iou_types=("bbox",),
box_only=False,
bbox_aug=False,
device="cuda",
expected_results=(),
expected_results_sigma_tol=4,
Expand All @@ -81,7 +81,7 @@ def inference(
total_timer = Timer()
inference_timer = Timer()
total_timer.tic()
predictions = compute_on_dataset(model, data_loader, device, inference_timer)
predictions = compute_on_dataset(model, data_loader, device, bbox_aug, inference_timer)
# wait for all processes to complete before measuring the time
synchronize()
total_time = total_timer.toc()
Expand Down
1 change: 1 addition & 0 deletions tools/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def main():
dataset_name=dataset_name,
iou_types=iou_types,
box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
bbox_aug=cfg.TEST.BBOX_AUG.ENABLED,
device=cfg.MODEL.DEVICE,
expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
Expand Down
1 change: 1 addition & 0 deletions tools/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def run_test(cfg, model, distributed):
dataset_name=dataset_name,
iou_types=iou_types,
box_only=False if cfg.MODEL.RETINANET_ON else cfg.MODEL.RPN_ONLY,
bbox_aug=cfg.TEST.BBOX_AUG.ENABLED,
device=cfg.MODEL.DEVICE,
expected_results=cfg.TEST.EXPECTED_RESULTS,
expected_results_sigma_tol=cfg.TEST.EXPECTED_RESULTS_SIGMA_TOL,
Expand Down

0 comments on commit b2a2a74

Please sign in to comment.