Skip to content

Commit

Permalink
[iOS][GPU] Add Metal/MPSCNN support on iOS (pytorch#46112)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#46112

### Summary

This PR adds the support of running torchscript models on iOS GPU via Metal (Inference only). The feature is currently in prototype state, API changes are expected. The tutorial and the documents will be added once it goes to beta.

allow-large-files

- Users API

```
  auto module = torch::jit::load(model);
  module.eval();
  at::Tensor input = at::ones({1,3,224,224}, at::ScalarType::Float).metal();
  auto output = module.forward({input}).toTensor().cpu();
```
- Supported Models
    - Person Segmentation v106 (FB Internal)
    - Mobilenetv2

- Supported Operators
    - aten::conv2d
    - aten::addmm
    - aten::add.Tensor
    - aten::sub.Tensor
    - aten::mul.Tensor
    - aten::relu
    - aten::hardtanh
    - aten::hardtanh_
    - aten::sigmoid
    - aten::max_pool2d
    - aten::adaptive_avg_pool2d
    - aten::reshape
    - aten::t
    - aten::view
    - aten::log_softmax.int
    - aten::upsample_nearest2d.vec

- Supported Devices
    - Apple A9 and above
    - iOS 10.2 and above

- CMake scripts
    - `IOS_ARCH=arm64 ./scripts/build_ios.sh -DUSE_METAL=ON`

### Test Plan

- Circle CI

ghstack-source-id: 114155638

Test Plan:
1. Sandcastle CI
2. Circle CI

Reviewed By: dreiss

Differential Revision: D23236555

fbshipit-source-id: 98ffc48b837e308bc678c37a9a5fd8ae72d11625
  • Loading branch information
xta0 authored and facebook-github-bot committed Oct 13, 2020
1 parent 7f6a1b2 commit a277c09
Show file tree
Hide file tree
Showing 54 changed files with 4,149 additions and 9 deletions.
6 changes: 6 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,11 @@ filegroup(
srcs = glob(["aten/src/ATen/vulkan/*.cpp"]),
)

filegroup(
name = "aten_base_metal",
srcs = glob(["aten/src/ATen/metal/*.cpp"]),
)

filegroup(
name = "ATen_QUANTIZED_SRCS",
srcs = glob(
Expand Down Expand Up @@ -650,6 +655,7 @@ cc_library(
":ATen_CORE_SRCS",
":ATen_QUANTIZED_SRCS",
":aten_base_cpp",
":aten_base_metal",
":aten_base_vulkan",
":aten_native_cpp",
":aten_native_mkl_cpp",
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ option(USE_GLOG "Use GLOG" OFF)
option(USE_LEVELDB "Use LEVELDB" OFF)
option(USE_LITE_PROTO "Use lite protobuf instead of full." OFF)
option(USE_LMDB "Use LMDB" OFF)
option(USE_METAL "Use Metal for iOS build" ON)
option(USE_METAL "Use Metal for iOS build" OFF)
option(USE_NATIVE_ARCH "Use -march=native" OFF)
cmake_dependent_option(
USE_NCCL "Use NCCL" ON
Expand Down
18 changes: 18 additions & 0 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ file(GLOB native_mkl_cpp "native/mkl/*.cpp")
file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp")
file(GLOB vulkan_cpp "vulkan/*.cpp")
file(GLOB native_vulkan_cpp "native/vulkan/api/*.cpp" "native/vulkan/*.cpp")

file(GLOB metal_h "metal/*.h")
file(GLOB metal_cpp "metal/*.cpp")
file(GLOB_RECURSE native_metal_h "native/metal/*.h")
file(GLOB metal_test_srcs "native/metal/mpscnn/tests/*.mm")
file(GLOB_RECURSE native_metal_srcs "native/metal/*.mm", "native/metal/*.cpp")
EXCLUDE(native_metal_srcs "${native_metal_srcs}" ${metal_test_srcs})

file(GLOB native_sparse_cpp "native/sparse/*.cpp")
file(GLOB native_quantized_cpp
"native/quantized/*.cpp"
Expand Down Expand Up @@ -117,6 +125,12 @@ else()
set(all_cpu_cpp ${all_cpu_cpp} ${vulkan_cpp})
endif()

if(USE_METAL)
set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp} ${native_metal_srcs})
else()
set(all_cpu_cpp ${all_cpu_cpp} ${metal_cpp})
endif()

if(USE_CUDA AND USE_ROCM)
message(FATAL_ERROR "ATen doesn't not currently support simultaneously building with CUDA and ROCM")
endif()
Expand Down Expand Up @@ -375,6 +389,10 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake"
set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS})
if(NOT INTERN_BUILD_MOBILE)
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${miopen_h})
else()
if(USE_METAL)
list(APPEND INSTALL_HEADERS ${metal_h} ${native_metal_h})
endif()
endif()

# https://stackoverflow.com/questions/11096471/how-can-i-install-a-hierarchy-of-files-using-cmake
Expand Down
31 changes: 31 additions & 0 deletions aten/src/ATen/metal/Context.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include <atomic>

#include <ATen/Tensor.h>
#include <ATen/metal/Context.h>

namespace at {
namespace metal {

std::atomic<const MetalInterface*> g_metal_impl_registry;

MetalImplRegistrar::MetalImplRegistrar(MetalInterface* impl) {
g_metal_impl_registry.store(impl);
}

at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src) {
auto p = at::metal::g_metal_impl_registry.load();
if (p) {
return p->metal_copy_(self, src);
}
AT_ERROR("Metal backend was not linked to the build");
}
} // namespace metal

namespace native {
bool is_metal_available() {
auto p = at::metal::g_metal_impl_registry.load();
return p ? p->is_metal_available() : false;
}

} // namespace native
} // namespace at
30 changes: 30 additions & 0 deletions aten/src/ATen/metal/Context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef MetalContext_h
#define MetalContext_h

#include <atomic>

#include <ATen/Tensor.h>

namespace at {
namespace metal {

struct MetalInterface {
virtual ~MetalInterface() = default;
virtual bool is_metal_available() const = 0;
virtual at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src)
const = 0;
};

extern std::atomic<const MetalInterface*> g_metal_impl_registry;

class MetalImplRegistrar {
public:
explicit MetalImplRegistrar(MetalInterface*);
};

at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src);

} // namespace metal
} // namespace at

#endif /* MetalContext_h */
7 changes: 6 additions & 1 deletion aten/src/ATen/native/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <ATen/native/quantized/Copy.h>
#include <ATen/quantized/Quantizer.h>
#include <ATen/vulkan/Context.h>
#include <ATen/metal/Context.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/NamedTensorUtils.h>
#include <torch/library.h>
Expand Down Expand Up @@ -79,7 +80,7 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
// (e.g. XLA) may be supported by overriding copy_ and _copy_from.
bool is_supported_device(Device device) {
DeviceType device_type = device.type();
return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan;
return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan || device_type == kMetal;
}

} // namespace
Expand Down Expand Up @@ -133,6 +134,10 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
return at::vulkan::vulkan_copy_(self, src);
}

if (self.device().type() == at::kMetal || src.device().type() == at::kMetal) {
return at::metal::metal_copy_(self, src);
}

auto iter = TensorIteratorConfig()
.add_output(self)
.add_input(src)
Expand Down
Loading

0 comments on commit a277c09

Please sign in to comment.