Skip to content

Commit

Permalink
Fix halo correction kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
thorjohnsen committed Apr 1, 2022
1 parent 60000f7 commit 705aa35
Showing 1 changed file with 33 additions and 19 deletions.
52 changes: 33 additions & 19 deletions apex/contrib/bottleneck/bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,17 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s
spatial_halo_exchanger.left_right_halo_exchange(out1[:,:,:1,:], out1[:,:,Hs-1:,:], top_out1_halo, btm_out1_halo)
if spatial_method == 1:
# overlap mid convolution with halo transfer
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
if explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
else:
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
if spatial_group_rank < spatial_group_size-1:
stream2.wait_stream(stream1)
with torch.cuda.stream(stream2):
Expand All @@ -280,17 +291,6 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s
btm_fat_halo[:,:,0:2,:].copy_(out1[:,:,Hs-2:,:])
btm_fat_halo[:,:,2:,:].copy_(btm_out1_halo)
btm_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, btm_fat_halo, args)
if spatial_group_rank > 0:
with torch.cuda.stream(stream1):
if explicit_nhwc:
top_fat_halo = torch.empty((N,3,W,C),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:1,:,:].copy_(top_out1_halo)
top_fat_halo[:,1:3,:,:].copy_(out1[:,:2,:,:])
else:
top_fat_halo = torch.empty((N,C,3,W),dtype=out1.dtype,device=out1.device)
top_fat_halo[:,:,:1,:].copy_(top_out1_halo)
top_fat_halo[:,:,1:3,:].copy_(out1[:,:,:2,:])
top_out2 = fast_bottleneck.forward_out2_halo(explicit_nhwc, top_fat_halo, args)
inc.add_delay(10)
elif spatial_method != 2 and spatial_method != 3:
assert(False), "spatial_method must be 1, 2 or 3"
Expand Down Expand Up @@ -324,21 +324,35 @@ def forward(ctx, spatial_group_size, spatial_group_rank, spatial_communicator, s
torch.cuda.current_stream().wait_stream(stream2)
btm_out2_halo.copy_(btm_out2)
elif spatial_method == 3:
# Note
# out2 halo correction cannot overlap with anything since it has
# to wait for out2_mask to finish, but itself has to finish before
# the first kernel of _forward_rest can launch.
# At least we can overlap the two halo correction kernels.
if spatial_group_rank > 0:
w1by3 = args[2][:,:,2:3,:].contiguous(memory_format=torch.preserve)
top_out1_halo = top_out1_halo.contiguous(memory_format=memory_format)
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.contiguous(memory_format=memory_format))
top_out2_halo.copy_(top_out2)
stream1.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
with torch.cuda.stream(stream1):
w1by3 = args[2][:,:1,:,:].clone()
top_out1_halo = top_out1_halo.clone()
top_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, top_out1_halo, args, w1by3, top_out2_halo.clone())
top_out2_halo.copy_(top_out2)
if spatial_group_rank < spatial_group_size-1:
w1by3 = args[2][:,:,:1,:].contiguous(memory_format=torch.preserve)
btm_out1_halo = btm_out1_halo.contiguous(memory_format=memory_format)
btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.contiguous(memory_format=memory_format))
btm_out2_halo.copy_(btm_out2)
stream2.wait_stream(torch.cuda.current_stream()) # wait for *_out2_mask to finish
with torch.cuda.stream(stream2):
w1by3 = args[2][:,2:3,:,:].clone()
btm_out1_halo = btm_out1_halo.clone()
btm_out2 = fast_bottleneck.forward_out2_halo_corr(explicit_nhwc, btm_out1_halo, args, w1by3, btm_out2_halo.clone())
btm_out2_halo.copy_(btm_out2)
if spatial_group_rank > 0:
torch.cuda.current_stream().wait_stream(stream1)
if spatial_group_rank < spatial_group_size-1:
torch.cuda.current_stream().wait_stream(stream2)

fast_bottleneck.forward_rest(explicit_nhwc, stride_1x1, args, outputs)
# save halos for backward pass
if spatial_group_size > 1:
if spatial_method != 2:
# make sure copy of mid-section of out1 into out1_pad is done before exiting
torch.cuda.current_stream().wait_stream(stream3)
ctx.save_for_backward(*(args+outputs+[out1_pad,]))
else:
Expand Down

0 comments on commit 705aa35

Please sign in to comment.