Skip to content

Commit

Permalink
Fixing additional GPU memory on device 0 due to discretization
Browse files Browse the repository at this point in the history
  • Loading branch information
timudk committed Jul 9, 2023
1 parent 061d11d commit ba3e7fe
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions sgm/modules/diffusionmodules/discretizer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import numpy as np
from functools import partial
from abc import abstractmethod

from ...util import append_zero
from ...modules.diffusionmodules.util import make_beta_schedule
Expand All @@ -13,19 +14,23 @@ def generate_roughly_equally_spaced_steps(


class Discretization:
def __call__(self, n, do_append_zero=True, device="cuda", flip=False):
sigmas = self.get_sigmas(n, device)
def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
sigmas = self.get_sigmas(n, device=device)
sigmas = append_zero(sigmas) if do_append_zero else sigmas
return sigmas if not flip else torch.flip(sigmas, (0,))

@abstractmethod
def get_sigmas(self, n, device):
raise NotImplementedError("abstract class should not be called")


class EDMDiscretization(Discretization):
def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.rho = rho

def get_sigmas(self, n, device):
def get_sigmas(self, n, device="cpu"):
ramp = torch.linspace(0, 1, n, device=device)
min_inv_rho = self.sigma_min ** (1 / self.rho)
max_inv_rho = self.sigma_max ** (1 / self.rho)
Expand All @@ -40,6 +45,7 @@ def __init__(
linear_end=0.0120,
num_timesteps=1000,
):
super().__init__()
self.num_timesteps = num_timesteps
betas = make_beta_schedule(
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
Expand All @@ -48,7 +54,7 @@ def __init__(
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.to_torch = partial(torch.tensor, dtype=torch.float32)

def get_sigmas(self, n, device):
def get_sigmas(self, n, device="cpu"):
if n < self.num_timesteps:
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
alphas_cumprod = self.alphas_cumprod[timesteps]
Expand Down

0 comments on commit ba3e7fe

Please sign in to comment.