Skip to content

Commit b460465

Browse files
IvanKobzarevfacebook-github-bot
authored andcommitted
[Mobile GPU][Integration] Vulkan backend integration (pytorch#36491)
Summary: This PR contains the initial version of Vulkan (GPU) Backend integration. The primary target environment is Android, but the desktop build is also supported. ## CMake Introducing three cmake options: USE_VULKAN: The main switch, if it is off, all other options do not affect. USE_VULKAN_WRAPPER: ON - Vulkan will be used loading it at runtime as "libvulkan.so" using libdl, every function call is wrapped in vulkan_wrapper.h. OFF - linking with libvulkan.so directly USE_VULKAN_SHADERC_RUNTIME: ON - Shader compilation library will be linked, and shaders will be compiled runtime. OFF - Shaders will be precompiled and shader compilation library is not included. ## Codegen if `USE_VULKAN_SHADERC_RUNTIME` is ON: Shaders precompilation () starts in cmake/VulkanCodegen.cmake, which calls `aten/src/ATen/native/vulkan/gen_glsl.py` or `aten/src/ATen/native/vulkan/gen_spv.py` to include shaders source or SPIR-V bytecode inside binary as uint32_t array in spv.h,spv.cpp. if `USE_VULKAN_SHADERC_RUNTIME` is OFF: The source of shaders is included as `glsl.h`,`glsl.cpp`. All codegen results happen in the build directory. ## Build dependencies cmake/Dependencies.cmake If the target platform is Android - vulkan library, headers, Vulkan wrapper will be used from ANDROID_NDK. Desktop build requires the VULKAN_SDK environment variable, and all vulkan dependencies will be used from it. (Desktop build was tested only on Linux). ## Pytorch integration: Adding 'Vulkan" as new Backend, DispatchKey, DeviceType. We are using Strided layout without supporting strides at the moment, but we plan to support them in the future. Using OpaqueTensorImpl where OpaqueHandle is copyable VulkanTensor, more details in comments in `aten/src/ATen/native/vulkan/Vulkan.h` Main code location: `aten/src/ATen/native/vulkan` `aten/src/ATen/native/vulkan/VulkanAten.cpp` - connection link between ATen and Vulkan api (Vulkan.h) that converts at::Tensor to VulkanTensor. `aten/src/ATen/native/Vulkan/Vulkan.h` - Vulkan API that contains VulkanTensor representation and functions to work with it. Plan to expose it for clients to be able to write their own Vulkan Ops. `aten/src/ATen/native/vulkan/VulkanOps.cpp` - Vulkan Operations Implementations that uses Vulkan.h API ## GLSL shaders Located in `aten/src/ATen/native/vulkan/glsl` as *.glsl files. All shaders use Vulkan specialized constants for workgroup sizes with ids 1, 2, 3 ## Supported operations Code point: conv2d no-groups conv2d depthwise addmm upsample nearest 2d clamp hardtanh ## Testing `aten/src/ATen/test/vulkan_test.cpp` - contains tests for copy from CPU to Vulkan and back all supported operations Desktop builds supported, and testing can be done on a desktop that has Vulkan supported GPU or with installed software implementation of Vulkan, like https://github.com/google/swiftshader ## Vulkan execution The initial implementation is trivial and waits every operator's execution. Pull Request resolved: pytorch#36491 Differential Revision: D21696709 Pulled By: IvanKobzarev fbshipit-source-id: da3e5a770b1a1995e9465d7e81963e7de56217fa
1 parent 1fa0bb6 commit b460465

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+4923
-10
lines changed

CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ option(USE_SNPE "Use Qualcomm's SNPE library" OFF)
191191
option(USE_SYSTEM_EIGEN_INSTALL
192192
"Use system Eigen instead of the one under third_party" OFF)
193193
option(USE_TENSORRT "Using Nvidia TensorRT library" OFF)
194+
option(USE_VULKAN "Use Vulkan GPU backend" OFF)
195+
option(USE_VULKAN_WRAPPER "Use Vulkan wrapper" ON)
196+
option(USE_VULKAN_SHADERC_RUNTIME "Use Vulkan Shader compilation runtime(Needs shaderc lib)" OFF)
194197
option(USE_XNNPACK "Use XNNPACK" ON)
195198
option(USE_ZMQ "Use ZMQ" OFF)
196199
option(USE_ZSTD "Use ZSTD" OFF)
@@ -475,6 +478,18 @@ if(USE_XNNPACK)
475478
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_XNNPACK -DUSE_INTERNAL_THREADPOOL_IMPL")
476479
endif()
477480

481+
if(USE_VULKAN)
482+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_VULKAN")
483+
endif()
484+
485+
if(USE_VULKAN_WRAPPER)
486+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_VULKAN_WRAPPER")
487+
endif()
488+
489+
if(USE_VULKAN_SHADERC_RUNTIME)
490+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_VULKAN_SHADERC_RUNTIME")
491+
endif()
492+
478493
# ---[ Whitelist file if whitelist is specified
479494
include(cmake/Whitelist.cmake)
480495

aten/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ set(ATen_HIP_SRCS)
3030
set(ATen_HIP_SRCS_W_SORT_BY_KEY)
3131
set(ATen_HIP_TEST_SRCS)
3232
set(ATen_HIP_INCLUDE)
33+
set(ATen_VULKAN_TEST_SRCS)
3334
set(ATen_CPU_DEPENDENCY_LIBS)
3435
set(ATen_CUDA_DEPENDENCY_LIBS)
3536
set(ATen_HIP_DEPENDENCY_LIBS)
@@ -51,6 +52,9 @@ set(TH_CPU_INCLUDE
5152
${CMAKE_BINARY_DIR}/aten/src)
5253
list(APPEND ATen_CPU_INCLUDE ${TH_CPU_INCLUDE})
5354

55+
if(USE_VULKAN)
56+
list(APPEND ATen_CPU_INCLUDE ${CMAKE_BINARY_DIR}/vulkan)
57+
endif()
5458

5559
# Find the HIP package, set the HIP paths, load the HIP CMake.
5660
if(USE_ROCM)
@@ -113,6 +117,7 @@ set(ATen_NVRTC_STUB_SRCS ${ATen_NVRTC_STUB_SRCS} PARENT_SCOPE)
113117
set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE)
114118
set(ATen_CUDA_TEST_SRCS ${ATen_CUDA_TEST_SRCS} PARENT_SCOPE)
115119
set(ATen_HIP_TEST_SRCS ${ATen_HIP_TEST_SRCS} PARENT_SCOPE)
120+
set(ATen_VULKAN_TEST_SRCS ${ATen_VULKAN_TEST_SRCS} PARENT_SCOPE)
116121
set(ATen_CPU_INCLUDE ${ATen_CPU_INCLUDE} PARENT_SCOPE)
117122
set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE} PARENT_SCOPE)
118123
set(ATen_HIP_INCLUDE ${ATen_HIP_INCLUDE} PARENT_SCOPE)

