forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[iOS][GPU] Add Metal/MPSCNN support on iOS (pytorch#46112)
Summary: Pull Request resolved: pytorch#46112 ### Summary This PR adds the support of running torchscript models on iOS GPU via Metal (Inference only). The feature is currently in prototype state, API changes are expected. The tutorial and the documents will be added once it goes to beta. allow-large-files - Users API ``` auto module = torch::jit::load(model); module.eval(); at::Tensor input = at::ones({1,3,224,224}, at::ScalarType::Float).metal(); auto output = module.forward({input}).toTensor().cpu(); ``` - Supported Models - Person Segmentation v106 (FB Internal) - Mobilenetv2 - Supported Operators - aten::conv2d - aten::addmm - aten::add.Tensor - aten::sub.Tensor - aten::mul.Tensor - aten::relu - aten::hardtanh - aten::hardtanh_ - aten::sigmoid - aten::max_pool2d - aten::adaptive_avg_pool2d - aten::reshape - aten::t - aten::view - aten::log_softmax.int - aten::upsample_nearest2d.vec - Supported Devices - Apple A9 and above - iOS 10.2 and above - CMake scripts - `IOS_ARCH=arm64 ./scripts/build_ios.sh -DUSE_METAL=ON` ### Test Plan - Circle CI ghstack-source-id: 114155638 Test Plan: 1. Sandcastle CI 2. Circle CI Reviewed By: dreiss Differential Revision: D23236555 fbshipit-source-id: 98ffc48b837e308bc678c37a9a5fd8ae72d11625
- Loading branch information
1 parent
7f6a1b2
commit a277c09
Showing
54 changed files
with
4,149 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#include <atomic> | ||
|
||
#include <ATen/Tensor.h> | ||
#include <ATen/metal/Context.h> | ||
|
||
namespace at { | ||
namespace metal { | ||
|
||
std::atomic<const MetalInterface*> g_metal_impl_registry; | ||
|
||
MetalImplRegistrar::MetalImplRegistrar(MetalInterface* impl) { | ||
g_metal_impl_registry.store(impl); | ||
} | ||
|
||
at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src) { | ||
auto p = at::metal::g_metal_impl_registry.load(); | ||
if (p) { | ||
return p->metal_copy_(self, src); | ||
} | ||
AT_ERROR("Metal backend was not linked to the build"); | ||
} | ||
} // namespace metal | ||
|
||
namespace native { | ||
bool is_metal_available() { | ||
auto p = at::metal::g_metal_impl_registry.load(); | ||
return p ? p->is_metal_available() : false; | ||
} | ||
|
||
} // namespace native | ||
} // namespace at |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#ifndef MetalContext_h | ||
#define MetalContext_h | ||
|
||
#include <atomic> | ||
|
||
#include <ATen/Tensor.h> | ||
|
||
namespace at { | ||
namespace metal { | ||
|
||
struct MetalInterface { | ||
virtual ~MetalInterface() = default; | ||
virtual bool is_metal_available() const = 0; | ||
virtual at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src) | ||
const = 0; | ||
}; | ||
|
||
extern std::atomic<const MetalInterface*> g_metal_impl_registry; | ||
|
||
class MetalImplRegistrar { | ||
public: | ||
explicit MetalImplRegistrar(MetalInterface*); | ||
}; | ||
|
||
at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src); | ||
|
||
} // namespace metal | ||
} // namespace at | ||
|
||
#endif /* MetalContext_h */ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.