Skip to content

Commit

Permalink
Use CUTLASS GEMM for NT bmm [OSS-only] (pytorch#85894)
Browse files Browse the repository at this point in the history
OSS-only copy of pytorch#85710
Pull Request resolved: pytorch#85894
Approved by: https://github.com/drisspg
  • Loading branch information
cpuhrsch authored and pytorchmergebot committed Oct 12, 2022
1 parent 73c43ce commit ef58a13
Show file tree
Hide file tree
Showing 10 changed files with 390 additions and 40 deletions.
2 changes: 2 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ cu_library(
"@cuda//:cublas",
"@cuda//:cufft",
"@cuda//:cusparse",
"@cutlass",
],
alwayslink = True,
)
Expand Down Expand Up @@ -1673,6 +1674,7 @@ cc_library(
] + if_cuda([
":torch_distributed_cuda",
"@cuda//:nvToolsExt",
"@cutlass",
]),
alwayslink = True,
)
Expand Down
6 changes: 6 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ new_local_repository(
path = "third_party/eigen",
)

new_local_repository(
name = "cutlass",
build_file = "//third_party:cutlass.BUILD",
path = "third_party/cutlass",
)

new_local_repository(
name = "fbgemm",
build_file = "//third_party:fbgemm/BUILD.bazel",
Expand Down
4 changes: 1 addition & 3 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,7 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
endif()

if(USE_CUDA AND NOT USE_ROCM)
if(USE_FLASH_ATTENTION)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
endif()
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
if($ENV{ATEN_STATIC_CUDA})
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
${CUDA_LIBRARIES}
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1174,7 +1174,8 @@
dispatch:
SparseCPU: bmm_sparse_cpu
SparseCUDA: bmm_sparse_cuda
NestedTensorCPU, NestedTensorCUDA: bmm_nested
NestedTensorCPU: bmm_nested
NestedTensorCUDA: bmm_nested_cuda
tags: canonical

- func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
Expand Down
22 changes: 0 additions & 22 deletions aten/src/ATen/native/nested/NestedTensorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,27 +101,5 @@ std::vector<Tensor> chunk_nested_tensor(const Tensor& self, int64_t chunks, int6
return splits;
}

std::vector<IntArrayRef> NestedTensor_get_sizes(
const NestedTensorImpl* self_ptr) {
int64_t ntensors = self_ptr->size(0);
std::vector<IntArrayRef> sizes(ntensors);
if (ntensors == 0) {
return sizes;
}
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
int64_t orig_dim = sizemat.size(1);
// nesting scalars has empty sizes
if (orig_dim == 0) {
return sizes;
}
const int64_t* sizemat_ptr = sizemat.data_ptr<int64_t>();

for (const auto i : c10::irange(ntensors)) {
sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
sizemat_ptr += orig_dim;
}
return sizes;
}

} // namespace native
} // namespace at
24 changes: 22 additions & 2 deletions aten/src/ATen/native/nested/NestedTensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,28 @@ inline at::Tensor create_nested_view_tensor(
int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt);

// The sizes of the underlying tensors
std::vector<IntArrayRef> NestedTensor_get_sizes(
const NestedTensorImpl* self_ptr);
inline std::vector<IntArrayRef> NestedTensor_get_sizes(
const NestedTensorImpl* self_ptr) {
int64_t ntensors = self_ptr->size(0);
std::vector<IntArrayRef> sizes(ntensors);
if (ntensors == 0) {
return sizes;
}
const Tensor& sizemat = self_ptr->get_nested_size_tensor();
int64_t orig_dim = sizemat.size(1);
// nesting scalars has empty sizes
if (orig_dim == 0) {
return sizes;
}
const int64_t* sizemat_ptr = sizemat.data_ptr<int64_t>();

for (const auto i : c10::irange(ntensors)) {
sizes[i] = IntArrayRef(sizemat_ptr, sizemat_ptr + orig_dim);
sizemat_ptr += orig_dim;
}
return sizes;
}


TORCH_API std::vector<int64_t> NestedTensor_get_max_size(
const NestedTensorImpl& nt);
Expand Down
Loading

0 comments on commit ef58a13

Please sign in to comment.