-
Notifications
You must be signed in to change notification settings - Fork 23
/
train.py
89 lines (79 loc) · 2.41 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact [email protected]
#
import os
import torch
import sys
import json
from utils.general_utils import safe_state, init_distributed
import utils.general_utils as utils
from argparse import ArgumentParser
from arguments import (
AuxiliaryParams,
ModelParams,
PipelineParams,
OptimizationParams,
DistributionParams,
BenchmarkParams,
DebugParams,
print_all_args,
init_args,
)
import train_internal
if __name__ == "__main__":
# Set up command line argument parser
parser = ArgumentParser(description="Training script parameters")
ap = AuxiliaryParams(parser)
lp = ModelParams(parser)
op = OptimizationParams(parser)
pp = PipelineParams(parser)
dist_p = DistributionParams(parser)
bench_p = BenchmarkParams(parser)
debug_p = DebugParams(parser)
args = parser.parse_args(sys.argv[1:])
# Set up distributed training
init_distributed(args)
## Prepare arguments.
# Check arguments
init_args(args)
args = utils.get_args()
# create log folder
if utils.GLOBAL_RANK == 0:
os.makedirs(args.log_folder, exist_ok=True)
os.makedirs(args.model_path, exist_ok=True)
if utils.WORLD_SIZE > 1:
torch.distributed.barrier(
group=utils.DEFAULT_GROUP
) # log_folder is created before other ranks start writing log.
if utils.GLOBAL_RANK == 0:
with open(args.log_folder + "/args.json", "w") as f:
json.dump(vars(args), f)
# Initialize system state (RNG)
safe_state(args.quiet)
torch.autograd.set_detect_anomaly(args.detect_anomaly)
# Initialize log file and print all args
log_file = open(
args.log_folder
+ "/python_ws="
+ str(utils.WORLD_SIZE)
+ "_rk="
+ str(utils.GLOBAL_RANK)
+ ".log",
"a" if args.auto_start_checkpoint else "w",
)
utils.set_log_file(log_file)
print_all_args(args, log_file)
train_internal.training(
lp.extract(args), op.extract(args), pp.extract(args), args, log_file
)
# All done
if utils.WORLD_SIZE > 1:
torch.distributed.barrier(group=utils.DEFAULT_GROUP)
utils.print_rank_0("\nTraining complete.")