Skip to content

Commit

Permalink
fix max grid limitation for shift kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
syurkevi authored and pavanky committed Aug 29, 2017
1 parent df95ec6 commit fad44c4
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/backend/cuda/kernel/shift.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ namespace cuda
const int blocksPerMatX, const int blocksPerMatY)
{
const int oz = blockIdx.x / blocksPerMatX;
const int ow = blockIdx.y / blocksPerMatY;
const int ow = (blockIdx.y + blockIdx.z * gridDim.y) / blocksPerMatY;

const int blockIdx_x = blockIdx.x - oz * blocksPerMatX;
const int blockIdx_y = blockIdx.y - ow * blocksPerMatY;
const int blockIdx_y = (blockIdx.y + blockIdx.z * gridDim.y) - ow * blocksPerMatY;

const int xx = threadIdx.x + blockIdx_x * blockDim.x;
const int yy = threadIdx.y + blockIdx_y * blockDim.y;
Expand Down Expand Up @@ -87,6 +87,10 @@ namespace cuda
blocksPerMatY * out.dims[3],
1);

const int maxBlocksY = cuda::getDeviceProp(cuda::getActiveDeviceId()).maxGridSize[1];
blocks.z = divup(blocks.y, maxBlocksY);
blocks.y = divup(blocks.y, blocks.z);

int sdims_[4];
// Need to do this because we are mapping output to input in the kernel
for(int i = 0; i < 4; i++) {
Expand Down
20 changes: 20 additions & 0 deletions test/shift.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,23 @@ TEST(Shift, CPP)
// Delete
delete[] outData;
}

TEST(Shift, MaxDim)
{
if (noDoubleTests<float>()) return;

const size_t largeDim = 65535 * 32 + 1 ;
const unsigned shift_x = 1;

af::array input = af::range(af::dim4(2, largeDim));
af::array output = af::shift(input, shift_x);

output = af::abs(input - output);
ASSERT_EQ(1.f, af::product<float>(output));

input = af::range(af::dim4(2, 1, 1, largeDim));
output = af::shift(input, shift_x);

output = af::abs(input - output);
ASSERT_EQ(1.f, af::product<float>(output));
}

0 comments on commit fad44c4

Please sign in to comment.