aten/src/ATen/CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ file(GLOB mkldnn_cpp "mkldnn/*.cpp")
6363
file(GLOB native_cpp "native/*.cpp")
6464
file(GLOB native_mkl_cpp "native/mkl/*.cpp")
6565
file(GLOB native_mkldnn_cpp "native/mkldnn/*.cpp")
66+
file(GLOB native_vulkan_cpp "native/vulkan/*.cpp")
67+
file(GLOB native_vulkan_stub_cpp "native/vulkan/stub/*.cpp")
6668
file(GLOB native_sparse_cpp "native/sparse/*.cpp")
6769
file(GLOB native_quantized_cpp
6870
"native/quantized/*.cpp"
@@ -105,6 +107,11 @@ endif()
105107
if(AT_MKLDNN_ENABLED)
106108
set(all_cpu_cpp ${all_cpu_cpp} ${mkldnn_cpp})
107109
endif()
110+
if(USE_VULKAN)
111+
set(all_cpu_cpp ${all_cpu_cpp} ${native_vulkan_cpp} ${vulkan_generated_cpp})
112+
else()
113+
set(all_cpu_cpp ${all_cpu_cpp} ${native_vulkan_stub_cpp})
114+
endif()
108115

109116
if(USE_CUDA AND USE_ROCM)
110117
message(FATAL_ERROR "ATen doesn't not currently support simultaneously building with CUDA and ROCM")
@@ -324,6 +331,7 @@ endif()
324331
# Include CPU paths for CUDA/HIP as well
325332
list(APPEND ATen_CUDA_INCLUDE ${ATen_CPU_INCLUDE})
326333
list(APPEND ATen_HIP_INCLUDE ${ATen_CPU_INCLUDE})
334+
list(APPEND ATen_VULKAN_INCLUDE ${ATen_CPU_INCLUDE})
327335

328336
# We have two libraries: libATen_cpu.so and libATen_cuda.so,
329337
# with libATen_cuda.so depending on libATen_cpu.so. The CPU library
@@ -402,11 +410,13 @@ set(ATen_CPU_TEST_SRCS ${ATen_CPU_TEST_SRCS} PARENT_SCOPE)
402410
set(ATen_CUDA_TEST_SRCS ${ATen_CUDA_TEST_SRCS} PARENT_SCOPE)
403411
set(ATen_CORE_TEST_SRCS ${ATen_CORE_TEST_SRCS} PARENT_SCOPE)
404412
set(ATen_HIP_TEST_SRCS ${ATen_HIP_TEST_SRCS} PARENT_SCOPE)
413+
set(ATen_VULKAN_TEST_SRCS ${ATen_VULKAN_TEST_SRCS} PARENT_SCOPE)
405414
set(ATen_QUANTIZED_TEST_SRCS ${ATen_QUANTIZED_TEST_SRCS} PARENT_SCOPE)
406415
set(ATen_CPU_INCLUDE ${ATen_CPU_INCLUDE} PARENT_SCOPE)
407416
set(ATen_THIRD_PARTY_INCLUDE ${ATen_THIRD_PARTY_INCLUDE} PARENT_SCOPE)
408417
set(ATen_CUDA_INCLUDE ${ATen_CUDA_INCLUDE} PARENT_SCOPE)
409418
set(ATen_HIP_INCLUDE ${ATen_HIP_INCLUDE} PARENT_SCOPE)
419+
set(ATen_VULKAN_INCLUDE ${ATen_VULKAN_INCLUDE} PARENT_SCOPE)
410420
set(ATen_CPU_DEPENDENCY_LIBS ${ATen_CPU_DEPENDENCY_LIBS} PARENT_SCOPE)
411421
set(ATen_CUDA_DEPENDENCY_LIBS ${ATen_CUDA_DEPENDENCY_LIBS} PARENT_SCOPE)
412422
set(ATen_HIP_DEPENDENCY_LIBS ${ATen_HIP_DEPENDENCY_LIBS} PARENT_SCOPE)

aten/src/ATen/function_wrapper.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ def TypedDict(name, attrs, total=True): # type: ignore
192192
break;
193193
""")
194194

195+
IFDEF_BLOCK = CodeTemplate("""\
196+
#ifdef ${ifdef_guard}
197+
${content}
198+
#endif
199+
""")
200+
195201
# add a native declaration for a native function
196202
NATIVE_DECLARATION = CodeTemplate("""\
197203
CAFFE2_API ${return_type} ${native_type_method_dispatch}(${formals_with_defaults});
@@ -221,7 +227,8 @@ def TypedDict(name, attrs, total=True): # type: ignore
221227
('ComplexDouble', 'ComplexDouble', 'ComplexDouble', False),
222228
]
223229

