Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
enable moving traced model between devices
Summary: Pull Request resolved: facebookresearch#4132 X-link: fairinternal/detectron2#568 X-link: facebookresearch/d2go#203 For full discussion: https://fb.workplace.com/groups/1405155842844877/posts/5744470455580039 Tracing the `.to(device)` will cause problem when moving the traced torchscript to another device (eg. from cpu to gpu, or even, from `cuda:0` to `cuda:1`). The reason is that `device` is not a `torch.Tensor`, so the tracer just hardcode the value during tracing. The solution is scripting the casting operation. Here's the code snippet illustrating this: ``` # define the MyModel similar to GeneralizedRCNN, which casts the input to the model's device class MyModel(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): # cast the input to the same device as this model, this makes it possible to # take a cpu tensor as input when the model is on GPU. x = x.to(self.conv1.weight.device) x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) # export the model by tracing model = MyModel() x = torch.zeros([1, 3, 32, 32]) ts = torch.jit.trace(model, x) print(ts.graph) # ===================================================== graph(%self.1 : __torch__.MyModel, %x : Float(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu)): %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_0.Conv2d = prim::GetAttr[name="conv2"](%self.1) %conv1 : __torch__.torch.nn.modules.conv.Conv2d = prim::GetAttr[name="conv1"](%self.1) %14 : int = prim::Constant[value=6]() # <ipython-input-2-5abde0efc36f>:11:0 %15 : int = prim::Constant[value=0]() # <ipython-input-2-5abde0efc36f>:11:0 %16 : Device = prim::Constant[value="cpu"]() # <ipython-input-2-5abde0efc36f>:11:0 %17 : NoneType = prim::Constant() %18 : bool = prim::Constant[value=0]() # <ipython-input-2-5abde0efc36f>:11:0 %19 : bool = prim::Constant[value=0]() # <ipython-input-2-5abde0efc36f>:11:0 %20 : NoneType = prim::Constant() %input.1 : Float(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu) = aten::to(%x, %14, %15, %16, %17, %18, %19, %20) # <ipython-input-2-5abde0efc36f>:11:0 %72 : Tensor = prim::CallMethod[name="forward"](%conv1, %input.1) %input.5 : Float(1, 20, 28, 28, strides=[15680, 784, 28, 1], requires_grad=1, device=cpu) = aten::relu(%72) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0 %73 : Tensor = prim::CallMethod[name="forward"](%conv2, %input.5) %61 : Float(1, 20, 24, 24, strides=[11520, 576, 24, 1], requires_grad=1, device=cpu) = aten::relu(%73) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0 return (%61) # ===================================================== # PyTorch cuda works model = copy.deepcopy(model) model.to("cuda") y = model(x) # torchscript cpu works y = ts(x) # torchscript cuda doesn't work ts = ts.to("cuda") y = ts(x) # ===================================================== RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor --------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-4-2aece3ad6c9a> in <module> 7 # torchscript cuda doesn't work 8 ts = ts.to("cuda") ----> 9 y = ts(x) /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1109 or _global_forward_hooks or _global_forward_pre_hooks): -> 1110 return forward_call(*input, **kwargs) 1111 # Do not call functions when jit is used 1112 full_backward_hooks, non_full_backward_hooks = [], [] RuntimeError: The following operation failed in the TorchScript interpreter. # ===================================================== # One solution is scripting the casting instead of tracing it, the folloing code demonstrate how to do it. We need to use mixed scripting/tracing torch.jit.script_if_tracing def cast_device_like(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor: return src.to(dst.device) class MyModel2(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): # cast the input to the same device as this model, this makes it possible to # take a cpu tensor as input when the model is on GPU. x = cast_device_like(x, self.conv1.weight) x = F.relu(self.conv1(x)) return F.relu(self.conv2(x)) # export the model by tracing model = MyModel2() x = torch.zeros([1, 3, 32, 32]) ts = torch.jit.trace(model, x) print(ts.graph) # ===================================================== graph(%self.1 : __torch__.MyModel2, %x : Float(1, 3, 32, 32, strides=[3072, 1024, 32, 1], requires_grad=0, device=cpu)): %conv2 : __torch__.torch.nn.modules.conv.___torch_mangle_5.Conv2d = prim::GetAttr[name="conv2"](%self.1) %conv1 : __torch__.torch.nn.modules.conv.___torch_mangle_4.Conv2d = prim::GetAttr[name="conv1"](%self.1) %conv1.1 : __torch__.torch.nn.modules.conv.___torch_mangle_4.Conv2d = prim::GetAttr[name="conv1"](%self.1) %weight.5 : Tensor = prim::GetAttr[name="weight"](%conv1.1) %14 : Function = prim::Constant[name="cast_device_like"]() %input.1 : Tensor = prim::CallFunction(%14, %x, %weight.5) %68 : Tensor = prim::CallMethod[name="forward"](%conv1, %input.1) %input.5 : Float(1, 20, 28, 28, strides=[15680, 784, 28, 1], requires_grad=1, device=cpu) = aten::relu(%68) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0 %69 : Tensor = prim::CallMethod[name="forward"](%conv2, %input.5) %55 : Float(1, 20, 24, 24, strides=[11520, 576, 24, 1], requires_grad=1, device=cpu) = aten::relu(%69) # /mnt/xarfuse/uid-20293/a90d1698-seed-nspid4026533681_cgpid21128615-ns-4026533618/torch/nn/functional.py:1406:0 return (%55) # ===================================================== # PyTorch cuda works model = copy.deepcopy(model) model.to("cuda") y = model(x) # torchscript cpu works y = ts(x) # Note that now torchscript cuda works ts = ts.to("cuda") y = ts(x) print(y.device) # ===================================================== cuda:0 # ===================================================== ``` For D2 (facebookresearch@11528ce), this diff creates a `move_tensor_device_same_as_another(A, B)` function to replace `A.to(B.device)`. This diff updates the `rcnn.py` and all its utils. For D2 (facebookresearch@11528ce083dc9ff83ee3a8f9086a1ef54d2a402f)Go, since the exported model will become device-agnostic, we can remove the "_gpu" from predictor-type. Update (April 11): Add test to cover tracing on one device and move traced model to another device for inference. When GPU is available, it'll trace on `cuda:0` and run inference on `cpu`, `cuda:0` (and `cuda:N-1` if available). Summary of the device related patterns - The usage of `.to(dtype=another_dype)` won't affect device. - Explicit device casting like `.to(device)` can be generally replaced by `move_device_like`. - For creating variable directly on device (eg. `torch.zeros`, `torch.arange`), we can replace then with ScriptModule to avoid first create on CPU and then move to new device. - Creating things on tracing device and then moving to new device is dangerous, because tracing device (eg. `cuda:0`) might not be available (eg. running on CPU-only machine). - It's hard to write `image_list.py` in this pattern because the size behaves differently during tracing (int vs. scalar tensor), in this diff, still create on CPU first and then move to target device. Reviewed By: tglik Differential Revision: D35367772 fbshipit-source-id: 02d07e3d96da85f4cfbeb996e3c14c2a6f619beb
- Loading branch information