forked from NVlabs/I2SB
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
177 lines (146 loc) · 8.01 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# ---------------------------------------------------------------
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for I2SB. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import sys
import random
import argparse
import copy
from pathlib import Path
from datetime import datetime
import numpy as np
import torch
from torch.multiprocessing import Process
from logger import Logger
from distributed_util import init_processes
from corruption import build_corruption
from dataset import imagenet
from i2sb import Runner, download_ckpt
import colored_traceback.always
from ipdb import set_trace as debug
RESULT_DIR = Path("results")
def set_seed(seed):
# https://github.com/pytorch/pytorch/issues/7068
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
def create_training_options():
# --------------- basic ---------------
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--name", type=str, default=None, help="experiment ID")
parser.add_argument("--ckpt", type=str, default=None, help="resumed checkpoint name")
parser.add_argument("--gpu", type=int, default=None, help="set only if you wish to run on a particular device")
parser.add_argument("--n-gpu-per-node", type=int, default=1, help="number of gpu on each node")
parser.add_argument("--master-address", type=str, default='localhost', help="address for master")
parser.add_argument("--node-rank", type=int, default=0, help="the index of node")
parser.add_argument("--num-proc-node", type=int, default=1, help="The number of nodes in multi node env")
# parser.add_argument("--amp", action="store_true")
# --------------- SB model ---------------
parser.add_argument("--image-size", type=int, default=256)
parser.add_argument("--corrupt", type=str, default=None, help="restoration task")
parser.add_argument("--t0", type=float, default=1e-4, help="sigma start time in network parametrization")
parser.add_argument("--T", type=float, default=1., help="sigma end time in network parametrization")
parser.add_argument("--interval", type=int, default=1000, help="number of interval")
parser.add_argument("--beta-max", type=float, default=0.3, help="max diffusion for the diffusion model")
# parser.add_argument("--beta-min", type=float, default=0.1)
parser.add_argument("--ot-ode", action="store_true", help="use OT-ODE model")
parser.add_argument("--clip-denoise", action="store_true", help="clamp predicted image to [-1,1] at each")
# optional configs for conditional network
parser.add_argument("--cond-x1", action="store_true", help="conditional the network on degraded images")
parser.add_argument("--add-x1-noise", action="store_true", help="add noise to conditional network")
# --------------- optimizer and loss ---------------
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--microbatch", type=int, default=2, help="accumulate gradient over microbatch until full batch-size")
parser.add_argument("--num-itr", type=int, default=1000000, help="training iteration")
parser.add_argument("--lr", type=float, default=5e-5, help="learning rate")
parser.add_argument("--lr-gamma", type=float, default=0.99, help="learning rate decay ratio")
parser.add_argument("--lr-step", type=int, default=1000, help="learning rate decay step size")
parser.add_argument("--l2-norm", type=float, default=0.0)
parser.add_argument("--ema", type=float, default=0.99)
# --------------- path and logging ---------------
parser.add_argument("--dataset-dir", type=Path, default="/dataset", help="path to LMDB dataset")
parser.add_argument("--log-dir", type=Path, default=".log", help="path to log std outputs and writer data")
parser.add_argument("--log-writer", type=str, default=None, help="log writer: can be tensorbard, wandb, or None")
parser.add_argument("--wandb-api-key", type=str, default=None, help="unique API key of your W&B account; see https://wandb.ai/authorize")
parser.add_argument("--wandb-user", type=str, default=None, help="user name of your W&B account")
opt = parser.parse_args()
# ========= auto setup =========
opt.device='cuda' if opt.gpu is None else f'cuda:{opt.gpu}'
if opt.name is None:
opt.name = opt.corrupt
opt.distributed = opt.n_gpu_per_node > 1
opt.use_fp16 = False # disable fp16 for training
# log ngc meta data
if "NGC_JOB_ID" in os.environ.keys():
opt.ngc_job_id = os.environ["NGC_JOB_ID"]
# ========= path handle =========
os.makedirs(opt.log_dir, exist_ok=True)
opt.ckpt_path = RESULT_DIR / opt.name
os.makedirs(opt.ckpt_path, exist_ok=True)
if opt.ckpt is not None:
ckpt_file = RESULT_DIR / opt.ckpt / "latest.pt"
assert ckpt_file.exists()
opt.load = ckpt_file
else:
opt.load = None
# ========= auto assert =========
assert opt.batch_size % opt.microbatch == 0, f"{opt.batch_size=} is not dividable by {opt.microbatch}!"
return opt
def main(opt):
log = Logger(opt.global_rank, opt.log_dir)
log.info("=======================================================")
log.info(" Image-to-Image Schrodinger Bridge")
log.info("=======================================================")
log.info("Command used:\n{}".format(" ".join(sys.argv)))
log.info(f"Experiment ID: {opt.name}")
# set seed: make sure each gpu has differnet seed!
if opt.seed is not None:
set_seed(opt.seed + opt.global_rank)
# build imagenet dataset
train_dataset = imagenet.build_lmdb_dataset(opt, log, train=True)
val_dataset = imagenet.build_lmdb_dataset(opt, log, train=False)
# note: images should be normalized to [-1,1] for corruption methods to work properly
if opt.corrupt == "mixture":
import corruption.mixture as mix
train_dataset = mix.MixtureCorruptDatasetTrain(opt, train_dataset)
val_dataset = mix.MixtureCorruptDatasetVal(opt, val_dataset)
# build corruption method
corrupt_method = build_corruption(opt, log)
run = Runner(opt, log)
run.train(opt, train_dataset, val_dataset, corrupt_method)
log.info("Finish!")
if __name__ == '__main__':
opt = create_training_options()
assert opt.corrupt is not None
# one-time download: ADM checkpoint
download_ckpt("data/")
if opt.distributed:
size = opt.n_gpu_per_node
processes = []
for rank in range(size):
opt = copy.deepcopy(opt)
opt.local_rank = rank
global_rank = rank + opt.node_rank * opt.n_gpu_per_node
global_size = opt.num_proc_node * opt.n_gpu_per_node
opt.global_rank = global_rank
opt.global_size = global_size
print('Node rank %d, local proc %d, global proc %d, global_size %d' % (opt.node_rank, rank, global_rank, global_size))
p = Process(target=init_processes, args=(global_rank, global_size, main, opt))
p.start()
processes.append(p)
for p in processes:
p.join()
else:
torch.cuda.set_device(0)
opt.global_rank = 0
opt.local_rank = 0
opt.global_size = 1
init_processes(0, opt.n_gpu_per_node, main, opt)