Skip to content

Commit

Permalink
enable moving traced model between devices
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#4132

X-link: fairinternal/detectron2#568

X-link: facebookresearch/d2go#203

For full discussion: https://fb.workplace.com/groups/1405155842844877/posts/5744470455580039

Tracing the `.to(device)` will cause problem when moving the traced torchscript to another device (eg. from cpu to gpu, or even, from `cuda:0` to `cuda:1`). The reason is that `device` is not a `torch.Tensor`, so the tracer just hardcode the value during tracing. The solution is scripting the casting operation.

Here's the code snippet illustrating this:
```
# define the MyModel similar to GeneralizedRCNN, which casts the input to the model's device
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        # cast the input to the same device as this model, this makes it possible to
        # take a cpu tensor as input when the model is on GPU.
        x = x.to(self.conv1.weight.device)

        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

# export the model by tracing
model = MyModel()
x = torch.zeros([1, 3, 32, 32])
ts = torch.jit.trace(model, x)
print(ts.graph)

# =====================================================
graph(%self.1 : __torch__.MyModel,
      %x : Float(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu)):
  %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d = prim::GetAttr[name="conv2"](%self.1)
  %conv1 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv1"](%self.1)
  %14 : int = prim::Constant[value=6]() # <ipython-input-2-5abde0efc36f>:11:0
  %15 : int = prim::Constant[value=0]() # <ipython-input-2-5abde0efc36f>:11:0
  %16 : Device = prim::Constant[value="cpu"]() # <ipython-input-2-5abde0efc36f>:11:0
  %17 : NoneType = prim::Constant()
  %18 : bool = prim::Constant[value=0]() # <ipython-input-2-5abde0efc36f>:11:0
  %19 : bool = prim::Constant[value=0]() # <ipython-input-2-5abde0efc36f>:11:0
  %20 : NoneType = prim::Constant()
  %input.1 : Float(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu) = aten::to(%x, %14, %15, %16, %17, %18, %19, %20) # <ipython-input-2-5abde0efc36f>:11:0
  %72 : Tensor = prim::CallMethod[name="forward"](%conv1, %input.1)
  %input.5 : Float(1, 20, 28, 28, strides=[15680, 784, 28, 1], requires_grad=1, device=cpu) = aten::relu(%72) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0
  %73 : Tensor = prim::CallMethod[name="forward"](%conv2, %input.5)
  %61 : Float(1, 20, 24, 24, strides=[11520, 576, 24, 1], requires_grad=1, device=cpu) = aten::relu(%73) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0
  return (%61)
# =====================================================

# PyTorch cuda works
model = copy.deepcopy(model)
model.to("cuda")
y = model(x)
# torchscript cpu works
y = ts(x)
# torchscript cuda doesn't work
ts = ts.to("cuda")
y = ts(x)

# =====================================================
RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-4-2aece3ad6c9a> in <module>
      7 # torchscript cuda doesn't work
      8 ts = ts.to("cuda")
----> 9 y = ts(x)
/mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []
RuntimeError: The following operation failed in the TorchScript interpreter.
# =====================================================

# One solution is scripting the casting instead of tracing it, the folloing code demonstrate how to do it. We need to use mixed scripting/tracing
torch.jit.script_if_tracing
def cast_device_like(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor:
    return src.to(dst.device)

class MyModel2(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(3, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        # cast the input to the same device as this model, this makes it possible to
        # take a cpu tensor as input when the model is on GPU.
        x = cast_device_like(x, self.conv1.weight)

        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

# export the model by tracing
model = MyModel2()
x = torch.zeros([1, 3, 32, 32])
ts = torch.jit.trace(model, x)
print(ts.graph)

# =====================================================
graph(%self.1 : __torch__.MyModel2,
      %x : Float(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu)):
  %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_5.Conv2d = prim::GetAttr[name="conv2"](%self.1)
  %conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_4.Conv2d = prim::GetAttr[name="conv1"](%self.1)
  %conv1.1 : __torch__.torch.nn.modules.conv.___torch_mangle_4.Conv2d = prim::GetAttr[name="conv1"](%self.1)
  %weight.5 : Tensor = prim::GetAttr[name="weight"](%conv1.1)
  %14 : Function = prim::Constant[name="cast_device_like"]()
  %input.1 : Tensor = prim::CallFunction(%14, %x, %weight.5)
  %68 : Tensor = prim::CallMethod[name="forward"](%conv1, %input.1)
  %input.5 : Float(1, 20, 28, 28, strides=[15680, 784, 28, 1], requires_grad=1, device=cpu) = aten::relu(%68) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0
  %69 : Tensor = prim::CallMethod[name="forward"](%conv2, %input.5)
  %55 : Float(1, 20, 24, 24, strides=[11520, 576, 24, 1], requires_grad=1, device=cpu) = aten::relu(%69) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0
  return (%55)
# =====================================================

# PyTorch cuda works
model = copy.deepcopy(model)
model.to("cuda")
y = model(x)
# torchscript cpu works
y = ts(x)
# Note that now torchscript cuda works
ts = ts.to("cuda")
y = ts(x)
print(y.device)

# =====================================================
cuda:0
# =====================================================
```

