forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcuda_lu_pivot_kernels.cu.cc
77 lines (63 loc) · 2.8 KB
/
cuda_lu_pivot_kernels.cu.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
/* Copyright 2021 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/cuda_lu_pivot_kernels.h"
#include <array>
#include <iostream>
namespace jax {
namespace {
__device__ void ComputePermutation(const std::int32_t* pivots,
std::int32_t* permutation_out,
const std::int32_t pivot_size,
const std::int32_t permutation_size) {
for (int i = 0; i < permutation_size; ++i) {
permutation_out[i] = i;
}
// Compute the permutation from a sequence of transpositions encoded in the
// pivot array by applying the transpositions in order on the identity
// permutation.
for (int i = 0; i < pivot_size; ++i) {
if ((pivots[i] < 0) || (pivots[i] >= permutation_size)) {
continue;
}
std::int32_t swap_temporary = permutation_out[i];
permutation_out[i] = permutation_out[pivots[i]];
permutation_out[pivots[i]] = swap_temporary;
}
}
__global__ void LuPivotsToPermutationKernel(
const std::int32_t* pivots, std::int32_t* permutation_out,
const std::int64_t batch_size, const std::int32_t pivot_size,
const std::int32_t permutation_size) {
for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < batch_size; idx += blockDim.x * gridDim.x) {
// Fill in the output array with the identity permutation.
ComputePermutation(pivots + idx * pivot_size,
permutation_out + idx * permutation_size, pivot_size,
permutation_size);
}
}
} // namespace
void LaunchLuPivotsToPermutationKernel(
cudaStream_t stream, void** buffers,
LuPivotsToPermutationDescriptor descriptor) {
const std::int32_t* pivots =
reinterpret_cast<const std::int32_t*>(buffers[0]);
std::int32_t* permutation_out = reinterpret_cast<std::int32_t*>(buffers[1]);
const int block_dim = 128;
const std::int64_t grid_dim = std::min<std::int64_t>(
1024, (descriptor.batch_size + block_dim - 1) / block_dim);
LuPivotsToPermutationKernel<<<grid_dim, block_dim,
/*dynamic_shared_mem_bytes=*/0, stream>>>(
pivots, permutation_out, descriptor.batch_size, descriptor.pivot_size,
descriptor.permutation_size);
}
} // namespace jax