Skip to content

Commit

Permalink
update the default naming rules
Browse files Browse the repository at this point in the history
  • Loading branch information
tqch committed Sep 18, 2022
1 parent c4fb319 commit 3bc977a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 30 deletions.
11 changes: 6 additions & 5 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,22 @@
parser.add_argument("--row-batch-size", default=10000, type=int)
parser.add_argument("--col-batch-size", default=10000, type=int)
parser.add_argument("--device", default="cuda:0", type=str)
parser.add_argument("--eval-dir", default="./eval")
parser.add_argument("--eval-dir", default="./images/eval")
parser.add_argument("--precomputed-dir", default="./precomputed", type=str)
parser.add_argument("--metrics", nargs="+", default=["fid", "pr"], type=str)
parser.add_argument("--seed", default=1234, type=int)
parser.add_argument("--affix", default="", type=str)
parser.add_argument("--folder-name", default="", type=str)

args = parser.parse_args()

root = os.path.expanduser(args.root)
dataset = args.dataset
affix = args.affix
print(f"Dataset: {dataset}")

eval_dir = args.eval_dir
img_dir = os.path.join(eval_dir, dataset + affix)
img_dir = eval_dir = args.eval_dir
folder_name = args.folder_name
if folder_name:
img_dir = os.path.join(img_dir, folder_name)
device = torch.device(args.device)

args = parser.parse_args()
Expand Down
24 changes: 16 additions & 8 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
parser.add_argument("--total-size", default=50000, type=int)
parser.add_argument("--config-dir", default="./configs", type=str)
parser.add_argument("--chkpt-dir", default="./chkpts", type=str)
parser.add_argument("--save-dir", default="./eval", type=str)
parser.add_argument("--chkpt-path", default="", type=str)
parser.add_argument("--save-dir", default="./images/eval", type=str)
parser.add_argument("--device", default="cuda:0", type=str)
parser.add_argument("--use-ema", action="store_true")
parser.add_argument("--use-ddim", action="store_true")
parser.add_argument("--eta", default=0., type=float)
parser.add_argument("--skip-schedule", default="linear", type=str)
parser.add_argument("--subseq-size", default=10, type=int)
parser.add_argument("--affix", default="", type=str)

args = parser.parse_args()

Expand All @@ -38,7 +40,6 @@
with open(os.path.join(config_dir, dataset + ".json")) as f:
configs = json.load(f)


diffusion_kwargs = configs["diffusion"]
beta_schedule = diffusion_kwargs.pop("beta_schedule")
beta_start = diffusion_kwargs.pop("beta_start")
Expand All @@ -57,22 +58,27 @@
else:
diffusion = GaussianDiffusion(betas, **diffusion_kwargs)

device = torch.device(args.device)
model = UNet(out_channels=in_channels, **configs["denoise"])
model.to(device)
chkpt_dir = args.chkpt_dir
chkpt_path = os.path.join(chkpt_dir, f"{dataset}_diffusion.pt")
chkpt_path = args.chkpt_path or os.path.join(chkpt_dir, f"ddpm_{dataset}.pt")
folder_name = os.path.basename(chkpt_path)[:-3] # truncated at file extension
use_ema = args.use_ema
if use_ema:
model.load_state_dict(torch.load(chkpt_path)["ema"]["shadow"])
state_dict = torch.load(chkpt_path, map_location=device)["ema"]["shadow"]
else:
model.load_state_dict(torch.load(chkpt_path)["model"])
device = torch.device(args.device)
model.to(device)
state_dict = torch.load(chkpt_path, map_location=device)["model"]
for k in list(state_dict.keys()):
if k.split(".")[0] == "module": # state_dict of DDP
state_dict[".".join(k.split(".")[1:])] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.eval()
for p in model.parameters():
if p.requires_grad:
p.requires_grad_(False)