For D2 (facebookresearch@11528ce), this diff creates a `move_tensor_device_same_as_another(A, B)` function to replace `A.to(B.device)`. This diff updates the `rcnn.py` and all its utils.

For D2 (facebookresearch@11528ce083dc9ff83ee3a8f9086a1ef54d2a402f)Go, since the exported model will become device-agnostic, we can remove the "_gpu" from predictor-type.

Update (April 11):
Add test to cover tracing on one device and move traced model to another device for inference. When GPU is available, it'll trace on `cuda:0` and run inference on `cpu`, `cuda:0` (and `cuda:N-1` if available).

Summary of the device related patterns
- The usage of `.to(dtype=another_dype)` won't affect device.
- Explicit device casting like `.to(device)` can be generally replaced by `move_device_like`.
- For creating variable directly on device (eg. `torch.zeros`, `torch.arange`), we can replace then with ScriptModule to avoid first create on CPU and then move to new device.
    - Creating things on tracing device and then moving to new device is dangerous, because tracing device (eg. `cuda:0`) might not be available (eg. running on CPU-only machine).
    - It's hard to write `image_list.py` in this pattern because the size behaves differently during tracing (int vs. scalar tensor), in this diff, still create on CPU first and then move to target device.

Reviewed By: tglik

Differential Revision: D35367772

fbshipit-source-id: 02d07e3d96da85f4cfbeb996e3c14c2a6f619beb
  • Loading branch information
wat3rBro authored and facebook-github-bot committed Apr 15, 2022
1 parent 2409af0 commit 2bd05b4
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 44 deletions.
1 change: 1 addition & 0 deletions detectron2/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
nonzero_tuple,
cross_entropy,
shapes_to_tensor,
move_device_like,
)
from .blocks import CNNBlockBase, DepthwiseSeparableConv2d
from .aspp import ASPP
Expand Down
9 changes: 9 additions & 0 deletions detectron2/layers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,12 @@ def nonzero_tuple(x):
return x.nonzero().unbind(1)
else:
return x.nonzero(as_tuple=True)


