Skip to content

Commit

Permalink
bump cutlass to v2.11.0 for contrib/multihead_attn (NVIDIA#1570)
Browse files Browse the repository at this point in the history
Signed-off-by: Masaki Kozuki <[email protected]>
  • Loading branch information
crcrpar authored Jan 26, 2023
1 parent 75f401e commit 536e549
Show file tree
Hide file tree
Showing 10 changed files with 588 additions and 603 deletions.
2 changes: 1 addition & 1 deletion apex/contrib/csrc/multihead_attn/cutlass
Submodule cutlass updated 6402 files
4 changes: 2 additions & 2 deletions apex/contrib/csrc/multihead_attn/softmax.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,7 @@ masked_softmax_warp_forward(input_t *dst, const output_t *src,
src + itr_idx);
apply_mask<input_t, ELEMENTS_PER_LDG_STG>(
&elements_input[i][it],
(__half)-std::numeric_limits<float>::infinity(),
__float2half(-std::numeric_limits<float>::infinity()),
curr_mask + itr_jmp);
}
}
Expand Down Expand Up @@ -1375,7 +1375,7 @@ __global__ void time_masked_softmax_warp_forward(
src + itr_idx);
apply_mask<input_t, ELEMENTS_PER_LDG_STG>(
&elements_input[i][it],
(__half)-std::numeric_limits<float>::infinity(),
__float2half(-std::numeric_limits<float>::infinity()),
curr_mask + itr_jmp);
}
}
Expand Down
454 changes: 190 additions & 264 deletions apex/contrib/csrc/multihead_attn/strided_batched_gemm.cuh

Large diffs are not rendered by default.

206 changes: 113 additions & 93 deletions apex/contrib/test/multihead_attn/test_encdec_multihead_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,128 +14,148 @@ class EncdecMultiheadAttnTest(unittest.TestCase):
def setUp(self, seed=1234):
torch.manual_seed(seed)

self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0

self.ref_layer = EncdecMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=False,
impl='default')
self.ref_layer = EncdecMultiheadAttn(
self.hidden_dim, self.heads, dropout=self.dropout_prob, bias=False, include_norm_add=False, impl="default"
)
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.ref_inputs_q = torch.randn(
self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")
).requires_grad_(True)
self.ref_inputs_k = torch.randn(
self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")
).requires_grad_(True)

# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

self.tst_layer = EncdecMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=False,
impl='fast')
self.tst_layer = EncdecMultiheadAttn(
self.hidden_dim, self.heads, dropout=self.dropout_prob, bias=False, include_norm_add=False, impl="fast"
)
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()

self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)

def test_encdec_multihead_attn(self) :
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)

tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.tst_inputs_q = torch.randn(
self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")
).requires_grad_(True)
self.tst_inputs_k = torch.randn(
self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")
).requires_grad_(True)

def test_encdec_multihead_attn(self):
ref_outputs, _ = self.ref_layer.forward(
self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True,
)

tst_outputs, _ = self.tst_layer.forward(
self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True,
)
torch.testing.assert_close(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)

with torch.no_grad():
ref_grads = torch.randn_like(ref_outputs)
ref_grads = torch.randn_like(ref_outputs)
tst_grads = ref_grads.clone()
ref_outputs.backward(ref_grads)
tst_outputs.backward(tst_grads)
self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))

def test_encdec_multihead_attn_time_mask(self) :
grads = torch.randn_like(self.tst_inputs_q)
time_mask_byte = torch.triu(torch.ones(self.tst_inputs_q.size(0), self.tst_inputs_k.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1)
torch.testing.assert_close(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3)

def test_encdec_multihead_attn_time_mask(self):
grads = torch.randn_like(self.tst_inputs_q)
time_mask_byte = torch.triu(
torch.ones(
self.tst_inputs_q.size(0), self.tst_inputs_k.size(0), device=torch.device("cuda"), dtype=torch.uint8
),
1,
)
time_mask_bool = time_mask_byte.to(torch.bool)

ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=time_mask_bool,
is_training=True)

tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=time_mask_byte,
is_training=True)
ref_outputs, _ = self.ref_layer.forward(
self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=time_mask_bool,
is_training=True,
)

tst_outputs, _ = self.tst_layer.forward(
self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=time_mask_byte,
is_training=True,
)

self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)

self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))

