forked from abarankab/DDPM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscript_utils.py
122 lines (100 loc) · 2.89 KB
/
script_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import torchvision
import torch.nn.functional as F
from .unet import UNet
from .diffusion import (
GaussianDiffusion,
generate_linear_schedule,
generate_cosine_schedule,
)
def cycle(dl):
"""
https://github.com/lucidrains/denoising-diffusion-pytorch/
"""
while True:
for data in dl:
yield data
def get_transform():
class RescaleChannels(object):
def __call__(self, sample):
return 2 * sample - 1
return torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
RescaleChannels(),
])
def str2bool(v):
"""
https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
"""
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("boolean value expected")
def add_dict_to_argparser(parser, default_dict):
"""
https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/script_util.py
"""
for k, v in default_dict.items():
v_type = type(v)
if v is None:
v_type = str
elif isinstance(v, bool):
v_type = str2bool
parser.add_argument(f"--{k}", default=v, type=v_type)
def diffusion_defaults():
defaults = dict(
num_timesteps=1000,
schedule="linear",
loss_type="l2",
use_labels=False,
base_channels=128,
channel_mults=(1, 2, 2, 2),
num_res_blocks=2,
time_emb_dim=128 * 4,
norm="gn",
dropout=0.1,
activation="silu",
attention_resolutions=(1,),
ema_decay=0.9999,
ema_update_rate=1,
)
return defaults
def get_diffusion_from_args(args):
activations = {
"relu": F.relu,
"mish": F.mish,
"silu": F.silu,
}
model = UNet(
img_channels=3,
base_channels=args.base_channels,
channel_mults=args.channel_mults,
time_emb_dim=args.time_emb_dim,
norm=args.norm,
dropout=args.dropout,
activation=activations[args.activation],
attention_resolutions=args.attention_resolutions,
num_classes=None if not args.use_labels else 10,
initial_pad=0,
)
if args.schedule == "cosine":
betas = generate_cosine_schedule(args.num_timesteps)
else:
betas = generate_linear_schedule(
args.num_timesteps,
args.schedule_low * 1000 / args.num_timesteps,
args.schedule_high * 1000 / args.num_timesteps,
)
diffusion = GaussianDiffusion(
model, (32, 32), 3, 10,
betas,
ema_decay=args.ema_decay,
ema_update_rate=args.ema_update_rate,
ema_start=2000,
loss_type=args.loss_type,
)
return diffusion