@torch.jit.script_if_tracing
def move_device_like(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor:
"""
Tracing friendly way to cast tensor to another tensor's device. Device will be treated
as constant during tracing, scripting the casting process as whole can workaround this issue.
"""
return src.to(dst.device)
20 changes: 12 additions & 8 deletions detectron2/modeling/anchor_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import nn

from detectron2.config import configurable
from detectron2.layers import ShapeSpec
from detectron2.layers import ShapeSpec, move_device_like
from detectron2.structures import Boxes, RotatedBoxes
from detectron2.utils.registry import Registry

Expand Down Expand Up @@ -36,13 +36,17 @@ def __iter__(self):
return iter(self._buffers.values())


def _create_grid_offsets(size: List[int], stride: int, offset: float, device: torch.device):
def _create_grid_offsets(
size: List[int], stride: int, offset: float, target_device_tensor: torch.Tensor
):
grid_height, grid_width = size
shifts_x = torch.arange(
offset * stride, grid_width * stride, step=stride, dtype=torch.float32, device=device
shifts_x = move_device_like(
torch.arange(offset * stride, grid_width * stride, step=stride, dtype=torch.float32),
target_device_tensor,
)
shifts_y = torch.arange(
offset * stride, grid_height * stride, step=stride, dtype=torch.float32, device=device
shifts_y = move_device_like(
torch.arange(offset * stride, grid_height * stride, step=stride, dtype=torch.float32),
target_device_tensor,
)

shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
Expand Down Expand Up @@ -167,7 +171,7 @@ def _grid_anchors(self, grid_sizes: List[List[int]]):
# buffers() not supported by torchscript. use named_buffers() instead
buffers: List[torch.Tensor] = [x[1] for x in self.cell_anchors.named_buffers()]
for size, stride, base_anchors in zip(grid_sizes, self.strides, buffers):
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors.device)
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)

anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4))
Expand Down Expand Up @@ -314,7 +318,7 @@ def num_anchors(self):
def _grid_anchors(self, grid_sizes):
anchors = []
for size, stride, base_anchors in zip(grid_sizes, self.strides, self.cell_anchors):
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors.device)
shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors)
zeros = torch.zeros_like(shift_x)
shifts = torch.stack((shift_x, shift_y, zeros, zeros, zeros), dim=1)

Expand Down
6 changes: 5 additions & 1 deletion detectron2/modeling/meta_arch/dense_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch import Tensor, nn

from detectron2.data.detection_utils import convert_image_to_rgb
from detectron2.layers import move_device_like
from detectron2.modeling import Backbone
from detectron2.structures import Boxes, ImageList, Instances
from detectron2.utils.events import get_event_storage
Expand Down Expand Up @@ -69,6 +70,9 @@ def __init__(
def device(self):
return self.pixel_mean.device

def _move_to_current_device(self, x):
return move_device_like(x, self.pixel_mean)

def forward(self, batched_inputs: List[Dict[str, Tensor]]):
"""
Args:
Expand Down Expand Up @@ -121,7 +125,7 @@ def preprocess_image(self, batched_inputs: List[Dict[str, Tensor]]):
"""
Normalize, pad and batch the input images.
"""
images = [x["image"].to(self.device) for x in batched_inputs]
images = [self._move_to_current_device(x["image"]) for x in batched_inputs]
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
return images
Expand Down
11 changes: 9 additions & 2 deletions detectron2/modeling/meta_arch/rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from detectron2.config import configurable
from detectron2.data.detection_utils import convert_image_to_rgb
from detectron2.layers import move_device_like
from detectron2.structures import ImageList, Instances
from detectron2.utils.events import get_event_storage
from detectron2.utils.logger import log_first_n
Expand Down Expand Up @@ -84,6 +85,9 @@ def from_config(cls, cfg):
def device(self):
return self.pixel_mean.device

def _move_to_current_device(self, x):
return move_device_like(x, self.pixel_mean)

def visualize_training(self, batched_inputs, proposals):
"""
A function used to visualize images and proposals. It shows ground truth
Expand Down Expand Up @@ -221,7 +225,7 @@ def preprocess_image(self, batched_inputs: List[Dict[str, torch.Tensor]]):
"""
Normalize, pad and batch the input images.
"""
images = [x["image"].to(self.device) for x in batched_inputs]
images = [self._move_to_current_device(x["image"]) for x in batched_inputs]
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
return images
Expand Down Expand Up @@ -285,6 +289,9 @@ def from_config(cls, cfg):
def device(self):
return self.pixel_mean.device

def _move_to_current_device(self, x):
return move_device_like(x, self.pixel_mean)

def forward(self, batched_inputs):
"""
Args:
Expand All @@ -296,7 +303,7 @@ def forward(self, batched_inputs):
The dict contains one key "proposals" whose value is a
:class:`Instances` with keys "proposal_boxes" and "objectness_logits".
"""
images = [x["image"].to(self.device) for x in batched_inputs]
images = [self._move_to_current_device(x["image"]) for x in batched_inputs]
images = [(x - self.pixel_mean) / self.pixel_std for x in images]
images = ImageList.from_tensors(images, self.backbone.size_divisibility)
features = self.backbone(images.tensor)
Expand Down
42 changes: 28 additions & 14 deletions detectron2/modeling/poolers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import math
from typing import List
from typing import List, Optional
import torch
from torch import nn
from torchvision.ops import RoIPool
Expand Down Expand Up @@ -58,6 +58,16 @@ def assign_boxes_to_levels(
return level_assignments.to(torch.int64) - min_level


# script the module to avoid hardcoded device type
@torch.jit.script_if_tracing
def _convert_boxes_to_pooler_format(boxes: torch.Tensor, sizes: torch.Tensor) -> torch.Tensor:
sizes = sizes.to(device=boxes.device)
indices = torch.repeat_interleave(
torch.arange(len(sizes), dtype=boxes.dtype, device=boxes.device), sizes
)
return cat([indices[:, None], boxes], dim=1)


def convert_boxes_to_pooler_format(box_lists: List[Boxes]):
"""
Convert all boxes in `box_lists` to the low-level format used by ROI pooling ops
Expand All @@ -83,11 +93,21 @@ def convert_boxes_to_pooler_format(box_lists: List[Boxes]):
"""
boxes = torch.cat([x.tensor for x in box_lists], dim=0)
# __len__ returns Tensor in tracing.
sizes = shapes_to_tensor([x.__len__() for x in box_lists], device=boxes.device)
indices = torch.repeat_interleave(
torch.arange(len(box_lists), dtype=boxes.dtype, device=boxes.device), sizes
)
return cat([indices[:, None], boxes], dim=1)
sizes = shapes_to_tensor([x.__len__() for x in box_lists])
return _convert_boxes_to_pooler_format(boxes, sizes)


@torch.jit.script_if_tracing
def _create_zeros(
batch_target: Optional[torch.Tensor],
channels: int,
height: int,
width: int,
like_tensor: torch.Tensor,
) -> torch.Tensor:
batches = batch_target.shape[0] if batch_target is not None else 0
sizes = (batches, channels, height, width)
return torch.zeros(sizes, dtype=like_tensor.dtype, device=like_tensor.device)


class ROIPooler(nn.Module):
Expand Down Expand Up @@ -214,9 +234,7 @@ def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]):
x[0].size(0), len(box_lists)
)
if len(box_lists) == 0:
return torch.zeros(
(0, x[0].shape[1]) + self.output_size, device=x[0].device, dtype=x[0].dtype
)
return _create_zeros(None, x[0].shape[1], *self.output_size, x[0])

