From bf3b8783543bdbfc31721479091e35696baadd13 Mon Sep 17 00:00:00 2001 From: LuChengTHU Date: Sun, 6 Nov 2022 18:32:02 +0800 Subject: [PATCH] add a stablizing trick for steps < 15 --- ldm/models/diffusion/dpm_solver/dpm_solver.py | 83 ++++++++++++------- ldm/models/diffusion/dpm_solver/sampler.py | 2 +- 2 files changed, 56 insertions(+), 29 deletions(-) diff --git a/ldm/models/diffusion/dpm_solver/dpm_solver.py b/ldm/models/diffusion/dpm_solver/dpm_solver.py index 7ae736c65..bdb64e0c7 100644 --- a/ldm/models/diffusion/dpm_solver/dpm_solver.py +++ b/ldm/models/diffusion/dpm_solver/dpm_solver.py @@ -394,8 +394,8 @@ def data_prediction_fn(self, x, t): if self.thresholding: p = 0.995 # A hyperparameter in the paper of "Imagen" [1]. s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) - s = expand_dims(torch.maximum(s, torch.ones_like(s).to(s.device)), dims) - x0 = torch.clamp(x0, -s, s) / (s / self.max_val) + s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims) + x0 = torch.clamp(x0, -s, s) / s return x0 def model_fn(self, x, t): @@ -436,7 +436,7 @@ def get_time_steps(self, skip_type, t_T, t_0, N, device): else: raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) - def get_orders_for_singlestep_solver(self, steps, order): + def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): """ Get the order of each step for sampling by the singlestep DPM-Solver. @@ -458,6 +458,13 @@ def get_orders_for_singlestep_solver(self, steps, order): Args: order: A `int`. The max order for the solver (2 or 3). steps: A `int`. The total number of function evaluations (NFE). + skip_type: A `str`. The type for the spacing of the time steps. We support three types: + - 'logSNR': uniform logSNR for the time steps. + - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) + - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) + t_T: A `float`. The starting time of the sampling (default is T). + t_0: A `float`. The ending time of the sampling (default is epsilon). + device: A torch device. Returns: orders: A list of the solver order of each step. """ @@ -469,20 +476,26 @@ def get_orders_for_singlestep_solver(self, steps, order): orders = [3,] * (K - 1) + [1] else: orders = [3,] * (K - 1) + [2] - return orders elif order == 2: - K = steps // 2 if steps % 2 == 0: + K = steps // 2 orders = [2,] * K else: - orders = [2,] * K + [1] - return orders + K = steps // 2 + 1 + orders = [2,] * (K - 1) + [1] elif order == 1: - return [1,] * steps + K = 1 + orders = [1,] * steps else: raise ValueError("'order' must be '1' or '2' or '3'.") + if skip_type == 'logSNR': + # To reproduce the results in DPM-Solver paper + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) + else: + timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders)).to(device)] + return timesteps_outer, orders - def denoise_fn(self, x, s): + def denoise_to_zero_fn(self, x, s): """ Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. """ @@ -950,8 +963,8 @@ def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol return x def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time_uniform', - method='singlestep', denoise=False, solver_type='dpm_solver', atol=0.0078, - rtol=0.05, + method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver', + atol=0.0078, rtol=0.05, ): """ Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. @@ -1035,8 +1048,19 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time order: A `int`. The order of DPM-Solver. skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. - denoise: A `bool`. Whether to denoise at the final step. Default is False. - If `denoise` is True, the total NFE is (`steps` + 1). + denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. + Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). + + This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and + score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID + for diffusion models sampling by diffusion SDEs for low-resolutional images + (such as CIFAR-10). However, we observed that such trick does not matter for + high-resolutional images. As it needs an additional NFE, we do not recommend + it for high-resolutional images. + lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. + Only valid for `method=multistep` and `steps < 15`. We empirically find that + this trick is a key to stabilizing the sampling by DPM-Solver with very few steps + (especially for steps <= 10). So we recommend to set it to be `True`. solver_type: A `str`. The taylor expansion type for the solver. `dpm_solver` or `taylor`. We recommend `dpm_solver`. atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. @@ -1067,7 +1091,11 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time # Compute the remaining values by `order`-th order multistep DPM-Solver. for step in range(order, steps + 1): vec_t = timesteps[step].expand(x.shape[0]) - x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, order, solver_type=solver_type) + if lower_order_final and steps < 15: + step_order = min(order, steps + 1 - step) + else: + step_order = order + x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, vec_t, step_order, solver_type=solver_type) for i in range(order - 1): t_prev_list[i] = t_prev_list[i + 1] model_prev_list[i] = model_prev_list[i + 1] @@ -1077,23 +1105,22 @@ def sample(self, x, steps=20, t_start=None, t_end=None, order=3, skip_type='time model_prev_list[-1] = self.model_fn(x, vec_t) elif method in ['singlestep', 'singlestep_fixed']: if method == 'singlestep': - orders = self.get_orders_for_singlestep_solver(steps=steps, order=order) - timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) + timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device) elif method == 'singlestep_fixed': K = steps // order orders = [order,] * K - timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=(K * order), device=device) - with torch.no_grad(): - i = 0 - for order in orders: - vec_s, vec_t = timesteps[i].expand(x.shape[0]), timesteps[i + order].expand(x.shape[0]) - h = self.noise_schedule.marginal_lambda(timesteps[i + order]) - self.noise_schedule.marginal_lambda(timesteps[i]) - r1 = None if order <= 1 else (self.noise_schedule.marginal_lambda(timesteps[i + 1]) - self.noise_schedule.marginal_lambda(timesteps[i])) / h - r2 = None if order <= 2 else (self.noise_schedule.marginal_lambda(timesteps[i + 2]) - self.noise_schedule.marginal_lambda(timesteps[i])) / h - x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) - i += order - if denoise: - x = self.denoise_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) + timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) + for i, order in enumerate(orders): + t_T_inner, t_0_inner = timesteps_outer[i], timesteps_outer[i + 1] + timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=t_T_inner.item(), t_0=t_0_inner.item(), N=order, device=device) + lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) + vec_s, vec_t = t_T_inner.tile(x.shape[0]), t_0_inner.tile(x.shape[0]) + h = lambda_inner[-1] - lambda_inner[0] + r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h + r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h + x = self.singlestep_dpm_solver_update(x, vec_s, vec_t, order, solver_type=solver_type, r1=r1, r2=r2) + if denoise_to_zero: + x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0) return x diff --git a/ldm/models/diffusion/dpm_solver/sampler.py b/ldm/models/diffusion/dpm_solver/sampler.py index c3fb7bcea..141414309 100644 --- a/ldm/models/diffusion/dpm_solver/sampler.py +++ b/ldm/models/diffusion/dpm_solver/sampler.py @@ -77,6 +77,6 @@ def sample(self, ) dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) - x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2) + x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) return x.to(device), None \ No newline at end of file