folder_name = dataset + ("_ema" if use_ema else "") + ("_ddim" if use_ddim else "")
folder_name = folder_name + args.affix
save_dir = os.path.join(args.save_dir, folder_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
Expand All @@ -86,6 +92,8 @@ def save_image(arr):
with Image.fromarray(arr, mode="RGB") as im:
im.save(f"{save_dir}/{uuid.uuid4()}.png")

if torch.backends.cudnn.is_available():
torch.backends.cudnn.benchmark = True

with torch.inference_mode():
with ThreadPoolExecutor(max_workers=os.cpu_count()) as pool:
Expand Down
35 changes: 18 additions & 17 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from functools import partial


def logger(msg, **kwargs):
if dist.is_initialized() and dist.get_rank() == 0:
print(msg, **kwargs)


@errors.record
def main(args):

distributed = args.distributed

def logger(msg, **kwargs):
if not distributed or dist.get_rank() == 0:
print(msg, **kwargs)

root = os.path.expanduser(args.root)
dataset = args.dataset

Expand Down Expand Up @@ -61,18 +63,18 @@ def main(args):
out_channels = 2 * in_channels if model_var_type == "learned" else in_channels
_model = UNet(out_channels=out_channels, **configs["denoise"])

distributed = args.distributed
if distributed:
# check whether torch.distributed is available
# CUDA devices are required to run with NCCL backend
assert dist.is_available() and torch.cuda.is_available()
dist.init_process_group("nccl")
rank = dist.get_rank()
rank = dist.get_rank() # global process id across all node(s)
local_rank = int(os.environ["LOCAL_RANK"]) # local device id on a single node
_model = _model.to(rank)
model = DDP(_model, device_ids=[rank, ])
train_device = torch.device(f"cuda:{rank}")
model = DDP(_model, device_ids=[local_rank, ])
train_device = torch.device(f"cuda:{local_rank}")
else:
rank = 0
rank = local_rank = 0 # main process by default
model = _model.to(train_device)

optimizer = Adam(model.parameters(), lr=lr, betas=(beta1, beta2))
Expand Down Expand Up @@ -103,10 +105,7 @@ def main(args):
with open(os.path.join(chkpt_dir, f"exp_{timestamp}.info"), "w") as f:
f.write(hps_info)

chkpt_path = os.path.join(
chkpt_dir,
f"{dataset}_diffusion.pt"
)
chkpt_path = os.path.join(chkpt_dir, f"ddpm_{dataset}.pt")
chkpt_intv = args.chkpt_intv
logger(f"Checkpoint will be saved to {os.path.abspath(chkpt_path)}", end=" ")
logger(f"every {chkpt_intv} epochs")
Expand Down Expand Up @@ -136,9 +135,11 @@ def main(args):
distributed=distributed
)
evaluator = Evaluator(dataset=dataset, device=eval_device) if args.eval else None
if args.resume:
# in case of elastic launch, resume should always be turned on
resume = args.resume or distributed
if resume:
try:
map_location = {"cuda:0": f"cuda:{rank}"} if distributed else None
map_location = {"cuda:0": f"cuda:{local_rank}"} if distributed else train_device
trainer.resume_from_chkpt(chkpt_path, map_location=map_location)
except FileNotFoundError:
logger("Checkpoint file does not exist!")
Expand Down Expand Up @@ -174,7 +175,7 @@ def main(args):
parser.add_argument("--num-workers", default=4, type=int, help="number of workers for data loading")
parser.add_argument("--train-device", default="cuda:0", type=str)
parser.add_argument("--eval-device", default="cuda:0", type=str)
parser.add_argument("--image-dir", default="./images", type=str)
parser.add_argument("--image-dir", default="./images/train", type=str)
parser.add_argument("--num-save-images", default=64, type=int, help="number of images to generate & save")
parser.add_argument("--config-dir", default="./configs", type=str)
parser.add_argument("--chkpt-dir", default="./chkpts", type=str)
Expand Down

0 comments on commit 3bc977a

Please sign in to comment.