Skip to content

Commit

Permalink
Merge pull request ddPn08#87 from Isotr0py/karras
Browse files Browse the repository at this point in the history
Support karras samplers
  • Loading branch information
ddPn08 authored May 16, 2023
2 parents 0eb692e + e6b58cf commit 4ced4f1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
24 changes: 21 additions & 3 deletions lib/diffusers/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import diffusers

# TODO: add rest Auto111 samplers
# Commented schedulers are still unavailable in diffusers 0.16.1
SCHEDULERS = {
"unipc": diffusers.schedulers.UniPCMultistepScheduler,
"euler_a": diffusers.schedulers.EulerAncestralDiscreteScheduler,
Expand All @@ -9,8 +11,24 @@
"deis": diffusers.schedulers.DEISMultistepScheduler,
"dpm2": diffusers.schedulers.KDPM2DiscreteScheduler,
"dpm2-a": diffusers.schedulers.KDPM2AncestralDiscreteScheduler,
"heun": diffusers.schedulers.DPMSolverMultistepScheduler,
"dpm++": diffusers.schedulers.DPMSolverMultistepScheduler,
"dpm": diffusers.schedulers.DPMSolverMultistepScheduler,
"dpm++_2s": diffusers.schedulers.DPMSolverSinglestepScheduler,
"dpm++_2m": diffusers.schedulers.DPMSolverMultistepScheduler,
"dpm++_2m_karras": diffusers.schedulers.DPMSolverMultistepScheduler,
# "dpm++_sde": diffusers.schedulers.DPMSolverSDEScheduler,
# "dpm++_sde_karras": diffusers.schedulers.DPMSolverSDEScheduler,
"heun": diffusers.schedulers.HeunDiscreteScheduler,
"heun_karras": diffusers.schedulers.HeunDiscreteScheduler,
"lms": diffusers.schedulers.LMSDiscreteScheduler,
# "lms_karras": diffusers.schedulers.LMSDiscreteScheduler,
"pndm": diffusers.schedulers.PNDMScheduler,
}


def parser_schedulers_config(scheduler_id: str):
"""
add extra config parameter to scheduler
"""
kwargs = {}
if "karras" in scheduler_id:
kwargs["use_karras_sigmas"] = True
return kwargs
4 changes: 2 additions & 2 deletions modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

from api.models.diffusion import ImageGenerationOptions
from lib.diffusers.scheduler import SCHEDULERS
from lib.diffusers.scheduler import SCHEDULERS, parser_schedulers_config

from . import config, utils
from .images import save_image
Expand Down Expand Up @@ -133,7 +133,7 @@ def swap_scheduler(self, scheduler_id: str):
if not self.activated:
raise RuntimeError("Model not activated")
self.pipe.scheduler = SCHEDULERS[scheduler_id].from_config(
self.pipe.scheduler.config
self.pipe.scheduler.config, **parser_schedulers_config(scheduler_id)
)

def __call__(self, opts: ImageGenerationOptions, plugin_data: Dict[str, List] = {}):
Expand Down

0 comments on commit 4ced4f1

Please sign in to comment.