Skip to content

Commit

Permalink
Debugged grouping.
Browse files Browse the repository at this point in the history
  • Loading branch information
Enigmatisms committed Jun 5, 2023
1 parent 40f8e83 commit e139706
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions model_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def train(gpu, args):
actual_lr = args.lr * sample_ray_num / 512 # bigger batch -> higher lr (linearity)
ma_epoch = args.ma_epoch
ma_method = args.ma_method
group = None if not args.group else args.group

train_cnt, ep_start = None, None

rank = args.nr * args.gpus + gpu
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank, group_name = group)
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
process_group = dist.new_group(backend = 'nccl')
torch.cuda.set_device(gpu)

for folder in ("./output/", "./check_points/", "./model/"):
Expand Down Expand Up @@ -229,31 +229,31 @@ def run():
train_sampler.set_epoch(train_cnt)
if ep % ma_epoch == 0:
# double barrier to ensure synchronized sending / receiving
dist.barrier()
dist.barrier(group = process_group)
comm_timer.tic()
print(f"Using model average, method: {args.ma_method}... ", end = '')
if ma_method == 'p2p':
# This is a serialized reduce - broadcast (a central node exists)
if rank == 0:
param_recv_avg(mip_net, container, model_weights, [1, 2, 3], group = group)
param_recv_avg(mip_net, container, model_weights, [1, 2, 3], group = process_group)
# Receive from multiple nodes
param_send(mip_net, dist_ranks = [1, 2, 3], group = group)
param_send(mip_net, dist_ranks = [1, 2, 3], group = process_group)
else:
param_send(mip_net, dist_ranks = [0], group = group)
param_send(mip_net, dist_ranks = [0], group = process_group)
# Receive from only one node
param_recv(mip_net, source_rank = 0, group = group)
param_recv(mip_net, source_rank = 0, group = process_group)
elif ma_method == 'broadcast': # reduce-broadcast (one of the node is the bottleneck)
param_reduce(mip_net, model_weights, rank, 0, group = group)
param_broadcast(mip_net, 0, group = group)
param_reduce(mip_net, model_weights, rank, 0, group = process_group)
param_broadcast(mip_net, 0, group = process_group)
elif ma_method == 'all_reduce': # all-reduce (one-step reduce-broadcast)
for param in mip_net.parameters():
param.data *= model_weights[rank]
param_all_reduce(mip_net, group = group)
param_all_reduce(mip_net, group = process_group)
else:
# TODO: more delicate communication strategy should be implemented
# This is basically the case with correlated camera poses
pass
dist.barrier()
dist.barrier(group = process_group)
comm_timer.toc()
mean_comm_time = comm_timer.get_mean_time()
writer.add_scalar('Time/comm time', mean_comm_time, train_cnt)
Expand Down Expand Up @@ -312,8 +312,6 @@ def main():
parser.add_argument('--ma_method', choices=['p2p', 'broadcast', 'delicate', 'all_reduce'], type = str, default = 'p2p',
help='Model average strategies')

parser.add_argument('--group', default="", type=str,
help='Name of the group')
parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-g', '--gpus', default=1, type=int,
Expand Down

0 comments on commit e139706

Please sign in to comment.