Skip to content

Commit

Permalink
[cutlassF] fix race condition (fairinternal/xformers#709)
Browse files Browse the repository at this point in the history
Co-authored-by: danthe3rd <danthe3rd>

__original_commit__ = fairinternal/xformers@bd904d1
  • Loading branch information
danthe3rd authored and xFormers Bot committed Jul 6, 2023
1 parent 61f757b commit 55a4798
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -747,14 +747,6 @@ class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
arch::OpMultiplyAddComplexFastF32>::value) {
accum = plus_accum(accum, tmp_accum);
}

if (SharedMemoryClear == SharedMemoryClearOption::kZfill) {
// commit and drain all pending and predicated cp.async pnz from the GEMM
// mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}
}
};

Expand Down
12 changes: 12 additions & 0 deletions xformers/csrc/attention/cuda/fmha/gemm/mma_from_smem.h
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,9 @@ class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory<
int thread_idx,
int problem_size_0_n) {}

CUTLASS_DEVICE
static void drain_cp_asyncs() {}

/// Perform a threadblock-scoped matrix multiply-accumulate
CUTLASS_DEVICE
void operator()(
Expand Down Expand Up @@ -921,6 +924,15 @@ class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
smem_iterator_B1);
}

CUTLASS_DEVICE
static void drain_cp_asyncs() {
// commit and drain all pending and predicated cp.async pnz from the GEMM
// mainloop
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<0>();
__syncthreads();
}

CUTLASS_DEVICE
void copy_tiles_and_advance_1(
IteratorB1& iterator_B1,
Expand Down
4 changes: 4 additions & 0 deletions xformers/csrc/attention/cuda/fmha/kernel_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,8 @@ struct AttentionKernel {

if (kPreloadV) {
prologueV(0);
} else {
MM1::Mma::drain_cp_asyncs();
}

typename MM0::Mma::Operator::IteratorC::TensorCoord
Expand Down Expand Up @@ -997,6 +999,7 @@ struct AttentionKernel {
}

if (!kKeepOutputInRF) {
MM1::Mma::drain_cp_asyncs();
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
Expand Down Expand Up @@ -1103,6 +1106,7 @@ struct AttentionKernel {
thread_id(),
warp_id(),
lane_id());
MM1::Mma::drain_cp_asyncs();
epilogue(rescale, dest_iter, accum_o);
}

Expand Down

0 comments on commit 55a4798

Please sign in to comment.