Skip to content

Commit

Permalink
fix some tests
Browse files Browse the repository at this point in the history
Summary:
Fix some interesting numerical issues.
Use latest torch release in docker.

Reviewed By: rbgirshick

Differential Revision: D21175202

fbshipit-source-id: f47be6ae647139db46079d79329bf5a8402ac8a0
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Apr 22, 2020
1 parent 8a2b870 commit 0c045f2
Show file tree
Hide file tree
Showing 11 changed files with 111 additions and 41 deletions.
2 changes: 1 addition & 1 deletion INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ also installs detectron2 with a few simple commands.

### Requirements
- Linux or macOS with Python ≥ 3.6
- PyTorch ≥ 1.3
- PyTorch ≥ 1.4
- [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch installation.
You can install them together at [pytorch.org](https://pytorch.org) to make sure of this.
- OpenCV, optional, needed by demo and visualization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ at::Tensor box_iou_rotated_cuda(
const at::Tensor& boxes1,
const at::Tensor& boxes2) {
using scalar_t = float;
AT_ASSERTM(
boxes1.scalar_type() == at::kFloat, "boxes1 must be a float tensor");
AT_ASSERTM(
boxes2.scalar_type() == at::kFloat, "boxes2 must be a float tensor");
AT_ASSERTM(boxes1.type().is_cuda(), "boxes1 must be a CUDA tensor");
AT_ASSERTM(boxes2.type().is_cuda(), "boxes2 must be a CUDA tensor");
at::cuda::CUDAGuard device_guard(boxes1.device());
Expand Down
41 changes: 31 additions & 10 deletions detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ HOST_DEVICE_INLINE T dot_2d(const Point<T>& A, const Point<T>& B) {
return A.x * B.x + A.y * B.y;
}

template <typename T>
HOST_DEVICE_INLINE T cross_2d(const Point<T>& A, const Point<T>& B) {
return A.x * B.y - B.x * A.y;
// R: result type. can be different from input type
template <typename T, typename R = T>
HOST_DEVICE_INLINE R cross_2d(const Point<T>& A, const Point<T>& B) {
return static_cast<R>(A.x) * static_cast<R>(B.y) -
static_cast<R>(B.x) * static_cast<R>(A.y);
}

template <typename T>
Expand Down Expand Up @@ -190,11 +192,13 @@ HOST_DEVICE_INLINE int convex_hull_graham(
// (essentially sorting according to angles)
// If the angles are the same, sort according to their distance to origin
T dist[24];
#ifdef __CUDACC__
// compute distance to origin before sort, and sort them together with the
// points
for (int i = 0; i < num_in; i++) {
dist[i] = dot_2d<T>(q[i], q[i]);
}

#ifdef __CUDACC__
// CUDA version
// In the future, we can potentially use thrust
// for sorting here to improve speed (though not guaranteed)
Expand Down Expand Up @@ -223,6 +227,10 @@ HOST_DEVICE_INLINE int convex_hull_graham(
return temp > 0;
}
});
// compute distance to origin after sort, since the points are now different.
for (int i = 0; i < num_in; i++) {
dist[i] = dot_2d<T>(q[i], q[i]);
}
#endif

// Step 4:
Expand All @@ -249,9 +257,22 @@ HOST_DEVICE_INLINE int convex_hull_graham(
// until the 3-point relationship is convex again, or
// until the stack only contains two points
for (int i = k + 1; i < num_in; i++) {
while (m > 1 && cross_2d<T>(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
m--;
while (m > 1) {
auto q1 = q[i] - q[m - 2], q2 = q[m - 1] - q[m - 2];
// cross_2d() uses FMA and therefore computes round(round(q1.x*q2.y) -
// q2.x*q1.y) So it may not return 0 even when q1==q2. Therefore we
// compare round(q1.x*q2.y) and round(q2.x*q1.y) directly. (round means
// round to nearest floating point).
if (q1.x * q2.y >= q2.x * q1.y)
m--;
else
break;
}
// Using double also helps, but float can solve the issue for now.
// while (m > 1 && cross_2d<T, double>(q[i] - q[m - 2], q[m - 1] - q[m - 2])
// >= 0) {
// m--;
// }
q[m++] = q[i];
}

Expand Down Expand Up @@ -328,14 +349,14 @@ single_box_iou_rotated(T const* const box1_raw, T const* const box2_raw) {
box2.h = box2_raw[3];
box2.a = box2_raw[4];

const T area1 = box1.w * box1.h;
const T area2 = box2.w * box2.h;
T area1 = box1.w * box1.h;
T area2 = box2.w * box2.h;
if (area1 < 1e-14 || area2 < 1e-14) {
return 0.f;
}

const T intersection = rotated_boxes_intersection<T>(box1, box2);
const T iou = intersection / (area1 + area2 - intersection);
T intersection = rotated_boxes_intersection<T>(box1, box2);
T iou = intersection / (area1 + area2 - intersection);
return iou;
}

Expand Down
2 changes: 1 addition & 1 deletion detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ at::Tensor nms_rotated_cuda(
dim3 threads(threadsPerBlock);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
AT_DISPATCH_FLOATING_TYPES(
dets_sorted.type(), "nms_rotated_kernel_cuda", [&] {
nms_rotated_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
dets_num,
Expand Down
1 change: 0 additions & 1 deletion detectron2/layers/rotated_boxes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from __future__ import absolute_import, division, print_function, unicode_literals

# import torch
from detectron2 import _C


Expand Down
3 changes: 2 additions & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ RUN wget https://bootstrap.pypa.io/get-pip.py && \

# install dependencies
# See https://pytorch.org/ for other options if you use a different version of CUDA
RUN pip install --user torch torchvision tensorboard cython
RUN pip install --user tensorboard cython
RUN pip install --user torch==1.5+cu101 torchvision==0.6+cu101 -f https://download.pytorch.org/whl/torch_stable.html
RUN pip install --user 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'

RUN pip install --user 'git+https://github.com/facebookresearch/fvcore'
Expand Down
3 changes: 2 additions & 1 deletion docker/Dockerfile-circleci
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ RUN wget -q https://bootstrap.pypa.io/get-pip.py && \

# install dependencies
# See https://pytorch.org/ for other options if you use a different version of CUDA
RUN pip install torch torchvision tensorboard cython
RUN pip install tensorboard cython
RUN pip install torch==1.5+cu101 torchvision==0.6+cu101 -f https://download.pytorch.org/whl/torch_stable.html
RUN pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
2 changes: 1 addition & 1 deletion docs/tutorials/deployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ cd tools/deploy/ && ./caffe2_converter.py --config-file ../../configs/COCO-Insta
Note that:
1. The conversion needs valid sample inputs & weights to trace the model. That's why the script requires the dataset.
You can modify the script to obtain sample inputs in other ways.
2. GPU conversion is supported only with Pytorch's master. So we use `MODEL.DEVICE cpu`.
2. GPU conversion is supported only with Pytorch ≥ 1.5. So we use `MODEL.DEVICE cpu`.
3. With the `--run-eval` flag, it will evaluate the converted models to verify its accuracy.
The accuracy is typically slightly different (within 0.1 AP) from PyTorch due to
numerical precisions between different implementations.
Expand Down
47 changes: 38 additions & 9 deletions tests/layers/test_nms_rotated.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,42 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
import unittest
import torch
from torchvision import ops

from detectron2.layers import batched_nms, batched_nms_rotated, nms_rotated


def nms_edit_distance(keep1, keep2):
"""
Compare the "keep" result of two nms call.
They are allowed to be different in terms of edit distance
due to floating point precision issues, e.g.,
if a box happen to have an IoU of 0.5 with another box,
one implentation may choose to keep it while another may discard it.
"""
if torch.equal(keep1, keep2):
# they should be equal most of the time
return 0
keep1, keep2 = tuple(keep1.cpu()), tuple(keep2.cpu())
m, n = len(keep1), len(keep2)

# edit distance with DP
f = [np.arange(n + 1), np.arange(n + 1)]
for i in range(m):
cur_row = i % 2
other_row = (i + 1) % 2
f[other_row][0] = i + 1
for j in range(n):
f[other_row][j + 1] = (
f[cur_row][j]
if keep1[i] == keep2[j]
else min(min(f[cur_row][j], f[cur_row][j + 1]), f[other_row][j]) + 1
)
return f[m % 2][n]


class TestNMSRotated(unittest.TestCase):
def reference_horizontal_nms(self, boxes, scores, iou_threshold):
"""
Expand Down Expand Up @@ -43,7 +73,6 @@ def _create_tensors(self, N):
return boxes, scores

def test_batched_nms_rotated_0_degree_cpu(self):
# torch.manual_seed(0)
N = 2000
num_classes = 50
boxes, scores = self._create_tensors(N)
Expand All @@ -63,11 +92,10 @@ def test_batched_nms_rotated_0_degree_cpu(self):
assert torch.allclose(
rotated_boxes, backup
), "rotated_boxes modified by batched_nms_rotated"
assert torch.equal(keep, keep_ref), err_msg.format(iou)
self.assertLessEqual(nms_edit_distance(keep, keep_ref), 1, err_msg.format(iou))

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_batched_nms_rotated_0_degree_cuda(self):
# torch.manual_seed(0)
N = 2000
num_classes = 50
boxes, scores = self._create_tensors(N)
Expand All @@ -81,13 +109,14 @@ def test_batched_nms_rotated_0_degree_cuda(self):
for iou in [0.2, 0.5, 0.8]:
backup = boxes.clone()
keep_ref = batched_nms(boxes.cuda(), scores.cuda(), idxs, iou)
assert torch.allclose(boxes, backup), "boxes modified by batched_nms"
self.assertTrue(torch.allclose(boxes, backup), "boxes modified by batched_nms")
backup = rotated_boxes.clone()
keep = batched_nms_rotated(rotated_boxes.cuda(), scores.cuda(), idxs, iou)
assert torch.allclose(
rotated_boxes, backup
), "rotated_boxes modified by batched_nms_rotated"
assert torch.equal(keep, keep_ref), err_msg.format(iou)
self.assertTrue(
torch.allclose(rotated_boxes, backup),
"rotated_boxes modified by batched_nms_rotated",
)
self.assertLessEqual(nms_edit_distance(keep, keep_ref), 1, err_msg.format(iou))

def test_nms_rotated_0_degree_cpu(self):
N = 1000
Expand All @@ -101,7 +130,7 @@ def test_nms_rotated_0_degree_cpu(self):
for iou in [0.5]:
keep_ref = self.reference_horizontal_nms(boxes, scores, iou)
keep = nms_rotated(rotated_boxes, scores, iou)
assert torch.equal(keep, keep_ref), err_msg.format(iou)
self.assertLessEqual(nms_edit_distance(keep, keep_ref), 1, err_msg.format(iou))

def test_nms_rotated_90_degrees_cpu(self):
N = 1000
Expand Down
43 changes: 27 additions & 16 deletions tests/structures/test_rotated_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,22 +57,14 @@ def test_iou_half_overlap_cuda(self):
ious_cuda = pairwise_iou_rotated(boxes1.cuda(), boxes2.cuda())
self.assertTrue(torch.allclose(ious_cuda.cpu(), expected_ious))

def test_iou_precision_cpu(self):
boxes1 = torch.tensor([[565, 565, 10, 10, 0]], dtype=torch.float32)
boxes2 = torch.tensor([[565, 565, 10, 8.3, 0]], dtype=torch.float32)
iou = 8.3 / 10.0
expected_ious = torch.tensor([[iou]], dtype=torch.float32)
ious = pairwise_iou_rotated(boxes1, boxes2)
self.assertTrue(torch.allclose(ious, expected_ious))

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_iou_precision_cuda(self):
boxes1 = torch.tensor([[565, 565, 10, 10, 0]], dtype=torch.float32)
boxes2 = torch.tensor([[565, 565, 10, 8.3, 0]], dtype=torch.float32)
iou = 8.3 / 10.0
expected_ious = torch.tensor([[iou]], dtype=torch.float32)
ious_cuda = pairwise_iou_rotated(boxes1.cuda(), boxes2.cuda())
self.assertTrue(torch.allclose(ious_cuda.cpu(), expected_ious))
def test_iou_precision(self):
for device in ["cpu"] + ["cuda"] if torch.cuda.is_available() else []:
boxes1 = torch.tensor([[565, 565, 10, 10.0, 0]], dtype=torch.float32, device=device)
boxes2 = torch.tensor([[565, 565, 10, 8.3, 0]], dtype=torch.float32, device=device)
iou = 8.3 / 10.0
expected_ious = torch.tensor([[iou]], dtype=torch.float32)
ious = pairwise_iou_rotated(boxes1, boxes2)
self.assertTrue(torch.allclose(ious.cpu(), expected_ious))

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_iou_too_many_boxes_cuda(self):
Expand All @@ -82,6 +74,25 @@ def test_iou_too_many_boxes_cuda(self):
ious_cuda = pairwise_iou_rotated(boxes1.cuda(), boxes2.cuda())
self.assertTupleEqual(tuple(ious_cuda.shape), (s1, s2))

def test_iou_extreme(self):
# Cause floating point issues in cuda kernels (#1266)
for device in ["cpu"] + ["cuda"] if torch.cuda.is_available() else []:
boxes1 = torch.tensor([[160.0, 153.0, 230.0, 23.0, -37.0]], device=device)
boxes2 = torch.tensor(
[
[
-1.117407639806935e17,
1.3858420478349148e18,
1000.0000610351562,
1000.0000610351562,
1612.0,
]
],
device=device,
)
ious = pairwise_iou_rotated(boxes1, boxes2)
self.assertTrue(ious.min() >= 0, ious)


class TestRotatedBoxesStructure(unittest.TestCase):
def test_clip_area_0_degree(self):
Expand Down
4 changes: 4 additions & 0 deletions tools/deploy/caffe2_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import argparse
import os
import torch

from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
Expand All @@ -20,6 +21,9 @@ def setup_cfg(args):
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
if cfg.MODEL.DEVICE != "cpu":
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
assert TORCH_VERSION >= (1, 5), "PyTorch>=1.5 required!"
return cfg


Expand Down

0 comments on commit 0c045f2

Please sign in to comment.