Skip to content

Commit b3642b3

Browse files
apaszkesoumith
authored andcommitted
Softmax/LogSoftMax refactor (wrapped up) (pytorch#3245)
* Unify CUDA kernels for SoftMax and LogSoftMax * Improve SoftMax and LogSoftMax kernels performance Added a new instantiation of the spatial kernel for low inner_size and larger dim_size.
1 parent e43a63a commit b3642b3

File tree

6 files changed

+580
-625
lines changed

6 files changed

+580
-625
lines changed

test/test_nn.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4390,8 +4390,20 @@ def mseloss_no_reduce_module_test():
43904390
),
43914391
dict(
43924392
constructor=wrap_functional(F.softmax, dim=1),
4393-
input_size=(2, 3, 4, 5),
4394-
fullname='softmax_functional',
4393+
input_size=(2, 128), # trigger the last-dim algo in CUDA
4394+
fullname='softmax_lastdim',
4395+
pickle=False,
4396+
),
4397+
dict(
4398+
constructor=wrap_functional(F.softmax, dim=1),
4399+
input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo
4400+
fullname='softmax_spatial_special',
4401+
pickle=False,
4402+
),
4403+
dict(
4404+
constructor=wrap_functional(F.softmax, dim=1),
4405+
input_size=(2, 2, 4, 4), # regular spatial algorithm
4406+
fullname='softmax_spatial',
43954407
pickle=False,
43964408
),
43974409
dict(
@@ -4410,8 +4422,20 @@ def mseloss_no_reduce_module_test():
44104422
),
44114423
dict(
44124424
constructor=wrap_functional(F.log_softmax, dim=1),
4413-
input_size=(2, 3, 4, 5),
4414-
fullname='log_softmax',
4425+
input_size=(2, 128), # trigger the last-dim algo in CUDA
4426+
fullname='log_softmax_lastdim',
4427+
pickle=False,
4428+
),
4429+
dict(
4430+
constructor=wrap_functional(F.log_softmax, dim=1),
4431+
input_size=(2, 128, 2, 2), # trigger special case of spatial CUDA algo
4432+
fullname='log_softmax_spatial_special',
4433+
pickle=False,
4434+
),
4435+
dict(
4436+
constructor=wrap_functional(F.log_softmax, dim=1),
4437+
input_size=(2, 2, 4, 4), # regular spatial algorithm
4438+
fullname='log_softmax_spatial',
44154439
pickle=False,
44164440
),
44174441
dict(

torch/lib/THCUNN/LogSoftMax.cu

Lines changed: 14 additions & 320 deletions
Original file line numberDiff line numberDiff line change
@@ -1,337 +1,31 @@
11
#include "THCUNN.h"
22
#include "THCHalf.h"
3-
#include "THCTensorTypeUtils.cuh"
4-
#include "THCHalfAutoNumerics.cuh"
5-
#include "SharedMem.cuh"
63

7-
template <typename T, typename AccumT>
8-
__global__ void cunn_SpatialLogSoftMax_updateOutput_kernel(T *output, T *input, uint32_t outer_size, uint32_t dim_size, uint32_t inner_size)
9-
{
10-
const uint32_t outer_stride = inner_size * dim_size;
11-
const uint32_t dim_stride = inner_size;
12-
13-
for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
14-
const uint32_t outer_offset = outer_index * outer_stride;
15-
for (uint32_t inner_index = blockIdx.y * blockDim.x + threadIdx.x; inner_index < inner_size; inner_index += blockDim.x * gridDim.y) {
16-
const uint32_t data_offset = outer_offset + inner_index;
17-
18-
T max_input = input[data_offset];
19-
for (uint32_t d = 1; d < dim_size; d++) {
20-
const T value = input[data_offset + d * dim_stride];
21-
max_input = THCNumerics<T>::ge(max_input, value) ? max_input : value;
22-
}
23-
24-
AccumT sum = 0;
25-
for (uint32_t d = 0; d < dim_size; d++)
26-
sum += THCNumerics<T>::exp(input[data_offset + d * dim_stride] - max_input);
27-
const T logsum = max_input + ScalarConvert<AccumT, T>::to(THCNumerics<AccumT>::log(sum));
28-
29-
for (uint32_t d = 0; d < dim_size; d++)
30-
output[data_offset + d * dim_stride] = input[data_offset + d * dim_stride] - logsum;
31-
}
32-
}
33-
}
34-
35-
template <typename T, typename AccumT>
36-
__global__ void cunn_SpatialLogSoftMax_updateGradInput_kernel(T *gradInput, T *output, T *gradOutput, uint32_t outer_size, uint32_t dim_size, uint32_t inner_size)
37-
{
38-
const uint32_t outer_stride = inner_size * dim_size;
39-
const uint32_t dim_stride = inner_size;
40-
41-
for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) {
42-
const uint32_t outer_offset = outer_index * outer_stride;
43-
for (uint32_t inner_index = blockIdx.y * blockDim.x + threadIdx.x; inner_index < inner_size; inner_index += blockDim.x * gridDim.y) {
44-
const uint32_t data_offset = outer_offset + inner_index;
45-
46-
AccumT sum = 0;
47-
for (uint32_t d = 0; d < dim_size; d++) {
48-
sum += gradOutput[data_offset + d * dim_stride];
49-
}
50-
const T real_sum = ScalarConvert<AccumT, T>::to(sum);
51-
52-
for (uint32_t d = 0; d < dim_size; d++) {
53-
gradInput[data_offset + d * dim_stride] = gradOutput[data_offset + d * dim_stride] -
54-
THCNumerics<T>::exp(output[data_offset + d * dim_stride]) * real_sum;
55-
}
56-
}
57-
}
58-
}
59-
60-
static void LogSoftMax_getSpatialGridSize(
61-
uint32_t block_size, uint32_t max_active_blocks,
62-
uint64_t outer_size, uint64_t dim_size, uint64_t inner_size,
63-
dim3& grid, dim3& block) {
64-
// First, tile as many blocks as we can over the y axis
65-
uint32_t y_size = (inner_size + block_size - 1) / block_size;
66-
if (y_size > max_active_blocks)
67-
y_size = max_active_blocks;
68-
// Fill the x axis with as many blocks as we can fit
69-
uint32_t x_size = (max_active_blocks + y_size - 1) / y_size;
70-
if (x_size > outer_size)
71-
x_size = outer_size;
72-
grid = dim3(x_size, y_size);
73-
block = dim3(block_size);
74-
}
75-
76-
template <typename T, typename AccumT>
77-
struct MaxFloat
78-
{
79-
__device__ __forceinline__ AccumT operator()(AccumT max, T v) const
80-
{
81-
return fmaxType(max, v);
82-
}
83-
};
84-
85-
template<typename T, typename AccumT>
86-
struct SumFloat
87-
{
88-
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const
89-
{
90-
return sum + v;
91-
}
92-
};
4+
#include "SoftMaxCommon.cuh"
935

