Skip to content

Commit

Permalink
Revert "Flash Attention v2 (pytorch#105602)" (pytorch#108827)
Browse files Browse the repository at this point in the history
This reverts commit add45ae.

There are some conflicts on some benchmark csv file pytorch#105602 (comment) so I need to revert this manually.

The diff has been reverted internally.

Pull Request resolved: pytorch#108827
Approved by: https://github.com/kit1980
  • Loading branch information
huydhn authored and pytorchmergebot committed Sep 8, 2023
1 parent 8391e3f commit 24e9bbe
Show file tree
Hide file tree
Showing 89 changed files with 7,946 additions and 6,226 deletions.
8 changes: 0 additions & 8 deletions .ci/pytorch/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,6 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* && -z "$TORCH_CUDA_ARCH_LIST" ]]; then
exit 1
fi

# We only build FlashAttention files for CUDA 8.0+, and they require large amounts of
# memory to build and will OOM
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && [[ "$TORCH_CUDA_ARCH_LIST" == *"8.6"* || "$TORCH_CUDA_ARCH_LIST" == *"8.0"* ]]; then
echo "WARNING: FlashAttention files require large amounts of memory to build and will OOM"
echo "Setting MAX_JOBS=(nproc-2)/3 to reduce memory usage"
export MAX_JOBS="$(( $(nproc --ignore=2) / 3 ))"
fi

if [[ "${BUILD_ENVIRONMENT}" == *clang* ]]; then
export CC=clang
export CXX=clang++
Expand Down
4 changes: 2 additions & 2 deletions .circleci/scripts/binary_populate_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ EOL

# nproc doesn't exist on darwin
if [[ "$(uname)" != Darwin ]]; then
# This was lowered from 18 to 12 to avoid OOMs when compiling FlashAttentionV2
MEMORY_LIMIT_MAX_JOBS=12
# Because most Circle executors only have 20 CPUs, using more causes OOMs w/ Ninja and nvcc parallelization
MEMORY_LIMIT_MAX_JOBS=18
NUM_CPUS=$(( $(nproc) - 2 ))

# Defaults here for **binary** linux builds so they can be changed in one place
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ include(cmake/Dependencies.cmake)
cmake_dependent_option(
USE_FLASH_ATTENTION
"Whether to build the flash_attention kernel for scaled dot product attention" ON
"USE_CUDA AND NOT ROCM AND NOT MSVC AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
"USE_CUDA AND NOT ROCM AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)

# Flash Attention2 will error while building for sm52 while Mem Eff Attention won't
cmake_dependent_option(
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ file(GLOB native_utils_cpp "native/utils/*.cpp")

# flash_attention sources
file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")

#Mem_eff attention sources
Expand All @@ -170,7 +169,6 @@ file(GLOB mem_eff_attention_cuda_cpp "native/transformers/cuda/mem_eff_attention

if(USE_FLASH_ATTENTION)
list(APPEND native_transformers_cuda_cu ${flash_attention_cuda_cu})
list(APPEND native_transformers_cuda_cu ${flash_attention_cuda_kernels_cu})
list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp})
endif()

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14278,7 +14278,7 @@
variants: function
tags: nondeterministic_seeded

- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor ouput, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
dispatch:
CPU: _scaled_dot_product_flash_attention_cpu
CUDA: _scaled_dot_product_flash_attention_cuda
Expand All @@ -14304,7 +14304,7 @@
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
tags: nondeterministic_seeded

- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, int? max_q, int? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
variants: function
dispatch:
CUDA: _flash_attention_forward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,16 @@ inline auto sdpa_nested_preprocessing(

} // namespace

std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t, int64_t, Tensor, Tensor, Tensor>
std::tuple<
Tensor,
Tensor,
Tensor,
Tensor,
int64_t,
int64_t,
Tensor,
Tensor,
Tensor>
_scaled_dot_product_flash_attention_nestedtensor_cuda(
const Tensor& query,
const Tensor& key,
Expand All @@ -701,12 +710,8 @@ _scaled_dot_product_flash_attention_nestedtensor_cuda(
max_seqlen_batch_kv,
output_shape) = sdpa_nested_preprocessing(query, key, value);

auto
[attention,
logsumexp,
philox_seed,
philox_offset,
debug_attn_mask] =
Tensor attention, log_sumexp, debug_attn_mask, philox_seed, philox_offset;
std::tie(attention, log_sumexp, philox_seed, philox_offset, debug_attn_mask) =
at::_flash_attention_forward(
query_buffer_reshaped,
key_buffer_reshaped,
Expand All @@ -723,7 +728,7 @@ _scaled_dot_product_flash_attention_nestedtensor_cuda(
attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2);
return std::make_tuple(
attention,
logsumexp,
log_sumexp,
cumulative_sequence_length_q,
cumulative_sequence_length_kv,
max_seqlen_batch_q,
Expand Down
44 changes: 0 additions & 44 deletions aten/src/ATen/native/transformers/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include <ATen/native/transformers/sdp_utils_cpp.h>
#include <utility>
#include <c10/util/typeid.h>
#include <c10/core/SymInt.h>
#include <c10/core/SymIntArrayRef.h>
#include <c10/util/Logging.h>
#include <c10/util/Exception.h>
Expand Down Expand Up @@ -630,37 +629,6 @@ at::Tensor preprocess_mask(

return attn_mask;
}
// FlashAttentionV2 requires that head dimension be a multiple of 8
// This was previously done within the kernel, however
// This causes the kernel to maybe alias query, key, value
// So instead we pad the head_dimensions to be a multiple of 8 in the composite
// region
template <int alignment_size, bool slice>
at::Tensor pad_last_dim(const at::Tensor& attn_bias) {
auto last_dim_size = attn_bias.sym_size(-1);
if (last_dim_size % alignment_size == 0) {
return attn_bias;
}
auto pad_count = alignment_size - (last_dim_size % alignment_size);
auto padded_bias = at::pad_symint(attn_bias, {c10::SymInt(0), pad_count});
if (slice) {
return padded_bias.slice_symint(-1, 0, last_dim_size);
}
return padded_bias;
}

at::Tensor post_process_flash_output(
at::Tensor out,
c10::SymInt const& og_size) {
if (!out.is_nested()) {
out = out.slice_symint(-1, 0, og_size);
} else {
TORCH_CHECK(
out.size(-1) == og_size,
"FlashAttentionV2 returned a nested tensor with an incorrect size")
}
return out;
}

} // namespace

Expand Down Expand Up @@ -711,18 +679,6 @@ Tensor scaled_dot_product_attention(
c10::optional<Tensor> attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype());
switch (backend) {
case sdp::SDPBackend::flash_attention: {
if(query_.device().type() == DeviceType::CUDA){
c10::SymInt og_size = query_.sym_size(-1);
Tensor query_padded = pad_last_dim<8, false>(query_);
Tensor key_padded = pad_last_dim<8, false>(key);
Tensor value_padded = pad_last_dim<8, false>(value);
// We need to calculate the scale based off the OG head dim size
auto og_scale = sdp::calculate_scale(query_, scale);
auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
query_padded, key_padded, value_padded, dropout_p, is_causal, false /*return_debug_mask*/, og_scale.as_float_unchecked());
return post_process_flash_output(std::get<0>(out_lse_softmax), og_size);
}
// For the CPU case we do not need to pad the last dim
auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
query_, key, value, dropout_p, is_causal, false /*return_debug_mask*/, scale);
return std::get<0>(out_lse_softmax);
Expand Down
Loading

0 comments on commit 24e9bbe

Please sign in to comment.