224-
static_dispatch_backends = ['CPU', 'QuantizedCPU']
230+
static_dispatch_backends = ['CPU', 'QuantizedCPU', 'Vulkan']
231+
static_dispatch_backends_ifdef_guard = {'Vulkan' : 'USE_VULKAN'}
225232

226233

227234
class NYIError(Exception):
@@ -1059,11 +1066,18 @@ def swizzle_self(f): # blegh
10591066
# calling code.
10601067
for backend in static_dispatch_backends:
10611068
if backend in type_method_dispatch:
1062-
static_dispatch_function_cases.append(STATIC_DISPATCH_FUNCTION_SWITCH_CASE.substitute(
1069+
static_dispatch_function_case = STATIC_DISPATCH_FUNCTION_SWITCH_CASE.substitute(
10631070
option,
10641071
backend=backend,
10651072
backend_function=type_method_dispatch[backend],
1066-
actuals=option['method_actuals']))
1073+
actuals=option['method_actuals'])
1074+
if (backend in static_dispatch_backends_ifdef_guard):
1075+
static_dispatch_function_cases.append(IFDEF_BLOCK.substitute(
1076+
option,
1077+
ifdef_guard=static_dispatch_backends_ifdef_guard[backend],
1078+
content=static_dispatch_function_case))
1079+
else:
1080+
static_dispatch_function_cases.append(static_dispatch_function_case)
10671081

