You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# Summary
### Review Points
- Automatically pad tensors to create aligned masks when seqlen_kv is not multiple of 16. This will cause memory spike ~ 2 * attn_mask size which could in theory be big. At appears though that doing this + mem_eff is faster than no_pad + math. SO seems to be worth it
- Using expand to view the attn_mask in 4d. This is a little different to how we enforce q,k,v to be viewed in 4d prior to calling. Also not supprint b*n_heads, seq_lenq, seq_lenkv case.
- Should enable, pytorch#96099
### Profiling
I ran a bunch of comparisons between sdpa.MATH and sdp.MemEffAttention. I added a attn_bias of shape (1, 1, seqlen_q, seqln_k). For these experiments seqlen_q == seqlen_k. These were all ran on an a100 80gb gpu.
Configs:
```
# Run a bunch of experiments
batch_sizes = [8, 16, 32]
num_heads = [16, 32]
max_seq_lens = [15, 64, 128, 512, 555, 1024]
embed_dims = [32, 64, 128]
dtypes = [torch.float16, torch.bfloat16, torch.float32]
pad_percentages = [None]
backends = [SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]
run_backward = True
attn_mask = True
```
The function calls `sdpa(input**).sum().backward()`.
I calculated the geomean speedup of the efficient attention path of the math path for all these configs:
`Geomean Speedup: 1.977`
An example comparision with batchsize = 8, num_heads = 32, embed_dim = 64, and dtype = torch.float16:

This was done using the current state of the branch where we force alignment of mask when the last dim is not divisible by 16, which shows up in seq_len = 15 and 555 case.
The full data can be found here:
[attn_mask_sweep.csv](https://github.com/pytorch/pytorch/files/11962399/attn_mask_sweep.csv)
Pull Request resolved: pytorch#104310
Approved by: https://github.com/cpuhrsch
0 commit comments