diff --git a/model_average.py b/model_average.py index c314bff..aecf1a4 100644 --- a/model_average.py +++ b/model_average.py @@ -65,12 +65,12 @@ def train(gpu, args): actual_lr = args.lr * sample_ray_num / 512 # bigger batch -> higher lr (linearity) ma_epoch = args.ma_epoch ma_method = args.ma_method - group = None if not args.group else args.group train_cnt, ep_start = None, None rank = args.nr * args.gpus + gpu - dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank, group_name = group) + dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank) + process_group = dist.new_group(backend = 'nccl') torch.cuda.set_device(gpu) for folder in ("./output/", "./check_points/", "./model/"): @@ -229,31 +229,31 @@ def run(): train_sampler.set_epoch(train_cnt) if ep % ma_epoch == 0: # double barrier to ensure synchronized sending / receiving - dist.barrier() + dist.barrier(group = process_group) comm_timer.tic() print(f"Using model average, method: {args.ma_method}... ", end = '') if ma_method == 'p2p': # This is a serialized reduce - broadcast (a central node exists) if rank == 0: - param_recv_avg(mip_net, container, model_weights, [1, 2, 3], group = group) + param_recv_avg(mip_net, container, model_weights, [1, 2, 3], group = process_group) # Receive from multiple nodes - param_send(mip_net, dist_ranks = [1, 2, 3], group = group) + param_send(mip_net, dist_ranks = [1, 2, 3], group = process_group) else: - param_send(mip_net, dist_ranks = [0], group = group) + param_send(mip_net, dist_ranks = [0], group = process_group) # Receive from only one node - param_recv(mip_net, source_rank = 0, group = group) + param_recv(mip_net, source_rank = 0, group = process_group) elif ma_method == 'broadcast': # reduce-broadcast (one of the node is the bottleneck) - param_reduce(mip_net, model_weights, rank, 0, group = group) - param_broadcast(mip_net, 0, group = group) + param_reduce(mip_net, model_weights, rank, 0, group = process_group) + param_broadcast(mip_net, 0, group = process_group) elif ma_method == 'all_reduce': # all-reduce (one-step reduce-broadcast) for param in mip_net.parameters(): param.data *= model_weights[rank] - param_all_reduce(mip_net, group = group) + param_all_reduce(mip_net, group = process_group) else: # TODO: more delicate communication strategy should be implemented # This is basically the case with correlated camera poses pass - dist.barrier() + dist.barrier(group = process_group) comm_timer.toc() mean_comm_time = comm_timer.get_mean_time() writer.add_scalar('Time/comm time', mean_comm_time, train_cnt) @@ -312,8 +312,6 @@ def main(): parser.add_argument('--ma_method', choices=['p2p', 'broadcast', 'delicate', 'all_reduce'], type = str, default = 'p2p', help='Model average strategies') - parser.add_argument('--group', default="", type=str, - help='Name of the group') parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('-g', '--gpus', default=1, type=int,