Skip to content

Commit 4a008d2

Browse files
drisspgpytorchmergebot
authored andcommitted
REDO of dropout support for mem eff pytorch#102038 (pytorch#103704)
THIS IS A new PR with the changes from pytorch#102038 + pytorch#103201 + plus namespacing changes to fix bug. # Summary This PR builds off of: - pytorch#101847 - pytorch#100583 It specifically adds dropout support to the memory efficient attention kernel. In the process of doing so roughly 3 changes were made: - Update sdpa dispatching to allow for inputs requiring grad to be sent to efficient attention - Update how memory efficient attention handles passing the rng state from forward to backward in order to enable cuda_graph support - Fix a bug in the kernel that was causing incorrect gradients to be produced for num_keys > 64 with dropout and causal masking set. facebookresearch/xformers#755 Pull Request resolved: pytorch#103704 Approved by: https://github.com/cpuhrsch
1 parent bfa08a1 commit 4a008d2

File tree

68 files changed

+506
-240
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+506
-240
lines changed

aten/src/ATen/cuda/detail/UnpackRaw.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace philox {
1515
// Easiest thing that comes to mind is, define a __device__ unpack helper here, in ATen/cuda.
1616
//
1717
// The raw definition lives in its own file so jit codegen can easily copy it.
18-
__device__ __forceinline__ std::tuple<uint64_t, uint64_t>
18+
__host__ __device__ __forceinline__ std::tuple<uint64_t, uint64_t>
1919
unpack(at::PhiloxCudaState arg) {
2020
if (arg.captured_) {
2121
// static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long".

aten/src/ATen/native/native_functions.yaml

+13-4
Original file line numberDiff line numberDiff line change
@@ -14176,14 +14176,17 @@
1417614176
dispatch:
1417714177
CUDA: _scaled_dot_product_flash_attention_backward_cuda
1417814178

14179-
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp)
14179+
- func: _scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)
1418014180
dispatch:
1418114181
CUDA: _scaled_dot_product_efficient_attention_cuda
1418214182
NestedTensorCUDA: _scaled_dot_product_efficient_attention_nestedtensor_cuda
14183+
tags: nondeterministic_seeded
1418314184

14184-
- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, bool is_causal=False, bool chunk_grad_outputs=False, *, float? scale=None) -> (Tensor, Tensor, Tensor)
14185+
- func: _scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor)
14186+
device_check: NoCheck
1418514187
dispatch:
1418614188
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
14189+
tags: nondeterministic_seeded
1418714190

1418814191
# THIS FUNCTION iS DEPRECATED AND SHOULD BE REMOVED
1418914192
- func: _chunk_grad_outputs_efficient_attention(Tensor query, Tensor key, Tensor value, bool is_causal=False) -> bool
@@ -14203,13 +14206,13 @@
1420314206
CUDA: _flash_attention_backward
1420414207

1420514208
# Returns ouput, logsumexp if compute_logsumexp
14206-
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp)
14209+
- func: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int? max_seqlen_q, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? causal_diagonal=None, Tensor? seqlen_k=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset)
1420714210
variants: function
1420814211
dispatch:
1420914212
CUDA: _efficient_attention_forward
1421014213
tags: nondeterministic_seeded
1421114214

14212-
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_k, int max_seqlen_q, Tensor logsumexp, float dropout_p, Tensor rng_seed, Tensor rng_offset, int custom_mask_type, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
14215+
- func: _efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, int max_seqlen_k, int max_seqlen_q, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, *, float? scale=None, int? num_splits_key=None) -> (Tensor, Tensor, Tensor, Tensor)
1421314216
device_check: NoCheck
1421414217
variants: function
1421514218
dispatch:
@@ -14219,7 +14222,13 @@
1421914222
variants: function
1422014223
dispatch:
1422114224
CUDA: triton_scaled_dot_attention
14225+
tags: nondeterministic_seeded
1422214226
autogen: _triton_scaled_dot_attention.out
14227+
14228+
- func: _fill_mem_eff_dropout_mask_(Tensor(a!) self, float dropout_p, int seed, int offset) -> Tensor(a!)
14229+
variants: function
14230+
dispatch:
14231+
CUDA: _fill_mem_eff_dropout_mask_
1422314232
tags: nondeterministic_seeded
1422414233

