forked from facebookresearch/xformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
blocksparse.py
190 lines (148 loc) · 6.57 KB
/
blocksparse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from dataclasses import dataclass
import torch
from xformers import _is_triton_available
from xformers.components.attention import Attention, AttentionConfig, register_attention
logger = logging.getLogger("xformers")
_is_blocksparse_available = _is_triton_available()
if _is_blocksparse_available:
from triton.ops.blocksparse import matmul as blocksparse_matmul # type: ignore
from triton.ops.blocksparse import softmax as blocksparse_softmax # type: ignore
from xformers.triton.utils import gpu_capabilities_older_than_70
# Blocksparse requires Tensor cores
if gpu_capabilities_older_than_70():
logger.warning(
"Blocksparse is not available: the current GPU does not expose Tensor cores"
)
_is_blocksparse_available = False
if _is_blocksparse_available:
@dataclass
class BlockSparseAttentionConfig(AttentionConfig):
layout: torch.Tensor # The dimensions of the random features
block_size: int
dropout: float
num_heads: int
@register_attention("blocksparse", BlockSparseAttentionConfig)
class BlockSparseAttention(Attention):
r"""
Thin wrap over the Triton blocksparse computations. The sparsity pattern is determined through the layout.
.. warning: the layout is assumed to have the dimensions [heads, seq, seq].
If some dimensions are missing, we assume that the same layout is to be used across heads.
.. warning: for now, the sequence (context) length has to be a power of two. This constraint could
be relaxed in the future.
.. warning: the block size has to be picked from [16, 32, 64]. Some speed is gained from bigger blocks.
It is of course possible to reproduce coarser patterns given these primitives, as the user sees fit.
"""
def __init__(
self,
layout: torch.Tensor,
block_size: int = 16,
dropout: float = 0.0,
num_heads: int = 1, # optional, used to adapt the layout if in need
causal: bool = False,
*args,
**kwargs,
):
if layout.dim() == 2:
logger.warning(
"The layout passed is lacking a head dimension and a batch dimension"
)
logger.warning(
"Now assuming that the same layout is to be used across all heads"
)
layout = layout.unsqueeze(0).expand(num_heads, -1, -1)
logger.warning(f"New layout dimensions: {layout.shape}")
assert block_size in (
16,
32,
64,
128,
), "Only block sizes in [16, 32, 64, 128] are supported"
super().__init__()
self.causal = causal
self.attn_drop = torch.nn.Dropout(dropout, inplace=False)
# Pure blocksparse data
self.layout = layout
self.block_size = block_size
# make sure that the head dimension is not folded down with the batch
self.requires_head_dimension = True
# key padding mask and attention mask must be passed in separately
self.requires_same_k_q_dimensions = True
# The underlying triton op does not support per element attention mask
self.supports_attention_mask = False
self.supports_key_padding_mask = False
def create_triton_kernels(self, device):
# blocksparse operators
self.sparse_dot_sdd = blocksparse_matmul(
self.layout,
self.block_size,
"sdd",
trans_a=False,
trans_b=True,
device=device,
)
self.sparse_dot_dsd = blocksparse_matmul(
self.layout,
self.block_size,
"dsd",
trans_a=False,
trans_b=False,
device=device,
)
self.sparse_softmax = blocksparse_softmax(
self.layout,
self.block_size,
device=device,
)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
scale: float = 1.0,
*args,
**kwargs,
) -> torch.Tensor:
assert (
"att_mask" not in kwargs.keys() and "att_mask" not in args
), "This attention does not support an attention mask, but you can specify causality."
r"""
A thin wrap around the Triton blockparse attention operation
.. note: Per element attention mask is not supported, but you can specify causality
"""
# Delayed triton init, to make sure that we get the right device
# Infer device from query
if not hasattr(self, "sparse_dot_sdd"):
self.create_triton_kernels(q.device)
assert (
q.shape[-2] == k.shape[-2]
), "Blocksparse requires the same dimensions for K and Q for now"
assert (
q.shape[-2] == self.layout.shape[-2] * self.block_size
), "Actual sequence size and layout are inconsistent"
assert (
k.shape[-2] == self.layout.shape[-2] * self.block_size
), "Actual sequence size and layout are inconsistent"
assert (
q.shape[-2] % self.block_size
) == 0, "Sequence length {} must be a multiple of block size {}".format(
q.shape[-2], self.block_size
)
# Self-attend: (B, nh, S, hs) x (B, nh, hs, S) -> (B, nh, S, S)
# When the computations are block sparse, the matrix types change along the way:
# - (sparse) attention matrix = (dense) Kt * (dense) Q
q = q / math.sqrt(q.size(-1))
sparse_att_mat = self.sparse_dot_sdd(q, k)
# - softmax on the sparse attention matrix
sparse_att_mat = self.sparse_softmax(
sparse_att_mat, scale=scale, is_causal=self.causal
)
sparse_att_mat = self.attn_drop(sparse_att_mat)
# - then (dense) attention is (sparse) attention matrix * dense (value)
a = self.sparse_dot_dsd(sparse_att_mat, v)
return a