pooler_fmt_boxes = convert_boxes_to_pooler_format(box_lists)

Expand All @@ -227,14 +245,10 @@ def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]):
box_lists, self.min_level, self.max_level, self.canonical_box_size, self.canonical_level
)

num_boxes = pooler_fmt_boxes.size(0)
num_channels = x[0].shape[1]
output_size = self.output_size[0]

dtype, device = x[0].dtype, x[0].device
output = torch.zeros(
(num_boxes, num_channels, output_size, output_size), dtype=dtype, device=device
)
output = _create_zeros(pooler_fmt_boxes, num_channels, output_size, output_size, x[0])

for level, pooler in enumerate(self.level_poolers):
inds = nonzero_tuple(level_assignments == level)[0]
Expand Down
17 changes: 13 additions & 4 deletions detectron2/modeling/proposal_generator/proposal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import List, Tuple, Union
import torch

from detectron2.layers import batched_nms, cat
from detectron2.layers import batched_nms, cat, move_device_like
from detectron2.structures import Boxes, Instances

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -58,13 +58,17 @@ def find_top_rpn_proposals(
objectness score in descending order.
"""
num_images = len(image_sizes)
device = proposals[0].device
device = (
proposals[0].device
if torch.jit.is_scripting()
else ("cpu" if torch.jit.is_tracing() else proposals[0].device)
)

# 1. Select top-k anchor for every level and every image
topk_scores = [] # #lvl Tensor, each of shape N x topk
topk_proposals = []
level_ids = [] # #lvl Tensor, each of shape (topk,)
batch_idx = torch.arange(num_images, device=device)
batch_idx = move_device_like(torch.arange(num_images, device=device), proposals[0])
for level_id, (proposals_i, logits_i) in enumerate(zip(proposals, pred_objectness_logits)):
Hi_Wi_A = logits_i.shape[1]
if isinstance(Hi_Wi_A, torch.Tensor): # it's a tensor in tracing
Expand All @@ -79,7 +83,12 @@ def find_top_rpn_proposals(

topk_proposals.append(topk_proposals_i)
topk_scores.append(topk_scores_i)
level_ids.append(torch.full((num_proposals_i,), level_id, dtype=torch.int64, device=device))
level_ids.append(
move_device_like(
torch.full((num_proposals_i,), level_id, dtype=torch.int64, device=device),
proposals[0],
)
)

# 2. Concat all levels together
topk_scores = cat(topk_scores, dim=1)
Expand Down
8 changes: 7 additions & 1 deletion detectron2/modeling/roi_heads/mask_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from detectron2.config import configurable
from detectron2.layers import Conv2d, ConvTranspose2d, ShapeSpec, cat, get_norm
from detectron2.layers.wrappers import move_device_like
from detectron2.structures import Instances
from detectron2.utils.events import get_event_storage
from detectron2.utils.registry import Registry
Expand Down Expand Up @@ -141,7 +142,12 @@ def mask_rcnn_inference(pred_mask_logits: torch.Tensor, pred_instances: List[Ins
# Select masks corresponding to the predicted classes
num_masks = pred_mask_logits.shape[0]
class_pred = cat([i.pred_classes for i in pred_instances])
indices = torch.arange(num_masks, device=class_pred.device)
device = (
class_pred.device
if torch.jit.is_scripting()
else ("cpu" if torch.jit.is_tracing() else class_pred.device)
)
indices = move_device_like(torch.arange(num_masks, device=device), class_pred)
mask_probs_pred = pred_mask_logits[indices, class_pred][:, None].sigmoid()
# mask_probs_pred.shape: (B, 1, Hmask, Wmask)

Expand Down
8 changes: 5 additions & 3 deletions detectron2/structures/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,14 @@ def __init__(self, tensor: torch.Tensor):
Args:
tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2).
"""
device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
if not isinstance(tensor, torch.Tensor):
tensor = torch.as_tensor(tensor, dtype=torch.float32, device=torch.device("cpu"))
else:
tensor = tensor.to(torch.float32)
if tensor.numel() == 0:
# Use reshape, so we don't end up creating a new tensor that does not depend on
# the inputs (and consequently confuses jit)
tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32, device=device)
tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32)
assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()

