Skip to content

Commit

Permalink
Changed LegacyDDPMDiscretization for sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
timudk committed Jun 30, 2023
1 parent 613af10 commit e9869d7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
2 changes: 0 additions & 2 deletions scripts/demo/streamlit_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,10 +325,8 @@ def init_sampling(

def get_discretization(discretization, key=1):
if discretization == "LegacyDDPMDiscretization":
use_new_range = st.checkbox(f"Start from highest noise level? #{key}", False)
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
"params": {"legacy_range": not use_new_range},
}
elif discretization == "EDMDiscretization":
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
Expand Down
37 changes: 24 additions & 13 deletions sgm/modules/diffusionmodules/discretizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,26 @@
from ...modules.diffusionmodules.util import make_beta_schedule


def generate_roughly_equally_spaced_steps(n, m):
# 0, ..., m - 1
m -= 1

# We are getting rid of leading 0 later, so increase steps
n += 1

# Calculate the step size
step = m / (n - 1)

# Generate the list
steps_reversed = [int(m - i * step) for i in range(n)]
steps = steps_reversed[::-1]

# Get rid of leading 0
steps = steps[1:]

return np.array(steps)


class Discretization:
def __call__(self, n, do_append_zero=True, device="cuda", flip=False):
sigmas = self.get_sigmas(n, device)
Expand Down Expand Up @@ -33,7 +53,6 @@ def __init__(
linear_start=0.00085,
linear_end=0.0120,
num_timesteps=1000,
legacy_range=True,
):
self.num_timesteps = num_timesteps
betas = make_beta_schedule(
Expand All @@ -42,23 +61,15 @@ def __init__(
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.to_torch = partial(torch.tensor, dtype=torch.float32)
self.legacy_range = legacy_range

def get_sigmas(self, n, device):
if n < self.num_timesteps:
c = self.num_timesteps // n

if self.legacy_range:
timesteps = np.asarray(list(range(0, self.num_timesteps, c)))
timesteps += 1 # Legacy LDM Hack
else:
timesteps = np.asarray(list(range(0, self.num_timesteps + 1, c)))
timesteps -= 1
timesteps = timesteps[1:]

timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
alphas_cumprod = self.alphas_cumprod[timesteps]
else:
elif n == self.num_timesteps:
alphas_cumprod = self.alphas_cumprod
else:
raise ValueError

to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
Expand Down

0 comments on commit e9869d7

Please sign in to comment.