Skip to content

Commit

Permalink
[update] add ddim_edit
Browse files Browse the repository at this point in the history
  • Loading branch information
CFGpp-diffusion authored Jun 12, 2024
1 parent 1fce099 commit 8e0314e
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions latent_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,59 @@ def sample(self,
img = (img / 2 + 0.5).clamp(0, 1)
return img.detach().cpu()

@register_solver("ddim_edit")
class EditWardSwapDDIM(InversionDDIM):
"""
Editing via WardSwap after inversion.
Useful for text-guided image editing.
"""
@torch.autocast(device_type='cuda', dtype=torch.float16)
def sample(self,
src_img,
cfg_guidance=7.5,
prompt=["","",""],
callback_fn=None,
**kwargs):

# Text embedding
uc, src_c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[1])
_, tgt_c = self.get_text_embed(null_prompt=prompt[0], prompt=prompt[2])

# Initialize zT
zt = self.initialize_latent(method='ddim',
src_img=src_img,
uc=uc,
c=src_c,
cfg_guidance=cfg_guidance)
# Sampling
pbar = tqdm(self.scheduler.timesteps, desc="DDIM-edit")
for step, t in enumerate(pbar):
at = self.alpha(t)
at_prev = self.alpha(t - self.skip)

with torch.no_grad():
noise_uc, noise_c = self.predict_noise(zt, t, uc, tgt_c)
noise_pred = noise_uc + cfg_guidance * (noise_c - noise_uc)

# tweedie
z0t = (zt - (1-at).sqrt() * noise_pred) / at.sqrt()

# add noise
zt = at_prev.sqrt() * z0t + (1-at_prev).sqrt() * noise_pred

if callback_fn is not None:
callback_kwargs = {'z0t': z0t.detach(),
'zt': zt.detach(),
'decode': self.decode}
callback_kwargs = callback_fn(step, t, callback_kwargs)
z0t = callback_kwargs["z0t"]
zt = callback_kwargs["zt"]

# for the last step, do not add noise
img = self.decode(z0t)
img = (img / 2 + 0.5).clamp(0, 1)
return img.detach().cpu()

###########################################
# CFG++ version
###########################################
Expand Down

0 comments on commit 8e0314e

Please sign in to comment.