forked from NVlabs/I2SB
-
Notifications
You must be signed in to change notification settings - Fork 0
/
compute_metrices.py
205 lines (168 loc) · 7 KB
/
compute_metrices.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# ---------------------------------------------------------------
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# for I2SB. To view a copy of this license, see the LICENSE file.
# ---------------------------------------------------------------
import os
import argparse
import random
from pathlib import Path
from easydict import EasyDict as edict
import pickle
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from logger import Logger
from evaluation.resnet import build_resnet50
from evaluation import fid_util
from i2sb import download
import colored_traceback.always
from ipdb import set_trace as debug
RESULT_DIR = Path("results")
ADM_IMG256_FID_TRAIN_REF_CKPT = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz"
def set_seed(seed):
# https://github.com/pytorch/pytorch/issues/7068
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
class NumpyDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.targets = torch.LongTensor(targets)
self.transform = transforms.ToTensor()
def __getitem__(self, index):
img_np = self.data[index]
y = self.targets[index]
if img_np.dtype == "uint8":
# transform gives [0,1]
img_t = self.transform(img_np) * 2 - 1
elif img_np.dtype == "float32":
# transform gives [0,255]
img_t = self.transform(img_np) / 127.5 - 1
# img_t: [-1,1]
return img_t, y
def __len__(self):
return len(self.data)
@torch.no_grad()
def compute_accu(opt, numpy_arr, numpy_label_arr, batch_size=256):
dataset = NumpyDataset(numpy_arr, numpy_label_arr)
loader = DataLoader(dataset,
batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=1, drop_last=False,
)
resnet = build_resnet50().to(opt.device)
correct = total = 0
for (x,y) in loader:
pred_y = resnet(x.to(opt.device))
_, predicted = torch.max(pred_y.cpu(), 1)
correct += (predicted==y).sum().item()
total += y.size(0)
accu = correct / total
return accu
def convert_to_numpy(t):
# t: [-1,1]
out = (t + 1) * 127.5
out = out.clamp(0, 255)
out = out.to(torch.uint8)
out = out.permute(0, 2, 3, 1) # batch, 256, 256, 3
out = out.contiguous()
return out.cpu().numpy() # [0, 255]
def find_recon_imgs_pts(opt, log):
sample_dir = RESULT_DIR / opt.ckpt / opt.sample_dir
recon_imgs_pt = sample_dir / "recon.pt"
if recon_imgs_pt.exists():
log.info(f"Found recon.pt in dir={str(sample_dir)}!")
return [recon_imgs_pt,]
log.info(f"Finding partition pt files in dir={str(sample_dir)} ...")
recon_imgs_pts = [pt for pt in sample_dir.glob(f'recon_*.pt')]
assert len(recon_imgs_pts) > 0, f"Found 0 file that matches '{str(sample_dir)}/recon_*.pt'!"
return recon_imgs_pts
def build_numpy_data(log, recon_imgs_pts):
arr = []
label_arr = []
for pt in recon_imgs_pts:
out = torch.load(pt, map_location="cpu")
arr.append(out['arr'])
label_arr.append(out['label_arr'])
log.info(f"pt file {str(pt.name)} contains {len(out['label_arr'])} data!")
arr = torch.cat(arr, dim=0)
label_arr = torch.cat(label_arr, dim=0)
assert len(arr) == len(label_arr)
# converet to numpy
numpy_arr = convert_to_numpy(arr)
numpy_label_arr = label_arr.cpu().numpy()
return numpy_arr, numpy_label_arr
def build_ref_opt(opt, ref_fid_fn):
split = ref_fid_fn.name[:-4].split("_")[-1]
image_size = int(ref_fid_fn.name[:-4].split("_")[-2])
assert opt.image_size == image_size
return edict(
mode=opt.mode,
split=split,
image_size=image_size,
dataset_dir=opt.dataset_dir,
)
def get_ref_fid(opt, log):
# get ref fid npz file
with open(RESULT_DIR / opt.ckpt / "options.pkl", "rb") as f:
ckpt_opt = pickle.load(f)
# we use train set for super-res and val set for the rest of the tasks
split = "train" if 'sr4x' in ckpt_opt.corrupt else "val"
# build npz file
if split == "train":
ref_fid_fn = Path("data/VIRTUAL_imagenet256_labeled.npz")
if not ref_fid_fn.exists():
log.info(f"Downloading {ref_fid_fn=} (this can take a while ...)")
download(ADM_IMG256_FID_TRAIN_REF_CKPT, ref_fid_fn)
elif split == "val":
ref_fid_fn = Path("data/fid_imagenet_256_val.npz")
if not ref_fid_fn.exists():
log.info(f"Generating {ref_fid_fn=} (this can take a while ...)")
ref_opt = build_ref_opt(opt, ref_fid_fn)
fid_util.compute_fid_ref_stat(ref_opt, log)
# load npz file
ref_fid = np.load(ref_fid_fn)
ref_mu, ref_sigma = ref_fid['mu'], ref_fid['sigma']
return ref_fid_fn, ref_mu, ref_sigma
def log_metrices(opt):
# setup
set_seed(opt.seed)
if opt.gpu is not None:
torch.cuda.set_device(opt.gpu)
log = Logger(0, ".log")
log.info(f"======== Compute metrices: {opt.ckpt=}, {opt.mode=} ========")
# find all recon pt files
recon_imgs_pts = find_recon_imgs_pts(opt, log)
log.info(f"Found {len(recon_imgs_pts)} pt files={[pt.name for pt in recon_imgs_pts]}")
# build torch array
numpy_arr, numpy_label_arr = build_numpy_data(log, recon_imgs_pts)
log.info(f"Collected {numpy_arr.shape=}, {numpy_label_arr.shape=}!")
# compute accu
accu = compute_accu(opt, numpy_arr, numpy_label_arr)
log.info(f"Accuracy={accu:.3f}!")
# load ref fid stat
ref_fid_fn, ref_mu, ref_sigma = get_ref_fid(opt, log)
log.info(f"Loaded FID reference statistics from {ref_fid_fn}!")
# compute fid
fid = fid_util.compute_fid_from_numpy(numpy_arr, ref_mu, ref_sigma, mode=opt.mode)
log.info(f"FID(w.r.t. {ref_fid_fn=})={fid:.1f}!")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--gpu", type=int, default=None, help="set only if you wish to run on a particular device")
parser.add_argument("--ckpt", type=str, default=None, help="the checkpoint name for which we wish to compute metrices")
parser.add_argument("--mode", type=str, default="legacy_pytorch", help="the FID computation mode used in clean-fid")
parser.add_argument("--dataset-dir", type=Path, default="/dataset", help="path to LMDB dataset")
parser.add_argument("--sample-dir", type=Path, default=None, help="directory where samples are stored")
parser.add_argument("--image-size", type=int, default=256)
arg = parser.parse_args()
opt = edict(
device="cuda",
)
opt.update(vars(arg))
log_metrices(opt)