-
Notifications
You must be signed in to change notification settings - Fork 39
/
Copy pathmodule.cpp
266 lines (211 loc) · 10.6 KB
/
module.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <iostream>
#include <time.h>
#include <sys/time.h>
#include <vector>
#include <immintrin.h>
// Uncomment for ISPC
//#include "module_ispc.h"
//using namespace ispc;
// ------------------------------------ //
// WARM-UP: ACCESSING TENSORS //
// ------------------------------------ //
// Step #1: Understand Read/Write Accessors for a 2D Tensor
inline float twoDimRead(std::vector<float> &tensor, int &x, int &y, const int &sizeX) {
// Note that sizeX is the size of a Row, not the number of rows
return tensor[x * (sizeX)+ y];
}
inline void twoDimWrite(std::vector<float> &tensor, int &x, int &y, const int &sizeX, float &val) {
tensor[x * (sizeX) + y] = val;
}
// Step #2: Implement Read/Write Accessors for a 4D Tensor
inline float fourDimRead(std::vector<float> &tensor, int &x, int &y, int &z, int &b,
const int &sizeX, const int &sizeY, const int &sizeZ) {
return 0.0;
}
inline void fourDimWrite(std::vector<float> &tensor, int &x, int &y, int &z, int &b,
const int &sizeX, const int &sizeY, const int &sizeZ, float &val) {
return;
}
// DO NOT EDIT THIS FUNCTION //
std::vector<float> formatTensor(torch::Tensor tensor) {
tensor = tensor.flatten();
tensor = tensor.contiguous();
std::vector<float> vec(tensor.data_ptr<float>(), tensor.data_ptr<float>() + tensor.numel());
return vec;
}
/* Programming Your Attention Modules.
*
* You are given Q, K, and V Tensors as inputs that are formatted as vectors. We have also created O and QK^t Tensors
* that are formatted as vectors. After you have implemented your accessors in the Warm-Up you should be able to
* read/write to these tensors via the read/write functions above.
*
* You are also given 4 integers as parameters: B, H, N, d:
*
* B (Batch Size) - The number of samples for your attention layer. Think of it this way - if I asked my dnn
* a question and it output 5 different answers it had a batch size of 5. These samples are independent of each
* other and thus can be parallelized.
*
* H (Number of Heads) - Each head runs on its own set of Q, K, V matrices. This effectively allows each head
* to operate the same attention algorithm, but each with each head using different hyperparameters. These
* allow each head to have their own definition of what relevance is when looking at a token. These heads
* can operate independently of one another and thus can be parallized.
*
* N (Sequence Length) - The number of tokens. You may think of this as the number of words in a sample.
*
* d (Embedding Dimensionality) - The number of features each token encodes per attention head. Let's
* say I encoded a word using the follow (length, number of vowels, has a capital letters). The
* emvedded dimensionaliy would be 3.
* */
// ---------------------------------------------------------- //
// PART 1: NAIVE ATTENTION //
// ---------------------------------------------------------- //
torch::Tensor myNaiveAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor QK_tTensor,
int B, int H, int N, int d){
// Q, K, V are passed in with Shape: (B, H, N, d)
//QK^t Intermediate Tensor has Shape (N, N)
//Make O Tensor with Shape (B, H, N, d)
at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
//Format O, Q, K, and V tensors into 4D vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);
//Format QK_t Tensor into a 2D vector.
std::vector<float> QK_t = formatTensor(QK_tTensor);
/* Here is an example of how to read/write 0's to Q (B, H, N, d) using the 4D accessors
//loop over Batch Size
for (int b = 0; b < B; b++) {
//loop over Heads
for (int h = 0; h < H; h++) {
//loop over Sequence Length
for (int i = 0; i < N; i++) {
//loop over Embedding Dimensionality
for (int j = 0; j < d; j++) {
float val = fourDimRead(Q, b, h, i, j, H, N, d);
val = 0.0;
fourDimWrite(Q, b, h, i, j, H, N, d, val);
}
}
}
}
*/
/* Here is an example of how to read/write 0's to QK_t (N, N) using the 2D accessors
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
float val = twoDimRead(QK_t, i, j, N);
val = 0.0;
twoDimWrite(QK_t, i, j, N, val);
}
}
*/
// -------- YOUR CODE HERE -------- //
// DO NOT EDIT THIS RETURN STATEMENT //
// It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
// ---------------------------------------------------------- //
// PART 2: BLOCKED MATRIX MULTIPLY AND UNFUSED SOFTMAX //
// ---------------------------------------------------------- //
torch::Tensor myUnfusedAttentionBlocked(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor QK_tTensor,
int B, int H, int N, int d){
// Q, K, V are passed in with Shape: (B, H, N, d)
//QK^t Intermediate Tensor has Shape (N, N)
//Make O Tensor with Shape (B, H, N, d)
at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
//Format O, Q, K, and V tensors into 4D vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);
//Format QK_t Tensor into a 2D vector.
std::vector<float> QK_t = formatTensor(QK_tTensor);
// -------- YOUR CODE HERE -------- //
// DO NOT EDIT THIS RETURN STATEMENT //
// It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
// ---------------------------------------------------------- //
// PART 3: FUSED ATTENTION //
// ---------------------------------------------------------- //
torch::Tensor myFusedAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor temp,
int B, int H, int N, int d){
// Q, K, V are passed in with Shape: (B, H, N, d)
//Make O Tensor with Shape (B, H, N, d)
//and O Row Tensor with Shape (N)
at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
at::Tensor ORowTensor = at::zeros({N}, at::kFloat);
//Format Y, Q, K, and V tensors into 4D vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);
//Format ORow Tensor into a 1D vector
// You can simply access this as ORow[i]
std::vector<float> ORow = formatTensor(ORowTensor);
// -------- YOUR CODE HERE -------- //
// We give you a template of the first three loops for your convenience
//loop over batch
for (int b = 0; b < B; b++){
//loop over heads
for (int h = 0; h < H; h++){
for (int i = 0; i < N ; i++){
// YRow is moved inside so each OpenMP thread gets a local copy.
at::Tensor ORowTensor = temp.index({torch::indexing::Slice(omp_get_thread_num(), torch::indexing::None)});
std::vector<float> ORow = formatTensor(ORowTensor);
//YOUR CODE HERE
}
}
}
// DO NOT EDIT THIS RETURN STATEMENT //
// It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
// ---------------------------------------------------------- //
// PART 4: FLASH ATTENTION //
// ---------------------------------------------------------- //
torch::Tensor myFlashAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor,
torch::Tensor QiTensor, torch::Tensor KjTensor, torch::Tensor VjTensor,
torch::Tensor SijTensor, torch::Tensor PijTensor, torch::Tensor PVTensor,
torch::Tensor OiTensor, torch::Tensor LTensor, torch::Tensor LiTensor,
torch::Tensor LijTensor, torch::Tensor LnewTensor, int Bc, int Br,
int B, int H, int N, int d) {
// Q, K, V are passed in with Shape: (B, H, N, d)
// Sij, Pij are passed in with Shape: (Br, Bc)
// Kj, Vj are passed in with Shape: (Bc, d)
// Qi, Oi, and PV are passed in with Shape: (Br, d)
// L in passed in with Shape: (N)
// Li, Lij, and Lnew are passed in with shape (Br)
//Make O Tensor with Shape (B, H, N, d)
at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
//Format All Tensors into Vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);
std::vector<float> Sij = formatTensor(SijTensor);
std::vector<float> Pij = formatTensor(PijTensor);
std::vector<float> Kj = formatTensor(KjTensor);
std::vector<float> Vj = formatTensor(VjTensor);
std::vector<float> Qi = formatTensor(QiTensor);
std::vector<float> Oi = formatTensor(OiTensor);
std::vector<float> l = formatTensor(LTensor);
std::vector<float> PV = formatTensor(PVTensor);
std::vector<float> li = formatTensor(LiTensor);
std::vector<float> lij = formatTensor(LijTensor);
std::vector<float> lnew = formatTensor(LnewTensor);
// -------- YOUR CODE HERE -------- //
// DO NOT EDIT THIS RETURN STATEMENT //
// It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
/* DO NOT EDIT THESE BINDINGS */
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("myNaiveAttention", &myNaiveAttention, "Naive Attention");
m.def("myUnfusedAttentionBlocked", &myUnfusedAttentionBlocked, " Blocked Unfused Attention");
m.def("myFusedAttention", &myFusedAttention, "Fused Attention");
m.def("myFlashAttention", &myFlashAttention, "Flash Attention");
m.def("twoDimRead", &twoDimRead, "twoDimRead");
m.def("fourDimRead", &fourDimRead, "fourDimRead");
}