forked from ohayonguy/PMRF
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
77 lines (65 loc) · 3.81 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from lightning_models.mmse_rectified_flow import MMSERectifiedFlow
import torch
from torchvision.transforms.functional import to_tensor
from torchvision.utils import save_image
import os
from torch.utils.data import DataLoader
from torch_datasets.image_folder_dataset import ImageFolderDataset
from tqdm import tqdm
import argparse
torch.set_float32_matmul_precision('high')
torch.set_grad_enabled(False)
def main(args):
ds = ImageFolderDataset(args.lq_data_path, degradation=lambda x: (x, None), transform=to_tensor)
dl = DataLoader(ds, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=10, pin_memory=True)
output_path = os.path.join(args.output_dir, 'restored_images')
os.makedirs(output_path, exist_ok=True)
if args.ckpt_path_is_huggingface:
model = MMSERectifiedFlow.from_pretrained(args.ckpt_path).cuda()
else:
ckpt = torch.load(args.ckpt_path, map_location="cpu")
mmse_model_arch = ckpt['hyper_parameters']['mmse_model_arch']
model = MMSERectifiedFlow.load_from_checkpoint(args.ckpt_path,
# Need to provide mmse_model_arch to
# make sure the model initializes it.
mmse_model_arch=mmse_model_arch,
mmse_model_ckpt_path=None, # Will ignore the original path of the
# MMSE model used for training,
# and instead load it from the model checkpoint.
map_location='cpu').cuda()
if model.ema_wanted:
model.ema.load_state_dict(ckpt['ema'])
model.ema.copy_to()
if model.mmse_model is not None:
output_path_mmse = os.path.join(args.output_dir, 'restored_images_posterior_mean')
os.makedirs(output_path_mmse, exist_ok=True)
torch.compile(model, mode='max-autotune')
print("Compiled model")
model.freeze()
for batch in tqdm(dl):
y = batch['y'].cuda()
dummy_x = batch['x'].cuda()
estimate = model.generate_reconstructions(dummy_x, y, None, args.num_flow_steps, torch.device("cpu"))[0]
for i in tqdm(range(y.shape[0])):
save_image(estimate[i], os.path.join(output_path, os.path.basename(batch['img_file_name'][i])))
if model.mmse_model is not None:
mmse_estimate = model.mmse_model(y)
for i in tqdm(range(y.shape[0])):
save_image(mmse_estimate[i],
os.path.join(output_path_mmse, os.path.basename(batch['img_file_name'][i])))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt_path', type=str, required=False,
default='./checkpoints/blind_face_restoration_pmrf.ckpt',
help='Path to the model checkpoint.')
parser.add_argument('--ckpt_path_is_huggingface', action='store_true', required=False, default=False,
help='Whether the ckpt path is a huggingface model or a path to a local file.')
parser.add_argument('--lq_data_path', type=str, required=True,
help='Path to a folder that contains low quality images.')
parser.add_argument('--output_dir', type=str, required=True,
help='Path to a folder where the reconstructed images will be saved.')
parser.add_argument('--num_flow_steps', type=int, required=False, default=25,
help='Number of flow steps to use for inference.')
parser.add_argument('--batch_size', type=int, required=False, default=64,
help='Batch size for inference.')
main(parser.parse_args())