def test_encdec_multihead_attn_pad_mask(self) :
grads = torch.randn_like(self.tst_inputs_q)
pad_mask_byte = torch.tril(torch.ones(self.tst_inputs_k.size(1), self.tst_inputs_k.size(0), device=torch.device("cuda"), dtype=torch.uint8), 1)
torch.testing.assert_close(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3)

def test_encdec_multihead_attn_pad_mask(self):
grads = torch.randn_like(self.tst_inputs_q)
pad_mask_byte = torch.tril(
torch.ones(
self.tst_inputs_k.size(1), self.tst_inputs_k.size(0), device=torch.device("cuda"), dtype=torch.uint8
),
1,
)
pad_mask_bool = pad_mask_byte.to(torch.bool)

ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=pad_mask_bool,
need_weights=False,
attn_mask=None,
is_training=True)

tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=pad_mask_byte,
need_weights=False,
attn_mask=None,
is_training=True)
ref_outputs, _ = self.ref_layer.forward(
self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=pad_mask_bool,
need_weights=False,
attn_mask=None,
is_training=True,
)

tst_outputs, _ = self.tst_layer.forward(
self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=pad_mask_byte,
need_weights=False,
attn_mask=None,
is_training=True,
)

self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)

self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))
torch.testing.assert_close(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -15,71 +15,73 @@ def setUp(self, seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.seq_length = 80
self.sequences = 10
self.hidden_dim = 1024
self.heads = 16
self.dropout_prob = 0.0

self.ref_layer = EncdecMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=True,
impl='default')
self.ref_layer = EncdecMultiheadAttn(
self.hidden_dim, self.heads, dropout=self.dropout_prob, bias=False, include_norm_add=True, impl="default"
)
self.ref_layer.cuda().half()
self.ref_layer.reset_parameters()
self.ref_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.ref_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.ref_inputs_q = torch.randn(
self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")
).requires_grad_(True)
self.ref_inputs_k = torch.randn(
self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")
).requires_grad_(True)

# Reset seed so parameters are identical
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

self.tst_layer = EncdecMultiheadAttn(self.hidden_dim,
self.heads,
dropout=self.dropout_prob,
bias=False,
include_norm_add=True,
impl='fast')
self.tst_layer = EncdecMultiheadAttn(
self.hidden_dim, self.heads, dropout=self.dropout_prob, bias=False, include_norm_add=True, impl="fast"
)
self.tst_layer.cuda().half()
self.tst_layer.reset_parameters()

self.tst_inputs_q = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)
self.tst_inputs_k = torch.randn(self.seq_length, self.sequences, self.hidden_dim,
dtype=torch.float16, device=torch.device("cuda")).requires_grad_(True)

def test_encdec_multihead_attn_norm_add(self) :
grads = torch.randn_like(self.tst_inputs_q)

for _ in range(5) :
ref_outputs,_ = self.ref_layer.forward(self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)

tst_outputs,_ = self.tst_layer.forward(self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True)
self.tst_inputs_q = torch.randn(
self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")
).requires_grad_(True)
self.tst_inputs_k = torch.randn(
self.seq_length, self.sequences, self.hidden_dim, dtype=torch.float16, device=torch.device("cuda")
).requires_grad_(True)

def test_encdec_multihead_attn_norm_add(self):
grads = torch.randn_like(self.tst_inputs_q)

for _ in range(5):
ref_outputs, _ = self.ref_layer.forward(
self.ref_inputs_q,
self.ref_inputs_k,
self.ref_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True,
)

tst_outputs, _ = self.tst_layer.forward(
self.tst_inputs_q,
self.tst_inputs_k,
self.tst_inputs_k,
key_padding_mask=None,
need_weights=False,
attn_mask=None,
is_training=True,
)

self.ref_inputs_q.backward(grads)
self.tst_inputs_q.backward(grads)

self.assertTrue(torch.allclose(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5))
self.assertTrue(torch.allclose(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3))
self.assertTrue(torch.allclose(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3))
torch.testing.assert_close(self.ref_inputs_q, self.tst_inputs_q, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(self.ref_inputs_k, self.tst_inputs_k, atol=1e-5, rtol=1e-5)
torch.testing.assert_close(ref_outputs, tst_outputs, atol=1e-3, rtol=1e-3)
torch.testing.assert_close(self.ref_inputs_q.grad, self.tst_inputs_q.grad, atol=1e-3, rtol=1e-3)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 536e549

Please sign in to comment.