Skip to content

Commit

Permalink
replacing dtype torch.uint8 with torch.bool for indexing as the forme…
Browse files Browse the repository at this point in the history
…r is deprecated in pytorch v1.2.0
  • Loading branch information
garycao-cv committed Oct 7, 2019
1 parent 0ce8f6f commit d47d97d
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion demo/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def overlay_mask(self, image, predictions):
colors = self.compute_colors_for_labels(labels).tolist()

for mask, color in zip(masks, colors):
thresh = mask[0, :, :, None]
thresh = mask[0, :, :, None].astype(np.uint8)
contours, hierarchy = cv2_util.findContours(
thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ def __call__(self, matched_idxs):

# create binary mask from indices
pos_idx_per_image_mask = torch.zeros_like(
matched_idxs_per_image, dtype=torch.uint8
matched_idxs_per_image, dtype=torch.bool
)
neg_idx_per_image_mask = torch.zeros_like(
matched_idxs_per_image, dtype=torch.uint8
matched_idxs_per_image, dtype=torch.bool
)
pos_idx_per_image_mask[pos_idx_per_image] = 1
neg_idx_per_image_mask[neg_idx_per_image] = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def __call__(self, proposals, keypoint_logits):
valid.append(valid_per_image.view(-1))

keypoint_targets = cat(heatmaps, dim=0)
valid = cat(valid, dim=0).to(dtype=torch.uint8)
valid = cat(valid, dim=0).to(dtype=torch.bool)
valid = torch.nonzero(valid).squeeze(1)

# torch.mean (in binary_cross_entropy_with_logits) does'nt
Expand Down
4 changes: 2 additions & 2 deletions maskrcnn_benchmark/modeling/roi_heads/mask_head/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def paste_mask_in_image(mask, box, im_h, im_w, thresh=0.5, padding=1):
else:
# for visualization and debugging, we also
# allow it to return an unmodified mask
mask = (mask * 255).to(torch.uint8)
mask = (mask * 255).to(torch.bool)

im_mask = torch.zeros((im_h, im_w), dtype=torch.uint8)
im_mask = torch.zeros((im_h, im_w), dtype=torch.bool)
x_0 = max(box[0], 0)
x_1 = min(box[2] + 1, im_w)
y_0 = max(box[1], 0)
Expand Down
2 changes: 1 addition & 1 deletion maskrcnn_benchmark/modeling/rpn/anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def add_visibility_to(self, boxlist):
)
else:
device = anchors.device
inds_inside = torch.ones(anchors.shape[0], dtype=torch.uint8, device=device)
inds_inside = torch.ones(anchors.shape[0], dtype=torch.bool, device=device)
boxlist.add_field("visibility", inds_inside)

def forward(self, image_list, feature_maps):
Expand Down
2 changes: 1 addition & 1 deletion maskrcnn_benchmark/modeling/rpn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def select_over_all_levels(self, boxlists):
box_sizes = [len(boxlist) for boxlist in boxlists]
post_nms_top_n = min(self.fpn_post_nms_top_n, len(objectness))
_, inds_sorted = torch.topk(objectness, post_nms_top_n, dim=0, sorted=True)
inds_mask = torch.zeros_like(objectness, dtype=torch.uint8)
inds_mask = torch.zeros_like(objectness, dtype=torch.bool)
inds_mask[inds_sorted] = 1
inds_mask = inds_mask.split(box_sizes)
for i in range(num_images):
Expand Down
4 changes: 2 additions & 2 deletions maskrcnn_benchmark/structures/segmentation_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def convert_to_binarymask(self):
)
else:
size = self.size
masks = torch.empty([0, size[1], size[0]], dtype=torch.uint8)
masks = torch.empty([0, size[1], size[0]], dtype=torch.bool)

return BinaryMaskList(masks, size=self.size)

Expand All @@ -456,7 +456,7 @@ def __getitem__(self, item):
else:
# advanced indexing on a single dimension
selected_polygons = []
if isinstance(item, torch.Tensor) and item.dtype == torch.uint8:
if isinstance(item, torch.Tensor) and item.dtype == torch.bool:
item = item.nonzero()
item = item.squeeze(1) if item.numel() > 0 else item
item = item.tolist()
Expand Down

0 comments on commit d47d97d

Please sign in to comment.