Skip to content

Commit c4b7311

Browse files
drisspgpytorchmergebot
authored andcommitted
Meff Attn Bias (pytorch#104310)
# Summary ### Review Points - Automatically pad tensors to create aligned masks when seqlen_kv is not multiple of 16. This will cause memory spike ~ 2 * attn_mask size which could in theory be big. At appears though that doing this + mem_eff is faster than no_pad + math. SO seems to be worth it - Using expand to view the attn_mask in 4d. This is a little different to how we enforce q,k,v to be viewed in 4d prior to calling. Also not supprint b*n_heads, seq_lenq, seq_lenkv case. - Should enable, pytorch#96099 ### Profiling I ran a bunch of comparisons between sdpa.MATH and sdp.MemEffAttention. I added a attn_bias of shape (1, 1, seqlen_q, seqln_k). For these experiments seqlen_q == seqlen_k. These were all ran on an a100 80gb gpu. Configs: ``` # Run a bunch of experiments batch_sizes = [8, 16, 32] num_heads = [16, 32] max_seq_lens = [15, 64, 128, 512, 555, 1024] embed_dims = [32, 64, 128] dtypes = [torch.float16, torch.bfloat16, torch.float32] pad_percentages = [None] backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH] run_backward = True attn_mask = True ``` The function calls `sdpa(input**).sum().backward()`. I calculated the geomean speedup of the efficient attention path of the math path for all these configs: `Geomean Speedup: 1.977` An example comparision with batchsize = 8, num_heads = 32, embed_dim = 64, and dtype = torch.float16: ![attn_mask_compare_bsz_8_num_heads_32_embed_dim_64_dtype_fp16](https://github.com/pytorch/pytorch/assets/32754868/0d75bffe-350b-43f2-a37f-514f9158dcff) This was done using the current state of the branch where we force alignment of mask when the last dim is not divisible by 16, which shows up in seq_len = 15 and 555 case. The full data can be found here: [attn_mask_sweep.csv](https://github.com/pytorch/pytorch/files/11962399/attn_mask_sweep.csv) Pull Request resolved: pytorch#104310 Approved by: https://github.com/cpuhrsch
1 parent 45322fa commit c4b7311

File tree

13 files changed

+321
-105
lines changed

13 files changed

+321
-105
lines changed

aten/src/ATen/native/native_functions.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -14179,13 +14179,13 @@
1417914179
dispatch:
1418014180
CUDA: _scaled_dot_product_flash_attention_backward_cuda
1418114181

14182-
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
14182+
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
1418314183
dispatch:
1418414184
CUDA: _scaled_dot_product_efficient_attention_cuda
1418514185
NestedTensorCUDA: _scaled_dot_product_efficient_attention_nestedtensor_cuda
1418614186
tags: nondeterministic_seeded
1418714187

14188-
- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor)
14188+
- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor)
1418914189
device_check: NoCheck
1419014190
dispatch:
1419114191
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
@@ -14210,7 +14210,7 @@
1421014210
CUDA: _efficient_attention_forward
1421114211
tags: nondeterministic_seeded
1421214212

14213-
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_k, int max_seqlen_q, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
14213+
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_k, int max_seqlen_q, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
1421414214
device_check: NoCheck
1421514215
variants: function
1421614216
dispatch:

aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,7 @@ _scaled_dot_product_efficient_attention_nestedtensor_cuda(
743743
const Tensor& query,
744744
const Tensor& key,
745745
const Tensor& value,
746+
const c10::optional<at::Tensor>& attn_bias,
746747
bool compute_log_sumexp,
747748
double dropout_p,
748749
bool is_causal,

aten/src/ATen/native/transformers/attention.cpp

+66-16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <type_traits>
2+
#include <limits>
23
#include <c10/core/DeviceType.h>
34
#include <ATen/ATen.h>
45
#include <ATen/AccumulateType.h>
@@ -10,13 +11,14 @@
1011
#include <ATen/cpu/vec/vec256/vec256.h>
1112
#include <ATen/native/transformers/attention.h>
1213
#include <ATen/native/transformers/sdp_utils_cpp.h>
13-
#include <type_traits>
1414
#include <utility>
15+
#include <c10/util/typeid.h>
1516
#include <c10/core/SymIntArrayRef.h>
1617
#include <c10/util/Logging.h>
1718
#include <c10/util/Exception.h>
1819
#include <c10/core/DispatchKey.h>
1920
#include <c10/core/DispatchKeySet.h>
21+
#include <ATen/TensorSubclassLikeUtils.h>
2022

2123
#ifndef AT_PER_OPERATOR_HEADERS
2224
#include <ATen/NativeFunctions.h>
@@ -509,6 +511,7 @@ int64_t _fused_sdp_choice_meta(
509511
}
510512
return static_cast<int64_t>(sdp::SDPBackend::math);
511513
}
514+
namespace {
512515

513516
inline void validate_sdpa_input(
514517
const Tensor& query_,
@@ -535,9 +538,53 @@ inline void validate_sdpa_input(
535538
TORCH_CHECK(mask_dtype == at::kBool || mask_dtype == query_.dtype(),
536539
"Expected attn_mask dtype to be bool or to match query dtype, but got attn_mask.dtype: ",
537540
mask_dtype, " and query.dtype: ", query_.dtype(), " instead.");
541+
TORCH_CHECK(
542+
!query_.is_nested() && !key.is_nested(),
543+
"Scaled_dot_product_attention: Nested tensors for query / key are not supported "
544+
"when an explicit attn_mask is set");
538545
}
539546
return;
540547
}
548+
// This function is used to produce an attn_mask
549+
// in a standard format that can be consumed by both
550+
// the math and memory efficient attn_mask implementation
551+
// Args:
552+
// attn_mask: attn_mask of shape (B, L, S) or (L, S) or (B, N_heads, L, S)
553+
c10::optional<Tensor> convert_boolean_attn_mask(const c10::optional<Tensor>& attn_mask, caffe2::TypeMeta dtype) {
554+
// Pass through
555+
if(!attn_mask.has_value()){
556+
return c10::nullopt;
557+
}
558+
// Convert boolean mask to additive mask; need to invert mask to indicate what
559+
// to mask *out*.
560+
if (attn_mask->dtype() == at::kBool) {
561+
auto new_attn_mask = at::zeros_like(attn_mask.value(), dtype);
562+
// TODO Use the max type of the input and output
563+
new_attn_mask.masked_fill_(
564+
attn_mask->logical_not(), -std::numeric_limits<double>::infinity());
565+
return new_attn_mask;
566+
}
567+
// Otherwise, attn_mask represents an additive attention tensor
568+
return attn_mask;
569+
}
570+
// Memory Efficient Attention requires a padded attn mask bias
571+
// This function pads the attn_mask bias to be a multiple of 16
572+
// Then slices the padded bias to the original size
573+
// We apply this function to the top level SDPA so that
574+
// if padding is done it will be tracked for backward automatically
575+
at::Tensor pad_bias(const at::Tensor& attn_bias) {
576+
int align_to = 16;
577+
auto last_dim_size = attn_bias.sym_size(-1);
578+
if (last_dim_size % align_to == 0) {
579+
return attn_bias;
580+
}
581+
auto pad_count = align_to - (last_dim_size % align_to);
582+
auto padded_bias = at::pad_symint(attn_bias, {c10::SymInt(0), pad_count});
583+
return padded_bias.slice_symint(-1, 0, last_dim_size);
584+
}
585+
586+
} // namespace
587+
541588
// Computes scaled dot product attention on query, key and value tensors, using
542589
// an optional attention mask if passed, and applying dropout if a probability
543590
// greater than 0.0 is specified.
@@ -581,6 +628,7 @@ Tensor scaled_dot_product_attention(
581628
query_, key, value, attn_mask_, dropout_p, is_causal, scale);
582629
}
583630
sdp::SDPBackend backend = static_cast<sdp::SDPBackend>(choice_int);
631+
c10::optional<Tensor> attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype());
584632
switch (backend) {
585633
case sdp::SDPBackend::flash_attention: {
586634
auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
@@ -591,16 +639,25 @@ Tensor scaled_dot_product_attention(
591639
bool compute_logsumexp =
592640
(query_.requires_grad() || key.requires_grad() ||
593641
value.requires_grad());
642+
if (attn_mask.has_value()) {
643+
// Expand to 4d case
644+
attn_mask = attn_mask.value().expand_symint(
645+
{query_.sym_size(0),
646+
query_.sym_size(1),
647+
query_.sym_size(2),
648+
key.sym_size(2)});
649+
attn_mask = pad_bias(attn_mask.value());
650+
}
594651
auto out_and_lse = at::_scaled_dot_product_efficient_attention(
595-
query_, key, value, compute_logsumexp, dropout_p, is_causal, scale);
652+
query_, key, value, attn_mask, compute_logsumexp, dropout_p, is_causal, scale);
596653
return std::get<0>(out_and_lse);
597654
}
598655
case sdp::SDPBackend::math:
599656
return std::get<0>(at::_scaled_dot_product_attention_math(
600657
query_,
601658
key,
602659
value,
603-
attn_mask_,
660+
attn_mask,
604661
dropout_p,
605662
is_causal,
606663
c10::nullopt, /*dropout_mask*/
@@ -639,22 +696,15 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
639696
// Replace attn_mask with causal mask; lower triangular elements take part in attention.
640697
const auto L = query.sym_size(-2), S = key.sym_size(-2);
641698
attn_mask = at::ones_symint({L, S}, query.options().dtype(at::kBool)).tril();
699+
attn_mask = convert_boolean_attn_mask(attn_mask, query.dtype());
642700
}
701+
auto attn = at::matmul(query, key.transpose(-2, -1) * scaling_factor);
643702
if (attn_mask.has_value()) {
644-
TORCH_CHECK(!query.is_nested() && !key.is_nested(),
645-
"_scaled_dot_product_attention: Nested tensors for query / key are not supported "
646-
"when an explicit attn_mask is set");
647-
// Convert boolean mask to additive mask; need to invert mask to indicate what to mask *out*.
648-
if (attn_mask->dtype() == at::kBool){
649-
auto new_attn_mask = at::zeros_like(*attn_mask, query.dtype());
650-
new_attn_mask.masked_fill_(attn_mask->logical_not(), -std::numeric_limits<double>::infinity());
651-
attn_mask = new_attn_mask;
652-
}
653-
// Otherwise, attn_mask represents an additive attention tensor
654-
}
655-
auto attn = at::matmul(query, key.transpose(-2, -1)*scaling_factor);
656-
if (attn_mask.has_value()) {
703+
if (at::areAnyTensorSubclassLike({attn, *attn_mask})) {
704+
attn = attn.add(*attn_mask);
705+
} else {
657706
attn.add_(*attn_mask);
707+
}
658708
}
659709
attn = at::softmax(attn, -1);
660710
if (dropout_p > 0.0) {

aten/src/ATen/native/transformers/cuda/attention.cu

+15-5
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,11 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
524524
// strides from packed projection for nested tensors when seq_len is 1 will be
525525
// and will trigger a contiguous call in the kernel, so we prevent this
526526
bool no_seq_len_1_nested = query.is_nested() ? check_for_seq_len_1_nested_tensor(kernel_params, false) : true;
527-
if (no_seq_len_1_nested &&
527+
// The API for transfomer_encoder is a mask of shape (Batch_Size, Seq_len_q)
528+
// For mem-eff attention this will cause the expand call to error
529+
// For now I am going to turn of that path not have to deal with all the annoying
530+
// Mask type shape grossness
531+
if (!mask.has_value() && no_seq_len_1_nested &&
528532
(backend == sdp::SDPBackend::flash_attention || backend == sdp::SDPBackend::efficient_attention)) {
529533
auto x = at::linear(query, qkv_weight, qkv_bias);
530534
auto chunks = x.chunk(3, -1);
@@ -536,7 +540,6 @@ std::tuple<Tensor, Tensor> native_multi_head_attention_cuda(
536540
.transpose(1, 2);
537541
chunks[2] = (chunks[2].view({x_size_0, -1, num_head, dim_per_head}))
538542
.transpose(1, 2);
539-
540543
auto y = at::scaled_dot_product_attention(
541544
chunks[0], chunks[1], chunks[2], mask, 0.0, false, c10::nullopt);
542545

@@ -712,6 +715,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
712715
const Tensor& query,
713716
const Tensor& key,
714717
const Tensor& value,
718+
const c10::optional<at::Tensor>& attn_bias,
715719
bool compute_log_sumexp,
716720
double dropout_p,
717721
bool is_causal,
@@ -733,7 +737,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
733737
q_t,
734738
k_t,
735739
v_t,
736-
c10::nullopt,
740+
attn_bias,
737741
c10::nullopt,
738742
c10::nullopt,
739743
c10::nullopt,
@@ -1045,8 +1049,14 @@ std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
10451049

10461050
// assign strides for bias, viewed as
10471051
// (batch_sz, n_heads, n_queries, n_keys)
1048-
const at::Tensor bias_4d_view =
1049-
get_bias_4d_view(*bias, B, num_heads, M, N);
1052+
// We make sure to expand prior to calling the kernel
1053+
const at::Tensor& bias_4d_view = *bias;
1054+
TORCH_CHECK(bias_4d_view.dim()==4);
1055+
TORCH_CHECK(bias_4d_view.size(0)==B);
1056+
TORCH_CHECK(bias_4d_view.size(1)==num_heads);
1057+
TORCH_CHECK(bias_4d_view.size(2)==M);
1058+
TORCH_CHECK(bias_4d_view.size(3)==N);
1059+
10501060
ASSIGN_CHECK_OVERFLOW(p.bias_strideB, bias_4d_view.stride(0));
10511061
ASSIGN_CHECK_OVERFLOW(p.bias_strideH, bias_4d_view.stride(1));
10521062
ASSIGN_CHECK_OVERFLOW(p.bias_strideM, bias_4d_view.stride(2));

aten/src/ATen/native/transformers/cuda/attention_backward.cu

+31-15
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ _efficient_attention_backward(
115115
const at::Tensor& philox_seed, // seed using for generating random numbers for dropout
116116
const at::Tensor& philox_offset, // offset into random number sequence
117117
int64_t custom_mask_type,
118+
const bool bias_requires_grad,
118119
const c10::optional<double> scale,
119120
c10::optional <int64_t> num_splits_key) {
120121
#if defined(USE_FLASH_ATTENTION)
@@ -187,8 +188,6 @@ _efficient_attention_backward(
187188
int64_t K = query.size(3);
188189
int64_t Kv = value.size(3);
189190

190-
const bool bias_requires_grad = bias.has_value() && bias->requires_grad();
191-
192191
at::Tensor grad_q, grad_k, grad_v, grad_bias;
193192
grad_q = at::empty(query.sizes(), query.options());
194193
grad_k = at::empty(key.sizes(), key.options());
@@ -344,7 +343,13 @@ _efficient_attention_backward(
344343

345344
// assign strides for bias, viewed as:
346345
// (batch_sz, n_heads, n_queries, n_keys)
347-
const at::Tensor bias_4d_view = get_bias_4d_view(*bias, B, nH, M, N);
346+
// We make sure to expand prior to calling the kernel
347+
const at::Tensor& bias_4d_view = *bias;
348+
TORCH_CHECK(bias_4d_view.dim()==4);
349+
TORCH_CHECK(bias_4d_view.size(0)==B);
350+
TORCH_CHECK(bias_4d_view.size(1)==nH);
351+
TORCH_CHECK(bias_4d_view.size(2)==M);
352+
TORCH_CHECK(bias_4d_view.size(3)==N);
348353
ASSIGN_CHECK_OVERFLOW(p.bias_strideB, bias_4d_view.stride(0));
349354
ASSIGN_CHECK_OVERFLOW(p.bias_strideH, bias_4d_view.stride(1));
350355
ASSIGN_CHECK_OVERFLOW(p.bias_strideM, bias_4d_view.stride(2));
@@ -359,8 +364,14 @@ _efficient_attention_backward(
359364
// different values of Q will point to the same memory
360365
// locations, meaning bias.stride(1) == 0, while we'd want
361366
// grad_bias.stride(1) == nK
362-
const at::Tensor grad_bias_4d_view =
363-
get_bias_4d_view(grad_bias, B, nH, M, N);
367+
// We have expanded the input prior to calling the forward kernel
368+
const at::Tensor& grad_bias_4d_view = grad_bias;
369+
TORCH_CHECK(grad_bias_4d_view.dim()==4);
370+
TORCH_CHECK(grad_bias_4d_view.size(0)==B);
371+
TORCH_CHECK(grad_bias_4d_view.size(1)==nH);
372+
TORCH_CHECK(grad_bias_4d_view.size(2)==M);
373+
TORCH_CHECK(grad_bias_4d_view.size(3)==N);
374+
364375
ASSIGN_CHECK_OVERFLOW(p.gB_strideB, grad_bias_4d_view.stride(0));
365376
ASSIGN_CHECK_OVERFLOW(p.gB_strideH, grad_bias_4d_view.stride(1));
366377
ASSIGN_CHECK_OVERFLOW(p.gB_strideM, grad_bias_4d_view.stride(2));
@@ -531,20 +542,23 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_flash_attenti
531542
}
532543

533544

534-
std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_efficient_attention_backward_cuda(
545+
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_efficient_attention_backward_cuda(
535546
const at::Tensor& grad_out_,
536547
const at::Tensor& query,
537548
const at::Tensor& key,
538549
const at::Tensor& value,
550+
const at::Tensor& attn_bias,
539551
const at::Tensor& out,
540552
const at::Tensor& logsumexp,
541553
const at::Tensor& philox_seed,
542554
const at::Tensor& philox_offset,
543555
double dropout_p,
556+
std::array<bool, 4> grad_input_mask,
544557
bool causal,
545-
c10::optional<double> scale){
558+
c10::optional<double> scale) {
559+
546560
if (!grad_out_.defined()) {
547-
return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
561+
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
548562
}
549563
auto grad_out = grad_out_.transpose(1, 2);
550564
auto out_t = out.transpose(1, 2);
@@ -554,10 +568,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_efficient_att
554568

555569
Tensor grad_q, grad_k, grad_v, grad_bias;
556570

557-
// TODO_DRISS
558-
// These are place holders unitl we add support for bias
559-
auto bias = c10::nullopt;
560-
571+
// This is needed because SaveVarible automatically converts
572+
// c10::optional to undefined tensor
573+
c10::optional<Tensor> kernel_bias;
574+
if (attn_bias.defined()) {
575+
kernel_bias = attn_bias;
576+
}
561577
// Will add with signauter changes for dropout and bias
562578
// We are only handiling Dense inputs, but this should be passed
563579
// from forward to backward
@@ -567,14 +583,13 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_efficient_att
567583
sdp::CustomMaskType custom_mask_type = causal
568584
? sdp::CustomMaskType::CausalFromTopLeft
569585
: sdp::CustomMaskType::NoCustomMask;
570-
571586
std::tie(grad_q, grad_k, grad_v, grad_bias) =
572587
at::_efficient_attention_backward(
573588
grad_out,
574589
q_t,
575590
k_t,
576591
v_t,
577-
bias,
592+
kernel_bias,
578593
out_t,
579594
c10::nullopt,
580595
c10::nullopt,
@@ -585,10 +600,11 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_efficient_att
585600
philox_seed,
586601
philox_offset,
587602
static_cast<int64_t>(custom_mask_type),
603+
grad_input_mask[3],
588604
scale,
589605
c10::nullopt); // num_split_keys
590606
return std::make_tuple(
591-
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2));
607+
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias);
592608
}
593609

594610
} // namespace native

0 commit comments

Comments
 (0)