Skip to content

Commit

Permalink
Working version of sweeps kernel with manual stride management.
Browse files Browse the repository at this point in the history
  • Loading branch information
insertinterestingnamehere committed Jun 16, 2021
1 parent 4fe2215 commit 45a4e9e
Showing 1 changed file with 34 additions and 41 deletions.
75 changes: 34 additions & 41 deletions benchmarks/sweeps/sweeps_numba_sgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
uint64_t = nb.uint64
float_t = nb.float32
float_t_npy = np.float32
float_t_nbytes = 4

# A hand-rolled version of ravel_multi_index for use inside a jitted function.
# This is special cased for 4D indexing since that's what we need.
Expand Down Expand Up @@ -100,8 +101,8 @@ def compute_new_scattering(sigma_s, I, coefs, new_sigma):
cp.cuda.cublas.sgemmStridedBatched(cublas_handle, transa, transb, m, n, k, alpha, A, lda, strideA, B, ldb, strideB, beta, C, ldc, strideC, batchCount)
#new_sigma[:] = I[1:-1,1:-1,1:-1] @ coefs

@cuda.jit(nb.void(uint_t[:], float_t[:,:,:,:,:], float_t[:], float_t[:,:,:,:], float_t[:], float_t[:,:], float_t, uint_t[:], float_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t))
def compute_fluxes(work_items, I, I_flat, sigma, sigma_flat, directions, sigma_a_s, tgroup_id, num_dirs_inv, x_offset, y_offset, z_offset, frequency_offset, sigma_I_offset, sigma_I_x_mul, direction_correction):
@cuda.jit(nb.void(uint_t[:], uint_t, uint_t, uint_t, uint_t, uint_t, float_t[:], float_t[:], float_t[:,:], float_t, uint_t[:], float_t, uint64_t, uint64_t, uint64_t, uint64_t, uint64_t))
def compute_fluxes(work_items, nx, ny, nz, num_dirs, num_groups, I_flat, sigma_flat, directions, sigma_a_s, tgroup_id, num_dirs_inv, I_s0, I_s1, I_s2, I_s3, I_s4):
block_base_index = cuda.shared.array((1,), uint_t)
if not cuda.threadIdx.x:
block_base_index[0] = cuda.atomic.add(tgroup_id, 0, 1) * uint_t(cuda.blockDim.x)
Expand All @@ -112,13 +113,9 @@ def compute_fluxes(work_items, I, I_flat, sigma, sigma_flat, directions, sigma_a
idx = work_items[work_item_idx]
sigma_flat_idx = idx
# TODO: The uint_t specs for nx, ny, and nz here drastically change the result. WHY?
nx = uint_t(sigma.shape[0])
ny = uint_t(sigma.shape[1])
nz = uint_t(sigma.shape[2])
num_dirs = uint_t(sigma.shape[3])
num_groups = uint_t(I.shape[4])
sigma_x, sigma_y, sigma_z, dir_idx = unravel_4d_index(nx, ny, nz, directions.shape[0], idx)
I_flat_idx = sigma_flat_idx * num_groups + sigma_I_x_mul * sigma_x + z_offset * sigma_y + sigma_I_offset - direction_correction * dir_idx
I_flat_base_idx = I_s0 * (sigma_x + 1) + I_s1 * (sigma_y + 1) + I_s2 * (sigma_z + 1)
I_flat_idx = I_s0 * (sigma_x + 1) + I_s1 * (sigma_y + 1) + I_s2 * (sigma_z + 1) + I_s3 * dir_idx
# Now change abstract indices in the iteration space into indices into I.
# This is necessary since the beginning and end of each spatial axis
# is used for boundary conditions.
Expand All @@ -134,9 +131,9 @@ def compute_fluxes(work_items, I, I_flat, sigma, sigma_flat, directions, sigma_a
x_neighbor_idx = ix + uint_t(1) if x_has_sign else ix - uint_t(1)
y_neighbor_idx = iy + uint_t(1) if y_has_sign else iy - uint_t(1)
z_neighbor_idx = iz + uint_t(1) if z_has_sign else iz - uint_t(1)
x_neighbor_flat_idx = I_flat_idx + x_offset if x_has_sign else I_flat_idx - x_offset
y_neighbor_flat_idx = I_flat_idx + y_offset if y_has_sign else I_flat_idx - y_offset
z_neighbor_flat_idx = I_flat_idx + z_offset if z_has_sign else I_flat_idx - z_offset
x_neighbor_flat_idx = I_flat_idx + I_s0 if x_has_sign else I_flat_idx - I_s0
y_neighbor_flat_idx = I_flat_idx + I_s1 if y_has_sign else I_flat_idx - I_s1
z_neighbor_flat_idx = I_flat_idx + I_s2 if z_has_sign else I_flat_idx - I_s2
x_coef = -float_t(nx) if x_has_sign else float_t(nx)
y_coef = -float_t(ny) if y_has_sign else float_t(ny)
z_coef = -float_t(nz) if z_has_sign else float_t(nz)
Expand All @@ -150,7 +147,7 @@ def compute_fluxes(work_items, I, I_flat, sigma, sigma_flat, directions, sigma_a
# to store the scattering terms. This sum just runs over the directions.
incoming_scattering = float_t(0.)
sigma_dir_block_idx = sigma_flat_idx - dir_idx
for j64 in range(uint_t(sigma.shape[3])):
for j64 in range(uint_t(num_dirs)):
j = uint_t(j64)
#incoming_scattering += sigma[sigma_x, sigma_y, sigma_z, j]
incoming_scattering += sigma_flat[sigma_dir_block_idx + j]
Expand All @@ -161,37 +158,37 @@ def compute_fluxes(work_items, I, I_flat, sigma, sigma_flat, directions, sigma_a
y_factor = y_coef * diry
z_factor = z_coef * dirz
div = float_t(1.) / denominator
x_neighbor_flat_idx_last = x_neighbor_flat_idx + (num_groups - 1) * num_dirs
y_neighbor_flat_idx_last = y_neighbor_flat_idx + (num_groups - 1) * num_dirs
z_neighbor_flat_idx_last = z_neighbor_flat_idx + (num_groups - 1) * num_dirs
x_neighbor_flat_idx_last = x_neighbor_flat_idx + (num_groups - 1) * I_s4
y_neighbor_flat_idx_last = y_neighbor_flat_idx + (num_groups - 1) * I_s4
z_neighbor_flat_idx_last = z_neighbor_flat_idx + (num_groups - 1) * I_s4
while True:
cuda.threadfence()
# Stop if the upstream neighbors aren't ready.
if (math.isnan(I[x_neighbor_idx, iy, iz, dir_idx, -1]) or
math.isnan(I[ix, y_neighbor_idx, iz, dir_idx, -1]) or
math.isnan(I[ix, iy, z_neighbor_idx, dir_idx, -1])):
continue
#if (math.isnan(I_flat[x_neighbor_flat_idx_last]) or
# math.isnan(I_flat[y_neighbor_flat_idx_last]) or
# math.isnan(I_flat[z_neighbor_flat_idx_last])):
#if (math.isnan(I[x_neighbor_idx, iy, iz, dir_idx, -1]) or
# math.isnan(I[ix, y_neighbor_idx, iz, dir_idx, -1]) or
# math.isnan(I[ix, iy, z_neighbor_idx, dir_idx, -1])):
# continue
if (math.isnan(I_flat[x_neighbor_flat_idx_last]) or
math.isnan(I_flat[y_neighbor_flat_idx_last]) or
math.isnan(I_flat[z_neighbor_flat_idx_last])):
continue
# For simplicity we're assuming all frequencies scatter the same, so
# sum across frequencies now.
for k64 in range(uint_t(I.shape[4])):
for k64 in range(num_groups):
k = uint_t(k64)
numerator = (incoming_scattering -
x_factor * I[x_neighbor_idx,iy,iz,dir_idx,k] -
y_factor * I[ix,y_neighbor_idx,iz,dir_idx,k] -
z_factor * I[ix,iy,z_neighbor_idx,dir_idx,k])
#numerator = (incoming_scattering -
# x_factor * I_flat[x_neighbor_flat_idx + k * num_dirs] -
# y_factor * I_flat[y_neighbor_flat_idx + k * num_dirs] -
# z_factor * I_flat[z_neighbor_flat_idx + k * num_dirs])
#numerator = (incoming_scattering -
# x_factor * I[x_neighbor_idx,iy,iz,dir_idx,k] -
# y_factor * I[ix,y_neighbor_idx,iz,dir_idx,k] -
# z_factor * I[ix,iy,z_neighbor_idx,dir_idx,k])
numerator = (incoming_scattering -
x_factor * I_flat[x_neighbor_flat_idx + k * I_s4] -
y_factor * I_flat[y_neighbor_flat_idx + k * I_s4] -
z_factor * I_flat[z_neighbor_flat_idx + k * I_s4])
flux = numerator * div
if k == I.shape[4] - uint_t(1):
if k == num_groups - uint_t(1):
cuda.threadfence()
I[ix,iy,iz,dir_idx,k] = flux
#I_flat[I_flat_idx] = flux
#I[ix,iy,iz,dir_idx,k] = flux
I_flat[I_flat_idx + k * I_s4] = flux
break

def sweep_step(work_items, tgroup_id, I, sigma, new_sigma, coefs, directions, sigma_a, sigma_s):
Expand All @@ -203,18 +200,14 @@ def sweep_step(work_items, tgroup_id, I, sigma, new_sigma, coefs, directions, si
assert I.strides[3] == 4
I_flat = np.swapaxes(I, 3, 4).ravel()
sigma_flat = sigma.ravel()
x_offset = I.shape[1] * I.shape[2] * I.shape[3] * I.shape[4]
y_offset = I.shape[2] * I.shape[3] * I.shape[4]
z_offset = I.shape[3] * I.shape[4]
# direction_offset = 1
sigma_I_offset = (I.shape[1] * I.shape[2] + I.shape[2] + 1) * z_offset
sigma_I_x_mul = (I.shape[1] + I.shape[2] + 1) * z_offset
direction_correction = (I.shape[4] - 1)
frequency_offset = I.shape[3]
cp.cuda.get_current_stream().synchronize()
start = perf_counter()
compute_fluxes[num_blocks, chunk_size, 0, uint_t_nbytes](work_items, I, I_flat, sigma, sigma_flat, directions, sigma_a + sigma_s, tgroup_id, 1. / I.shape[1], x_offset, y_offset, z_offset, frequency_offset, sigma_I_offset, sigma_I_x_mul, direction_correction)
cuda.profile_start()
compute_fluxes[num_blocks, chunk_size, 0, uint_t_nbytes](work_items, sigma.shape[0], sigma.shape[1], sigma.shape[2], sigma.shape[3], I.shape[4], I_flat, sigma_flat, directions, sigma_a + sigma_s, tgroup_id, 1. / I.shape[1], I.strides[0] // float_t_nbytes, I.strides[1] // float_t_nbytes, I.strides[2] // float_t_nbytes, I.strides[3] // float_t_nbytes, I.strides[4] // float_t_nbytes)
cp.cuda.get_current_stream().synchronize()
cuda.profile_stop()
stop = perf_counter()
print("sweep kernel time:", stop - start)
# Compute the scattering terms in the collision operator.
Expand Down

0 comments on commit 45a4e9e

Please sign in to comment.