self.tensor = tensor
Expand Down
8 changes: 6 additions & 2 deletions detectron2/structures/image_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from torch import device
from torch.nn import functional as F

from detectron2.layers.wrappers import shapes_to_tensor
from detectron2.layers.wrappers import move_device_like, shapes_to_tensor


class ImageList(object):
Expand Down Expand Up @@ -103,7 +103,11 @@ def from_tensors(
else:
# max_size can be a tensor in tracing mode, therefore convert to list
batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size)
batched_imgs = tensors[0].new_full(batch_shape, pad_value)
device = (
None if torch.jit.is_scripting() else ("cpu" if torch.jit.is_tracing() else None)
)
batched_imgs = tensors[0].new_full(batch_shape, pad_value, device=device)
batched_imgs = move_device_like(batched_imgs, tensors[0])
for img, pad_img in zip(tensors, batched_imgs):
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)

Expand Down
6 changes: 4 additions & 2 deletions detectron2/structures/masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@ def __init__(self, tensor: Union[torch.Tensor, np.ndarray]):
Args:
tensor: bool Tensor of N,H,W, representing N instances in the image.
"""
device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
tensor = torch.as_tensor(tensor, dtype=torch.bool, device=device)
if isinstance(tensor, torch.Tensor):
tensor = tensor.to(torch.bool)
else:
tensor = torch.as_tensor(tensor, dtype=torch.bool, device=torch.device("cpu"))
assert tensor.dim() == 3, tensor.size()
self.image_size = tensor.shape[1:]
self.tensor = tensor
Expand Down
Loading

0 comments on commit 2bd05b4

Please sign in to comment.