946
template<typename T, typename AccumT>
95-
struct SumExpFloat
96-
{
97-
__device__ __forceinline__ SumExpFloat(T v)
98-
: max_k(v)
99-
{}
7+
struct LogSoftMaxForwardEpilogue {
8+
__device__ __forceinline__ LogSoftMaxForwardEpilogue(T max_input, AccumT sum)
9+
: logsum(max_input + ScalarConvert<AccumT, T>::to(THCNumerics<AccumT>::log(sum))) {}
10010

101-
__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const
102-
{
103-
return sum + THCNumerics<T>::exp(v - max_k);
11+
__device__ __forceinline__ T operator()(T input) const {
12+
return input - logsum;
10413
}
10514

106-
const T max_k;
107-
};
108-
109-
template<typename AccumT>
110-
struct NoFinal
111-
{
112-
__device__ __forceinline__ AccumT operator()(AccumT v) const
113-
{
114-
return v;
115-
}
15+
const T logsum;
11616
};
11717

118-
template<typename AccumT>
119-
struct LSMFinal
120-
{
121-
__device__ __forceinline__ LSMFinal(AccumT m)
122-
: max_k(m)
123-
{}
18+
template<typename T, typename AccumT>
19+
struct LogSoftMaxBackwardEpilogue {
20+
__device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum)
21+
: sum(ScalarConvert<AccumT, T>::to(sum)) {}
12422

125-
__device__ __forceinline__ AccumT operator()(AccumT v) const
126-
{
127-
return max_k + THCNumerics<AccumT>::log(v);
23+
__device__ __forceinline__ T operator()(T gradOutput, T output) const {
24+
return gradOutput - THCNumerics<T>::exp(output) * sum;
12825
}
12926

130-
const AccumT max_k;
27+
const T sum;
13128
};
13229

133-
template <template<typename, typename> class Reduction, template<typename> class Finalize, typename AccumT>
134-
__device__ __forceinline__ AccumT
135-
blockReduce(AccumT* smem, AccumT val,
136-
const Reduction<AccumT, AccumT>& r,
137-
AccumT defaultVal,
138-
const Finalize<AccumT>& f)
139-
{
140-
// To avoid RaW races from chaining blockReduce calls together, we
141-
// need a sync here
142-
__syncthreads();
143-
144-
smem[threadIdx.x] = val;
145-
146-
__syncthreads();
147-
148-
AccumT warpVal = defaultVal;
149-
150-
// First warp will perform per-warp reductions for the remaining warps
151-
if ((threadIdx.x / 32) == 0) // only threads in warp1 go into this (if)
152-
{
153-
int lane = threadIdx.x % 32; // from 0 to 31
154-
155-
// if less than 1024 threads per block, then only activate the relevant lanes
156-
if (lane < blockDim.x / 32)
157-
{
158-
#pragma unroll
159-
for (int i = 0; i < 32; ++i)
160-
{
161-
warpVal = r(warpVal, smem[lane * 32 + i]);
162-
}
163-
164-
smem[lane] = warpVal;
165-
}
166-
}
167-
168-
__syncthreads();
169-
170-
// First thread will perform a reduction of the above per-warp reductions
171-
AccumT blockVal = defaultVal;
172-
173-
if (threadIdx.x == 0)
174-
{
175-
for (int i = 0; i < blockDim.x / 32; ++i)
176-
{
177-
blockVal = r(blockVal, smem[i]);
178-
}
179-
180-
smem[0] = f(blockVal);
181-
}
182-
183-
// Sync and broadcast
184-
__syncthreads();
185-
return smem[0];
186-
}
187-
188-
template <template<typename, typename> class Reduction, typename AccumT>
189-
__device__ __forceinline__ AccumT
190-
blockReduce(AccumT* smem, AccumT val,
191-
const Reduction<AccumT, AccumT>& r,
192-
AccumT defaultVal)
193-
{
194-
return blockReduce<Reduction, NoFinal, AccumT>(smem, val, r, defaultVal, NoFinal<AccumT>());
195-
}
196-
197-
template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT>
198-
__device__ __forceinline__ AccumT
199-
ilpReduce(T* data,
200-
int size,
201-
const Reduction<T, AccumT>& r,
202-
AccumT defaultVal)
203-
{
204-
AccumT threadVal = defaultVal;
205-
int offset = threadIdx.x;
206-
207-
int last = size % (ILP * blockDim.x);
208-
209-
// Body (unroll by ILP times)
210-
for (; offset < size - last; offset += blockDim.x * ILP)
211-
{
212-
T tmp[ILP];
213-
214-
#pragma unroll
215-
for (int j = 0; j < ILP; ++j)
216-
{
217-
tmp[j] = data[offset + j * blockDim.x];
218-
}
219-
220-
#pragma unroll
221-
for (int j = 0; j < ILP; ++j)
222-
{
223-
threadVal = r(threadVal, tmp[j]);
224-
}
225-
}
226-
227-
// Epilogue
228-
for (; offset < size; offset += blockDim.x)
229-
{
230-
threadVal = r(threadVal, data[offset]);
231-
}
232-
233-
return threadVal;
234-
}
235-
236-
template <int ILP, typename T, typename AccumT>
237-
__global__ void
238-
cunn_LogSoftMax_updateOutput_kernel(T *output, T *input, int classes)
239-
{
240-
SharedMem<AccumT> smem;
241-
AccumT *buffer = smem.getPointer();
242-
// forward pointers to batch[blockIdx.x]
243-
// each block handles a sample in the mini-batch
244-
input += blockIdx.x * classes;
245-
output += blockIdx.x * classes;
246-
247-
// find the max of the batch
248-
AccumT threadMax = ilpReduce<MaxFloat, ILP, T, AccumT>(
249-
input, classes, MaxFloat<T, AccumT>(), -THCNumerics<AccumT>::max());
250-
// find the max over all batches
251-
AccumT max_k = blockReduce<MaxFloat, AccumT>(
252-
buffer, threadMax, MaxFloat<AccumT, AccumT>(), -THCNumerics<AccumT>::max());
253-
T max_k_non_accum = ScalarConvert<AccumT, T>::to(max_k);
254-
255-
AccumT threadExp = ilpReduce<SumExpFloat, ILP, T, AccumT>(
256-
input, classes, SumExpFloat<T, AccumT>(max_k_non_accum), AccumT(0));
257-
T logsum_k = ScalarConvert<AccumT, T>::to(
258-
blockReduce<SumFloat, LSMFinal, AccumT>(
259-
buffer, threadExp, SumFloat<AccumT, AccumT>(), AccumT(0), LSMFinal<AccumT>(max_k)));
260-
261-
// Output LSM (hand ILP)
262-
int offset = threadIdx.x;
263-
264-
int last = classes % (ILP * blockDim.x);
265-
for (; offset < classes - last; offset += blockDim.x * ILP)
266-
{
267-
T tmp[ILP];
268-
269-
#pragma unroll
270-
for (int j = 0; j < ILP; ++j) {
271-
tmp[j] = input[offset + j * blockDim.x];
272-
}
273-
274-
#pragma unroll
275-
for (int j = 0; j < ILP; ++j)
276-
{
277-
output[offset + j * blockDim.x] = tmp[j] - logsum_k;
278-
}
279-
}
280-
281-
for (; offset < classes; offset += blockDim.x)
282-
{
283-
output[offset] = input[offset] - logsum_k;
284-
}
285-
}
286-
287-
template <int ILP, typename T, typename AccumT>
288-
__global__ void
289-
cunn_LogSoftMax_updateGradInput_kernel(T *gradInput,
290-
T *output,
291-
T *gradOutput,
292-
int classes)
293-
{
294-
SharedMem<AccumT> smem;
295-
AccumT *buffer = smem.getPointer();
296-
gradInput += blockIdx.x * classes;
297-
output += blockIdx.x * classes;
298-
gradOutput += blockIdx.x * classes;
299-
300-
AccumT threadSum = ilpReduce<SumFloat, 4, T, AccumT>(
301-
gradOutput, classes, SumFloat<T, AccumT>(), AccumT(0));
302-
T sum_k = ScalarConvert<AccumT, T>::to(
303-
blockReduce<SumFloat, AccumT>(
304-
buffer, threadSum, SumFloat<AccumT, AccumT>(), AccumT(0)));
305-
306-
// Update gradInput (hand ILP)
307-
int offset = threadIdx.x;
308-
int last = classes % (ILP * blockDim.x);
309-
for (; offset < classes - last; offset += blockDim.x * ILP)
310-
{
311-
T tmpGradOutput[ILP];
312-
T tmpOutput[ILP];
313-
314-
#pragma unroll
315-
for (int j = 0; j < ILP; ++j)
316-
{
317-
tmpGradOutput[j] = gradOutput[offset + j * blockDim.x];
318-
tmpOutput[j] = output[offset + j * blockDim.x];
319-
}
320-
321-
#pragma unroll
322-
for (int j = 0; j < ILP; ++j)
323-
{
324-
gradInput[offset + j * blockDim.x] =
325-
tmpGradOutput[j] - THCNumerics<T>::exp(tmpOutput[j]) * sum_k;
326-
}
327-
}
328-
329-
for (; offset < classes; offset += blockDim.x)
330-
{
331-
gradInput[offset] =
332-
gradOutput[offset] - THCNumerics<T>::exp(output[offset]) * sum_k;
333-
}
334-
}
335-
33630
#include "generic/LogSoftMax.cu"
33731
#include "THCGenerateFloatTypes.h"

0 commit comments

Comments
 (0)