10681082
static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute(
10691083
option,
@@ -1094,11 +1108,18 @@ def gen_namespace_function(option, multidispatch_formals):
10941108
static_dispatch_function_cases = []
10951109
for backend in static_dispatch_backends:
10961110
if backend in type_method_dispatch:
1097-
static_dispatch_function_cases.append(STATIC_DISPATCH_FUNCTION_SWITCH_CASE.substitute(
1111+
static_dispatch_function_case = STATIC_DISPATCH_FUNCTION_SWITCH_CASE.substitute(
10981112
option,
10991113
backend=backend,
11001114
backend_function=type_method_dispatch[backend],
1101-
actuals=option['actuals']))
1115+
actuals=option['actuals'])
1116+
if (backend in static_dispatch_backends_ifdef_guard):
1117+
static_dispatch_function_cases.append(IFDEF_BLOCK.substitute(
1118+
option,
1119+
ifdef_guard=static_dispatch_backends_ifdef_guard[backend],
1120+
content=static_dispatch_function_case))
1121+
else:
1122+
static_dispatch_function_cases.append(static_dispatch_function_case)
11021123
static_dispatch_function_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute(
11031124
option,
11041125
dispatch_key_var_name=dispatch_key_var_name,

aten/src/ATen/gen.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545
'--rocm',
4646
action='store_true',
4747
help='reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly')
48+
parser.add_argument(
49+
'--vulkan',
50+
action='store_true',
51+
help='Generate Vulkan backend functions')
4852
parser.add_argument(
4953
'--op_registration_whitelist',
5054
nargs='*',
@@ -67,6 +71,7 @@
6771
help='force it to generate schema-only registrations for all ops, including'
6872
'those that are not listed on --op_registration_whitelist')
6973
options = parser.parse_args()
74+
7075
# NB: It is mandatory to NOT use os.path.join here, as the install directory
7176
# will eventually be ingested by cmake, which does not respect Windows style
7277
# path slashes. If you switch this to use os.path.join, you'll get an error
@@ -365,7 +370,7 @@ def generate_storage_type_and_tensor(backend, density, declarations, per_op_regi
365370
fm.write(env['Type'] + ".cpp", SPARSE_TYPE_DERIVED_CPP, env)
366371
fm.write(env['Type'] + ".h", TYPE_DERIVED_H, env)
367372

368-
if env['DeviceType'] == 'CPU':
373+
if env['DeviceType'] == 'CPU' or env['DeviceType'] == 'Vulkan':
369374
top_env['cpu_type_headers'].append(
370375
'#include <ATen/{}.h>'.format(env['Type']))
371376
else:
@@ -384,6 +389,8 @@ def iterate_types():
384389
yield (backend, density)
385390
for backend in quantized_backends:
386391
yield (backend, 'Dense')
392+
if options.vulkan:
393+
yield('Vulkan', 'Dense')
387394

388395

389396
def gen_per_op_registration_filename(opname):

aten/src/ATen/native/Convolution.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
#if AT_NNPACK_ENABLED()
1313
#include <nnpack.h>
1414
#endif
15+
#ifdef USE_VULKAN
16+
#include <ATen/native/vulkan/VulkanAten.h>
17+
#endif
1518

1619

1720
constexpr int MIOPEN_DIM_MAX = 5;
@@ -47,6 +50,7 @@ struct ConvParams {
4750
bool use_mkldnn(const at::Tensor& input) const;
4851
bool use_nnpack(const at::Tensor& input) const;
4952
bool use_xnnpack(const at::Tensor& input, const at::Tensor& weight, const at::Tensor& bias) const;
53+
bool use_vulkan(const at::Tensor& input, const at::Tensor& weight) const;
5054
bool is_depthwise(const at::Tensor& input, const at::Tensor& weight) const;
5155
};
5256

@@ -274,6 +278,20 @@ auto ConvParams::use_xnnpack(
274278
return false;
275279
}
276280

281+
auto ConvParams::use_vulkan(
282+
const at::Tensor &input, const at::Tensor& weight) const -> bool {
283+
#ifdef USE_VULKAN
284+
if (!(input.is_vulkan() && input.scalar_type() == kFloat &&
285+
!transposed && input.ndimension() == 4)) {
286+
return false;
287+
}
288+
return (groups == 1) || (input.size(1) == groups && groups > 1 &&
289+
weight.size(0) % input.size(1) == 0);
290+
#else
291+
return false;
292+
#endif
293+
}
294+
277295
// We currently only have depthwise support for the case where groups ==
278296
// nInputPlane and nInputPlane == nOutputPlane (the latter due to the lack of
279297
// a depthwise multiplier)
@@ -669,6 +687,12 @@ at::Tensor _convolution(
669687
output = at::miopen_depthwise_convolution(
670688
input.contiguous(), weight, bias,
671689
padding, stride, dilation, params.groups, params.benchmark, params.deterministic);
690+
#ifdef USE_VULKAN
691+
} else if (params.use_vulkan(input, weight)) {
692+
output = at::native::vulkan_convolution(
693+
input, weight, bias,
694+
params.padding, params.stride, params.dilation, params.groups);
695+
#endif
672696
} else {
673697
output = at::thnn_conv_depthwise2d(input.contiguous(), weight, kernel_size, bias, stride, padding, dilation);
674698
}
@@ -761,6 +785,12 @@ at::Tensor _convolution(
761785
bias,
762786
params.stride,
763787
params.padding);
788+
#ifdef USE_VULKAN
789+
} else if (params.use_vulkan(input, weight)) {
790+
output = at::native::vulkan_convolution(
791+
input, weight, bias,
792+
params.padding, params.stride, params.dilation, params.groups);
793+
#endif
764794
} else if (input.device().type() == c10::DeviceType::CPU || input.device().type() == c10::DeviceType::CUDA) {
765795
if (params.groups == 1) {
766796
output = at::_convolution_nogroup(

aten/src/ATen/native/Copy.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
#include <ATen/NamedTensorUtils.h>
1111
#include <torch/library.h>
1212

13+
#ifdef USE_VULKAN
14+
#include <ATen/native/vulkan/VulkanAten.h>
15+
#endif
1316
namespace {
1417

1518
using namespace at;
@@ -78,7 +81,7 @@ void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
7881
// (e.g. XLA) may be supported by overriding copy_ and _copy_from.
7982
bool is_supported_device(Device device) {
8083
DeviceType device_type = device.type();
81-
return device_type == kCPU || device_type == kCUDA || device_type == kHIP;
84+
return device_type == kCPU || device_type == kCUDA || device_type == kHIP || device_type == kVulkan;
8285
}
8386

8487
} // namespace
@@ -126,6 +129,12 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
126129
TORCH_CHECK(false, "Copying from quantized Tensor to non-quantized Tensor is not allowed, please use dequantize to get a float Tensor from a quantized Tensor");
127130
}
128131

132+
#ifdef USE_VULKAN
133+
if (self.device().type() == at::kVulkan || src.device().type() == at::kVulkan) {
134+
return vulkan_copy_(self, src);
135+
}
136+
#endif
137+
129138
auto iter = TensorIterator();
130139
iter.set_check_mem_overlap(true);
131140
iter.add_output(self);

aten/src/ATen/native/TensorConversions.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ static inline Tensor to_impl(const Tensor& self, const TensorOptions& options, b
2929
return self;
3030
}
3131

32+
if (options.device().type() == DeviceType::Vulkan
33+
|| self.device().type() == DeviceType::Vulkan) {
34+
auto r = at::empty(self.sizes(), options, c10::nullopt);
35+
r.copy_(self, non_blocking);
36+
return r;
37+
}
38+
3239
if (memory_format == MemoryFormat::Preserve) {
3340
if (self.is_non_overlapping_and_dense()) {
3441
// Copy all strides
@@ -62,6 +69,13 @@ Tensor to(
6269
"to(options) expects unset requires_grad flag, but got "
6370
"options.requires_grad set as ", options.requires_grad());
6471

72+
if (options.device().type() == DeviceType::Vulkan
73+
|| self.device().type() == DeviceType::Vulkan) {
74+
auto r = at::empty(self.sizes(), options, c10::nullopt);
75+
r.copy_(self, non_blocking);
76+
return r;
77+
}
78+
6579
TORCH_CHECK(!options.has_layout() || self.layout() == options.layout(),
6680
"to(options) doesn't support converting to a different layout, "
6781
"but got self.layout being ", self.layout(),

aten/src/ATen/native/native_functions.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@
326326
SparseCPU: add_sparse
327327
SparseCUDA: add_sparse
328328
MkldnnCPU: mkldnn_add
329+
Vulkan: vulkan_add
329330
supports_named_tensor: True
330331

331332
- func: add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
@@ -764,6 +765,7 @@
764765
CPU: clamp
765766
CUDA: clamp
766767
QuantizedCPU: quantized_clamp
768+
Vulkan: vulkan_clamp
767769

768770
- func: clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)
769771
supports_named_tensor: True
@@ -1183,6 +1185,7 @@
11831185
MkldnnCPU: empty_mkldnn
11841186
SparseCPU: empty_sparse
11851187
SparseCUDA: empty_sparse
1188+
Vulkan: empty_vulkan
11861189

11871190
- func: new_empty(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
11881191
variants: method
@@ -1923,6 +1926,7 @@
19231926
CPU: mean_cpu_gpu
19241927
CUDA: mean_cpu_gpu
19251928
QuantizedCPU: quantized_mean_cpu
1929+
Vulkan: mean_vulkan
19261930

19271931
- func: mean.out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
19281932
supports_named_tensor: True
@@ -2231,6 +2235,9 @@
22312235
CPU: batch_norm_update_stats_cpu
22322236
CUDA: batch_norm_update_stats_cuda
22332237

2238+
- func: is_vulkan_available() -> bool
2239+
use_c10_dispatcher: full
2240+
22342241
- func: _nnpack_available() -> bool
22352242
use_c10_dispatcher: full
22362243

@@ -3476,6 +3483,7 @@
34763483
CUDA: addmm_cuda
34773484
SparseCPU: addmm_sparse_dense_cpu
34783485
SparseCUDA: addmm_sparse_dense_cuda
3486+
Vulkan: vulkan_addmm
34793487
supports_named_tensor: True
34803488

34813489
- func: addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
@@ -5962,6 +5970,7 @@
59625970
CPU: hardtanh_
59635971
CUDA: hardtanh_
59645972
QuantizedCPU: quantized_hardtanh_
5973+
Vulkan: vulkan_hardtanh_
59655974

59665975
- func: hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
59675976
python_module: nn
@@ -6705,6 +6714,7 @@
67056714
CPU: upsample_nearest2d_cpu
67066715
CUDA: upsample_nearest2d_cuda
67076716
QuantizedCPU: quantized_upsample_nearest2d_cpu
6717+
Vulkan: upsample_nearest2d_vulkan
67086718

67096719
- func: upsample_nearest2d_backward.grad_input(Tensor grad_output, int[2] output_size, int[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)
67106720
python_module: nn

0 commit comments

Comments
 (0)