Skip to content

Commit

Permalink
hijack ancestral sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py committed May 28, 2023
1 parent a506835 commit 32ca34d
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 3 deletions.
32 changes: 29 additions & 3 deletions modules/diffusion/upscalers/multidiffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
AutoencoderKL,
DDPMScheduler,
UNet2DConditionModel,
EulerAncestralDiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
)
from tqdm import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from api.events.generation import UNetDenoisingEvent
from .samplers import EulerAncestralSampler, KDPM2AncestralSampler


class Multidiffusion:
Expand All @@ -25,6 +28,21 @@ def __init__(
self.scheduler: DDPMScheduler = pipe.scheduler
self.plugin_data = pipe.plugin_data
self.opts = pipe.opts
self.ancestral = False

def hijack_ancestral_scheduler(self) -> bool:
if isinstance(self.scheduler, EulerAncestralDiscreteScheduler):
config = copy.deepcopy(self.scheduler.__dict__)
self.scheduler = EulerAncestralSampler.from_config(self.scheduler.config)
self.scheduler.__dict__.update(config)
return True
elif isinstance(self.scheduler, KDPM2AncestralDiscreteScheduler):
config = copy.deepcopy(self.scheduler.__dict__)
self.scheduler = KDPM2AncestralSampler.from_config(self.scheduler.config)
self.scheduler.__dict__.update(config)
return True
else:
return False

@classmethod
def get_views(cls, panorama_height, panorama_width, window_size=64, stride=8):
Expand Down Expand Up @@ -57,6 +75,8 @@ def views_denoise_latent(
callback_steps: int,
cross_attention_kwargs: Dict[str, Any],
):
# hijack ancestral schedulers
self.ancestral = self.hijack_ancestral_scheduler()
# 6. Define panorama grid and initialize views for synthesis.
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views)
count = torch.zeros_like(latents)
Expand All @@ -67,6 +87,7 @@ def views_denoise_latent(
for step, timestep in enumerate(timesteps):
count.zero_()
value.zero_()
noise = torch.randn_like(latents)
for j, (h_start, h_end, w_start, w_end) in enumerate(views):
# get the latents corresponding to the current view coordinates
latents_for_view = latents[:, :, h_start:h_end, w_start:w_end]
Expand Down Expand Up @@ -118,12 +139,14 @@ def views_denoise_latent(
)

# compute the previous noisy sample x_t -> x_t-1
latents_view_denoised = self.scheduler.step(
scheduler_output = self.scheduler.step(
model_output=noise_pred,
timestep=timestep,
sample=latents_for_view,
**extra_step_kwargs,
).prev_sample
)
latents_view_denoised = scheduler_output.prev_sample
sigma_up = scheduler_output.sigma_up if self.ancestral else 0

views_scheduler_status[j] = copy.deepcopy(
self.scheduler.__dict__
Expand All @@ -133,7 +156,10 @@ def views_denoise_latent(
count[:, :, h_start:h_end, w_start:w_end] += 1

# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = torch.where(count > 0, value / count, value)
# add noise for ancestral sampler, noise * sigma_up should be 0 for discrete sampler
latents = (
torch.where(count > 0, value / count, value) + noise * sigma_up
)

# call the callback, if provided
if step == len(timesteps) - 1 or (
Expand Down
135 changes: 135 additions & 0 deletions modules/diffusion/upscalers/samplers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Fake Ancestral sampler for multidiffusion
# (random noise was removed and sigma_up was extracted)
from dataclasses import dataclass
from typing import *

import torch
from torch import FloatTensor
from diffusers import (
EulerAncestralDiscreteScheduler,
KDPM2AncestralDiscreteScheduler,
)


@dataclass
class SamplerOutput:
prev_sample: torch.FloatTensor
sigma_up: torch.FloatTensor


class EulerAncestralSampler(EulerAncestralDiscreteScheduler):
def step(
self,
model_output: torch.FloatTensor,
timestep: float | torch.FloatTensor,
sample: torch.FloatTensor,
return_dict: bool = True,
) -> SamplerOutput | Tuple:
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (
sample / (sigma**2 + 1)
)
elif self.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample")
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)

sigma_from = self.sigmas[step_index]
sigma_to = self.sigmas[step_index + 1]
sigma_up = (
sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2
) ** 0.5
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5

# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma

dt = sigma_down - sigma

prev_sample = sample + derivative * dt

if not return_dict:
return (prev_sample, sigma_up)

return SamplerOutput(prev_sample=prev_sample, sigma_up=sigma_up)


class KDPM2AncestralSampler(KDPM2AncestralDiscreteScheduler):
def step(
self,
model_output: FloatTensor,
timestep: float | FloatTensor,
sample: FloatTensor,
return_dict: bool = True,
) -> SamplerOutput:
step_index = self.index_for_timestep(timestep)

if self.state_in_first_order:
sigma = self.sigmas[step_index]
sigma_interpol = self.sigmas_interpol[step_index]
sigma_up = self.sigmas_up[step_index]
sigma_down = self.sigmas_down[step_index - 1]
else:
# 2nd order / KPDM2's method
sigma = self.sigmas[step_index - 1]
sigma_interpol = self.sigmas_interpol[step_index - 1]
sigma_up = self.sigmas_up[step_index - 1]
sigma_down = self.sigmas_down[step_index - 1]

# currently only gamma=0 is supported. This usually works best anyways.
# We can support gamma in the future but then need to scale the timestep before
# passing it to the model which requires a change in API
gamma = 0
sigma_hat = sigma * (gamma + 1) # Note: sigma_hat == sigma for now

# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
pred_original_sample = sample - sigma_input * model_output
elif self.config.prediction_type == "v_prediction":
sigma_input = sigma_hat if self.state_in_first_order else sigma_interpol
pred_original_sample = model_output * (
-sigma_input / (sigma_input**2 + 1) ** 0.5
) + (sample / (sigma_input**2 + 1))
elif self.config.prediction_type == "sample":
raise NotImplementedError("prediction_type not implemented yet: sample")
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)

if self.state_in_first_order:
# 2. Convert to an ODE derivative for 1st order
derivative = (sample - pred_original_sample) / sigma_hat
# 3. delta timestep
dt = sigma_interpol - sigma_hat

# store for 2nd order step
self.sample = sample
self.dt = dt
prev_sample = sample + derivative * dt
else:
# DPM-Solver-2
# 2. Convert to an ODE derivative for 2nd order
derivative = (sample - pred_original_sample) / sigma_interpol
# 3. delta timestep
dt = sigma_down - sigma_hat

sample = self.sample
self.sample = None

prev_sample = sample + derivative * dt

if not return_dict:
return (prev_sample, sigma_up)

return SamplerOutput(prev_sample=prev_sample, sigma_up=sigma_up)

0 comments on commit 32ca34d

Please sign in to comment.