|
1 | 1 | #include "THCUNN.h"
|
2 | 2 | #include "THCHalf.h"
|
3 |
| -#include "THCTensorTypeUtils.cuh" |
4 |
| -#include "THCHalfAutoNumerics.cuh" |
5 |
| -#include "SharedMem.cuh" |
6 | 3 |
|
7 |
| -template <typename T, typename AccumT> |
8 |
| -__global__ void cunn_SpatialLogSoftMax_updateOutput_kernel(T *output, T *input, uint32_t outer_size, uint32_t dim_size, uint32_t inner_size) |
9 |
| -{ |
10 |
| - const uint32_t outer_stride = inner_size * dim_size; |
11 |
| - const uint32_t dim_stride = inner_size; |
12 |
| - |
13 |
| - for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) { |
14 |
| - const uint32_t outer_offset = outer_index * outer_stride; |
15 |
| - for (uint32_t inner_index = blockIdx.y * blockDim.x + threadIdx.x; inner_index < inner_size; inner_index += blockDim.x * gridDim.y) { |
16 |
| - const uint32_t data_offset = outer_offset + inner_index; |
17 |
| - |
18 |
| - T max_input = input[data_offset]; |
19 |
| - for (uint32_t d = 1; d < dim_size; d++) { |
20 |
| - const T value = input[data_offset + d * dim_stride]; |
21 |
| - max_input = THCNumerics<T>::ge(max_input, value) ? max_input : value; |
22 |
| - } |
23 |
| - |
24 |
| - AccumT sum = 0; |
25 |
| - for (uint32_t d = 0; d < dim_size; d++) |
26 |
| - sum += THCNumerics<T>::exp(input[data_offset + d * dim_stride] - max_input); |
27 |
| - const T logsum = max_input + ScalarConvert<AccumT, T>::to(THCNumerics<AccumT>::log(sum)); |
28 |
| - |
29 |
| - for (uint32_t d = 0; d < dim_size; d++) |
30 |
| - output[data_offset + d * dim_stride] = input[data_offset + d * dim_stride] - logsum; |
31 |
| - } |
32 |
| - } |
33 |
| -} |
34 |
| - |
35 |
| -template <typename T, typename AccumT> |
36 |
| -__global__ void cunn_SpatialLogSoftMax_updateGradInput_kernel(T *gradInput, T *output, T *gradOutput, uint32_t outer_size, uint32_t dim_size, uint32_t inner_size) |
37 |
| -{ |
38 |
| - const uint32_t outer_stride = inner_size * dim_size; |
39 |
| - const uint32_t dim_stride = inner_size; |
40 |
| - |
41 |
| - for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) { |
42 |
| - const uint32_t outer_offset = outer_index * outer_stride; |
43 |
| - for (uint32_t inner_index = blockIdx.y * blockDim.x + threadIdx.x; inner_index < inner_size; inner_index += blockDim.x * gridDim.y) { |
44 |
| - const uint32_t data_offset = outer_offset + inner_index; |
45 |
| - |
46 |
| - AccumT sum = 0; |
47 |
| - for (uint32_t d = 0; d < dim_size; d++) { |
48 |
| - sum += gradOutput[data_offset + d * dim_stride]; |
49 |
| - } |
50 |
| - const T real_sum = ScalarConvert<AccumT, T>::to(sum); |
51 |
| - |
52 |
| - for (uint32_t d = 0; d < dim_size; d++) { |
53 |
| - gradInput[data_offset + d * dim_stride] = gradOutput[data_offset + d * dim_stride] - |
54 |
| - THCNumerics<T>::exp(output[data_offset + d * dim_stride]) * real_sum; |
55 |
| - } |
56 |
| - } |
57 |
| - } |
58 |
| -} |
59 |
| - |
60 |
| -static void LogSoftMax_getSpatialGridSize( |
61 |
| - uint32_t block_size, uint32_t max_active_blocks, |
62 |
| - uint64_t outer_size, uint64_t dim_size, uint64_t inner_size, |
63 |
| - dim3& grid, dim3& block) { |
64 |
| - // First, tile as many blocks as we can over the y axis |
65 |
| - uint32_t y_size = (inner_size + block_size - 1) / block_size; |
66 |
| - if (y_size > max_active_blocks) |
67 |
| - y_size = max_active_blocks; |
68 |
| - // Fill the x axis with as many blocks as we can fit |
69 |
| - uint32_t x_size = (max_active_blocks + y_size - 1) / y_size; |
70 |
| - if (x_size > outer_size) |
71 |
| - x_size = outer_size; |
72 |
| - grid = dim3(x_size, y_size); |
73 |
| - block = dim3(block_size); |
74 |
| -} |
75 |
| - |
76 |
| -template <typename T, typename AccumT> |
77 |
| -struct MaxFloat |
78 |
| -{ |
79 |
| - __device__ __forceinline__ AccumT operator()(AccumT max, T v) const |
80 |
| - { |
81 |
| - return fmaxType(max, v); |
82 |
| - } |
83 |
| -}; |
84 |
| - |
85 |
| -template<typename T, typename AccumT> |
86 |
| -struct SumFloat |
87 |
| -{ |
88 |
| - __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const |
89 |
| - { |
90 |
| - return sum + v; |
91 |
| - } |
92 |
| -}; |
| 4 | +#include "SoftMaxCommon.cuh" |
93 | 5 |
|
94 | 6 | template<typename T, typename AccumT>
|
95 |
| -struct SumExpFloat |
96 |
| -{ |
97 |
| - __device__ __forceinline__ SumExpFloat(T v) |
98 |
| - : max_k(v) |
99 |
| - {} |
| 7 | +struct LogSoftMaxForwardEpilogue { |
| 8 | + __device__ __forceinline__ LogSoftMaxForwardEpilogue(T max_input, AccumT sum) |
| 9 | + : logsum(max_input + ScalarConvert<AccumT, T>::to(THCNumerics<AccumT>::log(sum))) {} |
100 | 10 |
|
101 |
| - __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const |
102 |
| - { |
103 |
| - return sum + THCNumerics<T>::exp(v - max_k); |
| 11 | + __device__ __forceinline__ T operator()(T input) const { |
| 12 | + return input - logsum; |
104 | 13 | }
|
105 | 14 |
|
106 |
| - const T max_k; |
107 |
| -}; |
108 |
| - |
109 |
| -template<typename AccumT> |
110 |
| -struct NoFinal |
111 |
| -{ |
112 |
| - __device__ __forceinline__ AccumT operator()(AccumT v) const |
113 |
| - { |
114 |
| - return v; |
115 |
| - } |
| 15 | + const T logsum; |
116 | 16 | };
|
117 | 17 |
|
118 |
| -template<typename AccumT> |
119 |
| -struct LSMFinal |
120 |
| -{ |
121 |
| - __device__ __forceinline__ LSMFinal(AccumT m) |
122 |
| - : max_k(m) |
123 |
| - {} |
| 18 | +template<typename T, typename AccumT> |
| 19 | +struct LogSoftMaxBackwardEpilogue { |
| 20 | + __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) |
| 21 | + : sum(ScalarConvert<AccumT, T>::to(sum)) {} |
124 | 22 |
|
125 |
| - __device__ __forceinline__ AccumT operator()(AccumT v) const |
126 |
| - { |
127 |
| - return max_k + THCNumerics<AccumT>::log(v); |
| 23 | + __device__ __forceinline__ T operator()(T gradOutput, T output) const { |
| 24 | + return gradOutput - THCNumerics<T>::exp(output) * sum; |
128 | 25 | }
|
129 | 26 |
|
130 |
| - const AccumT max_k; |
| 27 | + const T sum; |
131 | 28 | };
|
132 | 29 |
|
133 |
| -template <template<typename, typename> class Reduction, template<typename> class Finalize, typename AccumT> |
134 |
| -__device__ __forceinline__ AccumT |
135 |
| -blockReduce(AccumT* smem, AccumT val, |
136 |
| - const Reduction<AccumT, AccumT>& r, |
137 |
| - AccumT defaultVal, |
138 |
| - const Finalize<AccumT>& f) |
139 |
| -{ |
140 |
| - // To avoid RaW races from chaining blockReduce calls together, we |
141 |
| - // need a sync here |
142 |
| - __syncthreads(); |
143 |
| - |
144 |
| - smem[threadIdx.x] = val; |
145 |
| - |
146 |
| - __syncthreads(); |
147 |
| - |
148 |
| - AccumT warpVal = defaultVal; |
149 |
| - |
150 |
| - // First warp will perform per-warp reductions for the remaining warps |
151 |
| - if ((threadIdx.x / 32) == 0) // only threads in warp1 go into this (if) |
152 |
| - { |
153 |
| - int lane = threadIdx.x % 32; // from 0 to 31 |
154 |
| - |
155 |
| - // if less than 1024 threads per block, then only activate the relevant lanes |
156 |
| - if (lane < blockDim.x / 32) |
157 |
| - { |
158 |
| -#pragma unroll |
159 |
| - for (int i = 0; i < 32; ++i) |
160 |
| - { |
161 |
| - warpVal = r(warpVal, smem[lane * 32 + i]); |
162 |
| - } |
163 |
| - |
164 |
| - smem[lane] = warpVal; |
165 |
| - } |
166 |
| - } |
167 |
| - |
168 |
| - __syncthreads(); |
169 |
| - |
170 |
| - // First thread will perform a reduction of the above per-warp reductions |
171 |
| - AccumT blockVal = defaultVal; |
172 |
| - |
173 |
| - if (threadIdx.x == 0) |
174 |
| - { |
175 |
| - for (int i = 0; i < blockDim.x / 32; ++i) |
176 |
| - { |
177 |
| - blockVal = r(blockVal, smem[i]); |
178 |
| - } |
179 |
| - |
180 |
| - smem[0] = f(blockVal); |
181 |
| - } |
182 |
| - |
183 |
| - // Sync and broadcast |
184 |
| - __syncthreads(); |
185 |
| - return smem[0]; |
186 |
| -} |
187 |
| - |
188 |
| -template <template<typename, typename> class Reduction, typename AccumT> |
189 |
| -__device__ __forceinline__ AccumT |
190 |
| -blockReduce(AccumT* smem, AccumT val, |
191 |
| - const Reduction<AccumT, AccumT>& r, |
192 |
| - AccumT defaultVal) |
193 |
| -{ |
194 |
| - return blockReduce<Reduction, NoFinal, AccumT>(smem, val, r, defaultVal, NoFinal<AccumT>()); |
195 |
| -} |
196 |
| - |
197 |
| -template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT> |
198 |
| -__device__ __forceinline__ AccumT |
199 |
| -ilpReduce(T* data, |
200 |
| - int size, |
201 |
| - const Reduction<T, AccumT>& r, |
202 |
| - AccumT defaultVal) |
203 |
| -{ |
204 |
| - AccumT threadVal = defaultVal; |
205 |
| - int offset = threadIdx.x; |
206 |
| - |
207 |
| - int last = size % (ILP * blockDim.x); |
208 |
| - |
209 |
| - // Body (unroll by ILP times) |
210 |
| - for (; offset < size - last; offset += blockDim.x * ILP) |
211 |
| - { |
212 |
| - T tmp[ILP]; |
213 |
| - |
214 |
| -#pragma unroll |
215 |
| - for (int j = 0; j < ILP; ++j) |
216 |
| - { |
217 |
| - tmp[j] = data[offset + j * blockDim.x]; |
218 |
| - } |
219 |
| - |
220 |
| -#pragma unroll |
221 |
| - for (int j = 0; j < ILP; ++j) |
222 |
| - { |
223 |
| - threadVal = r(threadVal, tmp[j]); |
224 |
| - } |
225 |
| - } |
226 |
| - |
227 |
| - // Epilogue |
228 |
| - for (; offset < size; offset += blockDim.x) |
229 |
| - { |
230 |
| - threadVal = r(threadVal, data[offset]); |
231 |
| - } |
232 |
| - |
233 |
| - return threadVal; |
234 |
| -} |
235 |
| - |
236 |
| -template <int ILP, typename T, typename AccumT> |
237 |
| -__global__ void |
238 |
| -cunn_LogSoftMax_updateOutput_kernel(T *output, T *input, int classes) |
239 |
| -{ |
240 |
| - SharedMem<AccumT> smem; |
241 |
| - AccumT *buffer = smem.getPointer(); |
242 |
| - // forward pointers to batch[blockIdx.x] |
243 |
| - // each block handles a sample in the mini-batch |
244 |
| - input += blockIdx.x * classes; |
245 |
| - output += blockIdx.x * classes; |
246 |
| - |
247 |
| - // find the max of the batch |
248 |
| - AccumT threadMax = ilpReduce<MaxFloat, ILP, T, AccumT>( |
249 |
| - input, classes, MaxFloat<T, AccumT>(), -THCNumerics<AccumT>::max()); |
250 |
| - // find the max over all batches |
251 |
| - AccumT max_k = blockReduce<MaxFloat, AccumT>( |
252 |
| - buffer, threadMax, MaxFloat<AccumT, AccumT>(), -THCNumerics<AccumT>::max()); |
253 |
| - T max_k_non_accum = ScalarConvert<AccumT, T>::to(max_k); |
254 |
| - |
255 |
| - AccumT threadExp = ilpReduce<SumExpFloat, ILP, T, AccumT>( |
256 |
| - input, classes, SumExpFloat<T, AccumT>(max_k_non_accum), AccumT(0)); |
257 |
| - T logsum_k = ScalarConvert<AccumT, T>::to( |
258 |
| - blockReduce<SumFloat, LSMFinal, AccumT>( |
259 |
| - buffer, threadExp, SumFloat<AccumT, AccumT>(), AccumT(0), LSMFinal<AccumT>(max_k))); |
260 |
| - |
261 |
| - // Output LSM (hand ILP) |
262 |
| - int offset = threadIdx.x; |
263 |
| - |
264 |
| - int last = classes % (ILP * blockDim.x); |
265 |
| - for (; offset < classes - last; offset += blockDim.x * ILP) |
266 |
| - { |
267 |
| - T tmp[ILP]; |
268 |
| - |
269 |
| -#pragma unroll |
270 |
| - for (int j = 0; j < ILP; ++j) { |
271 |
| - tmp[j] = input[offset + j * blockDim.x]; |
272 |
| - } |
273 |
| - |
274 |
| -#pragma unroll |
275 |
| - for (int j = 0; j < ILP; ++j) |
276 |
| - { |
277 |
| - output[offset + j * blockDim.x] = tmp[j] - logsum_k; |
278 |
| - } |
279 |
| - } |
280 |
| - |
281 |
| - for (; offset < classes; offset += blockDim.x) |
282 |
| - { |
283 |
| - output[offset] = input[offset] - logsum_k; |
284 |
| - } |
285 |
| -} |
286 |
| - |
287 |
| -template <int ILP, typename T, typename AccumT> |
288 |
| -__global__ void |
289 |
| -cunn_LogSoftMax_updateGradInput_kernel(T *gradInput, |
290 |
| - T *output, |
291 |
| - T *gradOutput, |
292 |
| - int classes) |
293 |
| -{ |
294 |
| - SharedMem<AccumT> smem; |
295 |
| - AccumT *buffer = smem.getPointer(); |
296 |
| - gradInput += blockIdx.x * classes; |
297 |
| - output += blockIdx.x * classes; |
298 |
| - gradOutput += blockIdx.x * classes; |
299 |
| - |
300 |
| - AccumT threadSum = ilpReduce<SumFloat, 4, T, AccumT>( |
301 |
| - gradOutput, classes, SumFloat<T, AccumT>(), AccumT(0)); |
302 |
| - T sum_k = ScalarConvert<AccumT, T>::to( |
303 |
| - blockReduce<SumFloat, AccumT>( |
304 |
| - buffer, threadSum, SumFloat<AccumT, AccumT>(), AccumT(0))); |
305 |
| - |
306 |
| - // Update gradInput (hand ILP) |
307 |
| - int offset = threadIdx.x; |
308 |
| - int last = classes % (ILP * blockDim.x); |
309 |
| - for (; offset < classes - last; offset += blockDim.x * ILP) |
310 |
| - { |
311 |
| - T tmpGradOutput[ILP]; |
312 |
| - T tmpOutput[ILP]; |
313 |
| - |
314 |
| -#pragma unroll |
315 |
| - for (int j = 0; j < ILP; ++j) |
316 |
| - { |
317 |
| - tmpGradOutput[j] = gradOutput[offset + j * blockDim.x]; |
318 |
| - tmpOutput[j] = output[offset + j * blockDim.x]; |
319 |
| - } |
320 |
| - |
321 |
| -#pragma unroll |
322 |
| - for (int j = 0; j < ILP; ++j) |
323 |
| - { |
324 |
| - gradInput[offset + j * blockDim.x] = |
325 |
| - tmpGradOutput[j] - THCNumerics<T>::exp(tmpOutput[j]) * sum_k; |
326 |
| - } |
327 |
| - } |
328 |
| - |
329 |
| - for (; offset < classes; offset += blockDim.x) |
330 |
| - { |
331 |
| - gradInput[offset] = |
332 |
| - gradOutput[offset] - THCNumerics<T>::exp(output[offset]) * sum_k; |
333 |
| - } |
334 |
| -} |
335 |
| - |
336 | 30 | #include "generic/LogSoftMax.cu"
|
337 | 31 | #include "THCGenerateFloatTypes.h"
|
0 commit comments