1422514234
- func: _triton_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None) -> Tensor

aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -738,12 +738,13 @@ _scaled_dot_product_flash_attention_nestedtensor_cuda(
738738
debug_attn_mask);
739739
}
740740

741-
std::tuple<Tensor, Tensor>
741+
std::tuple<Tensor, Tensor, Tensor, Tensor>
742742
_scaled_dot_product_efficient_attention_nestedtensor_cuda(
743743
const Tensor& query,
744744
const Tensor& key,
745745
const Tensor& value,
746746
bool compute_log_sumexp,
747+
double dropout_p,
747748
bool is_causal,
748749
c10::optional<double> scale) {
749750
Tensor query_buffer_reshaped, key_buffer_reshaped, value_buffer_reshaped,
@@ -763,23 +764,23 @@ _scaled_dot_product_efficient_attention_nestedtensor_cuda(
763764
? sdp::CustomMaskType::CausalFromTopLeft
764765
: sdp::CustomMaskType::NoCustomMask;
765766

766-
Tensor attention, log_sumexp;
767-
std::tie(attention, log_sumexp) = at::_efficient_attention_forward(
767+
// See Note [Seed and Offset] for description of seed and offset
768+
auto [attention, log_sumexp, seed, offset] = at::_efficient_attention_forward(
768769
query_buffer_reshaped.unsqueeze(0),
769770
key_buffer_reshaped.unsqueeze(0),
770771
value_buffer_reshaped.unsqueeze(0),
771772
c10::nullopt,
772773
cumulative_sequence_length_q,
773774
cumulative_sequence_length_kv,
774775
max_seqlen_batch_q,
775-
0.0 /*dropout_p*/,
776+
dropout_p,
776777
static_cast<int64_t>(custom_mask_type),
777778
compute_log_sumexp,
778779
scale);
779780

780781
// Reshape output to convert nnz to batch_size and seq_len
781782
attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2);
782-
return std::make_tuple(std::move(attention), std::move(log_sumexp));
783+
return std::make_tuple(std::move(attention), std::move(log_sumexp), std::move(seed), std::move(offset));
783784
}
784785

785786
} // namespace native

aten/src/ATen/native/transformers/attention.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ Tensor scaled_dot_product_attention(
615615
(query_.requires_grad() || key.requires_grad() ||
616616
value.requires_grad());
617617
auto out_and_lse = at::_scaled_dot_product_efficient_attention(
618-
query_, key, value, compute_logsumexp, is_causal, scale);
618+
query_, key, value, compute_logsumexp, dropout_p, is_causal, scale);
619619
return std::get<0>(out_and_lse);
620620
}
621621
case sdp::SDPBackend::math:
@@ -682,9 +682,12 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
682682
attn = at::softmax(attn, -1);
683683
if (dropout_p > 0.0) {
684684
if (dropout_mask.has_value()) {
685-
auto attn_dropout_masked = attn.masked_fill(dropout_mask->logical_not(), 0.0);
685+
// In order to validate the correctness of the fused kernels, we need to
686+
// use the same dropout mask in order to compare the results.
687+
TORCH_WARN_ONCE("Dropout mask should only be used for testing purposes.");
688+
attn = attn.masked_fill(dropout_mask->logical_not(), 0.0);
686689
auto dropout_scaling = 1.0 / (1 - dropout_p);
687-
return std::make_tuple(at::matmul(attn_dropout_masked, value * dropout_scaling), attn);
690+
return std::make_tuple(at::matmul(attn, value * dropout_scaling), attn);
688691
} else {
689692
attn = at::dropout(attn, dropout_p, true);
690693
}

0 commit comments

Comments
 (0)