Skip to content

Commit

Permalink
Add initial scheduler support
Browse files Browse the repository at this point in the history
Add the ability to choose schedulers, at least for a subset of the pipeline
configurations. Allows configuring for text to audio and audio to audio
in the sidebar. Currently not used for interpolation, aka the riffusion
pipeline.

Topic: schedulers_v0
  • Loading branch information
hmartiro committed Jan 15, 2023
1 parent c771ab0 commit b459107
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 15 deletions.
32 changes: 21 additions & 11 deletions riffusion/streamlit/pages/audio_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,27 @@ def render_audio_to_audio() -> None:
device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar)

num_inference_steps = T.cast(
int,
st.sidebar.number_input(
"Steps per sample", value=50, help="Number of denoising steps per model run"
),
)
with st.sidebar:
num_inference_steps = T.cast(
int,
st.number_input(
"Steps per sample", value=50, help="Number of denoising steps per model run"
),
)

guidance = st.sidebar.number_input(
"Guidance",
value=7.0,
help="How much the model listens to the text prompt",
)
guidance = st.number_input(
"Guidance",
value=7.0,
help="How much the model listens to the text prompt",
)

scheduler = st.selectbox(
"Scheduler",
options=streamlit_util.SCHEDULER_OPTIONS,
index=0,
help="Which diffusion scheduler to use",
)
assert scheduler is not None

audio_file = st.file_uploader(
"Upload audio",
Expand Down Expand Up @@ -207,6 +216,7 @@ def render_audio_to_audio() -> None:
seed=prompt_input_a.seed,
progress_callback=progress_callback,
device=device,
scheduler=scheduler,
)

# Resize back to original size
Expand Down
8 changes: 8 additions & 0 deletions riffusion/streamlit/pages/text_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def render_text_to_audio() -> None:
guidance = st.number_input(
"Guidance", value=7.0, help="How much the model listens to the text prompt"
)
scheduler = st.selectbox(
"Scheduler",
options=streamlit_util.SCHEDULER_OPTIONS,
index=0,
help="Which diffusion scheduler to use",
)
assert scheduler is not None

if not prompt:
st.info("Enter a prompt")
Expand All @@ -85,6 +92,7 @@ def render_text_to_audio() -> None:
width=width,
height=512,
device=device,
scheduler=scheduler,
)
st.image(image)

Expand Down
61 changes: 57 additions & 4 deletions riffusion/streamlit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,15 @@
AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"]
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]

SCHEDULER_OPTIONS = [
"PNDMScheduler",
"DDIMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
]


@st.experimental_singleton
def load_riffusion_checkpoint(
Expand All @@ -42,6 +51,7 @@ def load_stable_diffusion_pipeline(
checkpoint: str = "riffusion/riffusion-model-v1",
device: str = "cuda",
dtype: torch.dtype = torch.float16,
scheduler: str = SCHEDULER_OPTIONS[0],
) -> StableDiffusionPipeline:
"""
Load the riffusion pipeline.
Expand All @@ -52,19 +62,56 @@ def load_stable_diffusion_pipeline(
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
dtype = torch.float32

return StableDiffusionPipeline.from_pretrained(
pipeline = StableDiffusionPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=dtype,
safety_checker=lambda images, **kwargs: (images, False),
).to(device)

pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)

return pipeline


def get_scheduler(scheduler: str, config: T.Any) -> T.Any:
"""
Construct a denoising scheduler from a string.
"""
if scheduler == "PNDMScheduler":
from diffusers import PNDMScheduler

return PNDMScheduler.from_config(config)
elif scheduler == "DPMSolverMultistepScheduler":
from diffusers import DPMSolverMultistepScheduler

return DPMSolverMultistepScheduler.from_config(config)
elif scheduler == "DDIMScheduler":
from diffusers import DDIMScheduler

return DDIMScheduler.from_config(config)
elif scheduler == "LMSDiscreteScheduler":
from diffusers import LMSDiscreteScheduler

return LMSDiscreteScheduler.from_config(config)
elif scheduler == "EulerDiscreteScheduler":
from diffusers import EulerDiscreteScheduler

return EulerDiscreteScheduler.from_config(config)
elif scheduler == "EulerAncestralDiscreteScheduler":
from diffusers import EulerAncestralDiscreteScheduler

return EulerAncestralDiscreteScheduler.from_config(config)
else:
raise ValueError(f"Unknown scheduler {scheduler}")


@st.experimental_singleton
def load_stable_diffusion_img2img_pipeline(
checkpoint: str = "riffusion/riffusion-model-v1",
device: str = "cuda",
dtype: torch.dtype = torch.float16,
scheduler: str = SCHEDULER_OPTIONS[0],
) -> StableDiffusionImg2ImgPipeline:
"""
Load the image to image pipeline.
Expand All @@ -75,13 +122,17 @@ def load_stable_diffusion_img2img_pipeline(
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
dtype = torch.float32

return StableDiffusionImg2ImgPipeline.from_pretrained(
pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
checkpoint,
revision="main",
torch_dtype=dtype,
safety_checker=lambda images, **kwargs: (images, False),
).to(device)

pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)

return pipeline


@st.experimental_memo
def run_txt2img(
Expand All @@ -93,11 +144,12 @@ def run_txt2img(
width: int,
height: int,
device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0],
) -> Image.Image:
"""
Run the text to image pipeline with caching.
"""
pipeline = load_stable_diffusion_pipeline(device=device)
pipeline = load_stable_diffusion_pipeline(device=device, scheduler=scheduler)

generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed)
Expand Down Expand Up @@ -214,9 +266,10 @@ def run_img2img(
seed: int,
negative_prompt: T.Optional[str] = None,
device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0],
progress_callback: T.Optional[T.Callable[[float], T.Any]] = None,
) -> Image.Image:
pipeline = load_stable_diffusion_img2img_pipeline(device=device)
pipeline = load_stable_diffusion_img2img_pipeline(device=device, scheduler=scheduler)

generator_device = "cpu" if device.lower().startswith("mps") else device
generator = torch.Generator(device=generator_device).manual_seed(seed)
Expand Down

0 comments on commit b459107

Please sign in to comment.