Skip to content

Commit

Permalink
Fix race condition in bwd (overwriting sK)
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Aug 1, 2023
1 parent a4e5d1e commit 1c41d2b
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 18 deletions.
8 changes: 5 additions & 3 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -1020,9 +1020,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)

// If we don't need syncthreads here since we're writing to the same location as sK and sV.
// Unless Is_V_in_regs. If Is_last, there's already a __syncthreads() at the end of the loop.
if (Kernel_traits::Is_V_in_regs && !Is_last) { __syncthreads(); }
// We need syncthreads here since we're writing to the same location as sK and sV.
// Without syncthreads, some thread might modify the location of sK while another thread
// is reading it for dQ gemm, leading to a race condition.
// If Is_last, there's already a __syncthreads() at the end of the loop.
if (!Is_last) { __syncthreads(); }

copy(smem_thr_copy_dKV, taccdKrdK, taccdKsdK);
copy(smem_thr_copy_dKV, taccdVrdV, taccdVsdV);
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def append_nvcc_threads(nvcc_extra_args):
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v",
# "--ptxas-options=-O2",
"-lineinfo"
]
+ generator_flag
Expand Down
35 changes: 20 additions & 15 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,44 +785,49 @@ def test_flash_attn_varlen_output(seqlen_q, seqlen_k, d, dropout_p, causal, mha_

# @pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
@pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [True])
# @pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('causal', [False])
# @pytest.mark.parametrize('d', [32, 56, 64, 80, 96, 128])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
@pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('d', [128])
# @pytest.mark.parametrize('seqlen', [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048])
@pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
# @pytest.mark.parametrize('seqlen', [193])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@pytest.mark.parametrize('seqlen', [128])
# @pytest.mark.parametrize('dropout_p', [0.0, 0.17])
@pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
if seqlen >= 2048 and torch.cuda.get_device_properties('cuda').total_memory <= 16 * 2**30:
pytest.skip() # Reference implementation OOM
device = 'cuda'
# set seed
torch.random.manual_seed(0)
batch_size = 32
batch_size = 60 # Sometimes we need large batch size for the race conditions to trigger
nheads = 4
qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True)
qkv = torch.randn(batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype,
requires_grad=True)
out0, lse0, _ = flash_attn_qkvpacked_func(
qkv, dropout_p, return_attn_probs=True, causal=causal
)
g = torch.randn_like(out0)
dqkv0, = torch.autograd.grad(out0, qkv, g)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dqkv0, = torch.autograd.grad(out0, qkv, g)
# Numerical error if we just do any arithmetic on dq
dq_atol = 2 * ((dqkv0[:, :, 0] + 0.3 - 0.3) - dqkv0[:, :, 0]).abs().max().item()

for _ in range(200):
for i in range(200):
torch.random.manual_seed(0)
out, lse, S_dmask = flash_attn_qkvpacked_func(
qkv, dropout_p, return_attn_probs=True, causal=causal
)
assert torch.equal(out, out0)
assert torch.equal(lse, lse0)
# sm_lse has some parts that are uninitialized from torch.empty
# assert torch.equal(sm_lse, sm_lse_0)

if not (is_sm75 and d == 128):
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
dqkv, = torch.autograd.grad(out, qkv, g)
assert torch.equal(dqkv[:, :, 0], dqkv0[:, :, 0])
dq_equal = torch.allclose(dqkv[:, :, 0], dqkv0[:, :, 0], atol=dq_atol)
if not dq_equal:
dq0 = dqkv0[:, :, 0]
dq = dqkv[:, :, 0]
print(f'Iter {i}, {dq_atol = }, dQ max diff: {(dqkv[:, :, 0] - dqkv0[:, :, 0]).abs().max().item()}')
assert dq_equal
assert torch.equal(dqkv[:, :, 1], dqkv0[:, :, 1])
assert torch.equal(dqkv[:, :, 2], dqkv0[:, :, 2])

Expand Down

0 comments on commit 1c41d2b

Please sign in to comment.