forked from tqch/ddpm-torch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate.py
97 lines (83 loc) · 3.95 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
if __name__ == "__main__":
import os
import json
import math
import uuid
import torch
from tqdm import trange
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
from ddpm_torch import *
from ddim import DDIM, get_selection_schedule
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument("--root", default="~/datasets", type=str)
parser.add_argument("--dataset", choices=["mnist", "cifar10", "celeba"], default="cifar10")
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--total-size", default=50000, type=int)
parser.add_argument("--config-dir", default="./configs", type=str)
parser.add_argument("--chkpt-dir", default="./chkpts", type=str)
parser.add_argument("--save-dir", default="./eval", type=str)
parser.add_argument("--device", default="cuda:0", type=str)
parser.add_argument("--use-ema", action="store_true")
parser.add_argument("--use-ddim", action="store_true")
parser.add_argument("--eta", default=0., type=float)
parser.add_argument("--skip-schedule", default="linear", type=str)
parser.add_argument("--subseq-size", default=10, type=int)
args = parser.parse_args()
dataset = args.dataset
root = os.path.expanduser("~/datasets")
in_channels = DATA_INFO[dataset]["channels"]
image_res = DATA_INFO[dataset]["resolution"][0]
config_dir = args.config_dir
with open(os.path.join(config_dir, dataset + ".json")) as f:
configs = json.load(f)
diffusion_kwargs = configs["diffusion"]
beta_schedule = diffusion_kwargs.pop("beta_schedule")
beta_start = diffusion_kwargs.pop("beta_start")
beta_end = diffusion_kwargs.pop("beta_end")
num_diffusion_timesteps = diffusion_kwargs.pop("timesteps")
betas = get_beta_schedule(beta_schedule, beta_start, beta_end, num_diffusion_timesteps)
use_ddim = args.use_ddim
if use_ddim:
diffusion_kwargs["model_var_type"] = "fixed-small"
skip_schedule = args.skip_schedule
eta = args.eta
subseq_size = args.subseq_size
subsequence = get_selection_schedule(skip_schedule, size=subseq_size, timesteps=num_diffusion_timesteps)
diffusion = DDIM(betas, **diffusion_kwargs, eta=eta, subsequence=subsequence)
else:
diffusion = GaussianDiffusion(betas, **diffusion_kwargs)
model = UNet(out_channels=in_channels, **configs["denoise"])
chkpt_dir = args.chkpt_dir
chkpt_path = os.path.join(chkpt_dir, f"{dataset}_diffusion.pt")
if args.use_ema:
model.load_state_dict(torch.load(chkpt_path)["ema"]["shadow"])
else:
model.load_state_dict(torch.load(chkpt_path)["model"])
device = torch.device(args.device)
model.to(device)
model.eval()
for p in model.parameters():
if p.requires_grad:
p.requires_grad_(False)
save_dir = os.path.join(args.save_dir, dataset)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
batch_size = args.batch_size
total_size = args.total_size
num_eval_batches = math.ceil(total_size / batch_size)
shape = (batch_size, 3, image_res, image_res)
def save_image(arr):
with Image.fromarray(arr, mode="RGB") as im:
im.save(f"{save_dir}/{uuid.uuid4()}.png")
with torch.inference_mode():
with ThreadPoolExecutor(max_workers=os.cpu_count()) as pool:
for i in trange(num_eval_batches):
if i == num_eval_batches - 1:
shape = (total_size - i * batch_size, 3, image_res, image_res)
x = diffusion.p_sample(model, shape=shape, device=device, noise=torch.randn(shape, device=device)).cpu()
else:
x = diffusion.p_sample(model, shape=shape, device=device, noise=torch.randn(shape, device=device)).cpu()
x = (x * 127.5 + 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1).numpy()
pool.map(save_image, list(x))