-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain_ssr.py
75 lines (59 loc) · 2.68 KB
/
train_ssr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import argparse
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchlight.utils import instantiate
from torchlight.nn.utils import get_learning_rate
from hsir.data.ssr.dataset import TrainDataset, ValidDataset
from hsir.trainer import Trainer
def train_cfg():
parser = argparse.ArgumentParser()
parser.add_argument('--arch', '-a', required=True)
parser.add_argument('--name', '-n', type=str, default=None,
help='name of the experiment, if not specified, arch will be used.')
parser.add_argument('--lr', type=float, default=4e-4)
parser.add_argument('--bs', type=int, default=20)
parser.add_argument('--epochs', type=int, default=5)
parser.add_argument('--schedule', type=str, default='hsir.schedule.denoise_default')
parser.add_argument('--resume', '-r', action='store_true')
parser.add_argument('--resume-path', '-rp', type=str, default=None)
parser.add_argument('--data-root', type=str, default='data/rgb2hsi')
parser.add_argument('--data-size', type=int, default=None)
parser.add_argument('--save-root', type=str, default='checkpoints/ssr')
parser.add_argument('--gpu-ids', type=str, default='0', help='gpu ids')
cfg = parser.parse_args()
cfg.gpu_ids = [int(id) for id in cfg.gpu_ids.split(',')]
cfg.name = cfg.arch if cfg.name is None else cfg.name
return cfg
def main():
cfg = train_cfg()
net = instantiate(cfg.arch)
trainer = Trainer(
net,
lr=cfg.lr,
save_dir=os.path.join(cfg.save_root, cfg.name),
gpu_ids=cfg.gpu_ids,
)
trainer.logger.print(cfg)
if cfg.resume: trainer.load(cfg.resume_path)
dataset = TrainDataset(cfg.data_root, size=cfg.data_size, stride=64)
train_loader = DataLoader(dataset, batch_size=cfg.bs, shuffle=True, num_workers=8, pin_memory=True)
dataset = ValidDataset(cfg.data_root)
val_loader = DataLoader(dataset, batch_size=1)
"""Main loop"""
# lr_scheduler = CosineAnnealingLR(trainer.optimizer, cfg.max_epochs, eta_min=1e-6)
epoch_per_save = 10
best_psnr = 0
while trainer.epoch < cfg.epochs:
trainer.logger.print('Epoch [{}] Use lr={}'.format(trainer.epoch, get_learning_rate(trainer.optimizer)))
trainer.train(train_loader)
# save ckpt
trainer.save_checkpoint('model_latest.pth')
metrics = trainer.validate(val_loader, 'NITRE')
if metrics['psnr'] > best_psnr:
best_psnr = metrics['psnr']
trainer.save_checkpoint('model_best.pth')
if trainer.epoch % epoch_per_save == 0:
trainer.save_checkpoint()
if __name__ == '__main__':
main()