Skip to content

Commit

Permalink
update: update infernce scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
Gal4way committed Sep 14, 2023
1 parent 85515bd commit 22df946
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 14 deletions.
7 changes: 6 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
/datasets
/datasets/
/checkpoints/
models/
utils/
__pycache__/
inference_log/
6 changes: 3 additions & 3 deletions commands/inference.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
python test.py --plms --gpu_id 0 \
--ddim_steps 100 \
--outdir inference_log/VITONHD \
--outdir inference_log/VITONHD/inference_10_results_100steps_with_random_startcode \
--config configs/viton512.yaml \
--ckpt checkpoints/viton512.ckpt \
--dataroot datasets/VITONHD \
--n_samples 8 \
--n_samples 5 \
--seed 23 \
--scale 1 \
--H 512 \
--W 512 \
--unpaired
# --paried
53 changes: 43 additions & 10 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def main():
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir

result_path = os.path.join(outpath, "result")
result_path = outpath
os.makedirs(result_path, exist_ok=True)

start_code = None
Expand Down Expand Up @@ -322,8 +322,8 @@ def main():
warp_feat = model.encode_first_stage(feat_tensor)
warp_feat = model.get_first_stage_encoding(warp_feat).detach()

ts = torch.full((1,), 999, device=device, dtype=torch.long)
start_code = model.q_sample(warp_feat, ts)
# ts = torch.full((1,), 999, device=device, dtype=torch.long)
# start_code = model.q_sample(warp_feat, ts)

shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
Expand All @@ -349,18 +349,51 @@ def main():

resize = transforms.Resize((opt.H, int(opt.H / 256 * 192)))


if not opt.skip_save:
# 创建一个字典来保存所有的图片
tensors_to_save = {
"mask": mask_tensor,
"inpaint_image": inpaint_image,
"ref": ref_tensor,
"feat": feat_tensor,
"source": image_tensor
}

def un_norm(x):
return (x + 1.0) / 2.0

return torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
for i, x_sample in enumerate(x_result):
filename = data['file_name'][i]
# filename = data['file_name']
save_x = resize(x_sample)
save_x = 255. * rearrange(save_x.cpu().numpy(), 'c h w -> h w c')
img = Image.fromarray(save_x.astype(np.uint8))
img.save(os.path.join(result_path, filename[:-4] + ".png"))

all_images = []
all_images.append(resize(un_norm(image_tensor[i])).cpu())
all_images.append(resize(un_norm(mask_tensor[i].repeat(3, 1, 1))).cpu())
all_images.append(resize(un_norm(inpaint_image[i])).cpu())
all_images.append(resize(un_norm(feat_tensor[i])).cpu())
all_images.append(resize(un_norm(ref_tensor[i])).cpu())
all_images.append(resize(x_sample).cpu())
grid = torch.stack(all_images, 0)
grid = make_grid(grid)
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
img.save(os.path.join(result_path, filename[:-4] + "_grid.png"))


x_sample = resize(x_sample)
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = Image.fromarray(x_sample.astype(np.uint8))
x_sample.save(os.path.join(result_path, filename[:-4] + "_result.png"))

# 添加图片到列表

for key, tensors in tensors_to_save.items():
save_tensor = un_norm(tensors[i])
save_tensor = resize(save_tensor)
save_tensor = 255. * rearrange(save_tensor.cpu().numpy(), 'c h w -> h w c')
save_tensor = np.squeeze(save_tensor)
save_image = Image.fromarray(save_tensor.astype(np.uint8))
save_image.save(os.path.join(result_path, filename[:-4] + f"_{key}.png"))

print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f" \nEnjoy.")
Expand Down

0 comments on commit 22df946

Please sign in to comment.