Skip to content

Commit

Permalink
[Playground] Add magic mix to audio2audio
Browse files Browse the repository at this point in the history
Allows for better preserving structure with audio to audio.

Topic: magic_mix
  • Loading branch information
hmartiro committed Jan 17, 2023
1 parent 2fc9b8d commit 3439aae
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 2 deletions.
39 changes: 38 additions & 1 deletion riffusion/streamlit/pages/audio_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def render_audio_to_audio() -> None:
device = streamlit_util.select_device(st.sidebar)
extension = streamlit_util.select_audio_extension(st.sidebar)

use_magic_mix = st.sidebar.checkbox("Use Magic Mix", False)

with st.sidebar:
num_inference_steps = T.cast(
int,
Expand Down Expand Up @@ -124,7 +126,27 @@ def render_audio_to_audio() -> None:
with right:
st.write("##### Prompt B")
prompt_input_b = PromptInput(guidance=guidance, **get_prompt_inputs(key="b"))

elif use_magic_mix:
prompt = st.text_input("Prompt", key="prompt_a")

row = st.columns(4)

seed = T.cast(
int,
row[0].number_input(
"Seed",
value=42,
key="seed_a",
),
)
prompt_input_a = PromptInput(
prompt=prompt,
seed=seed,
guidance=guidance,
)
magic_mix_kmin = row[1].number_input("Kmin", value=0.3)
magic_mix_kmax = row[2].number_input("Kmax", value=0.5)
magic_mix_mix_factor = row[3].number_input("Mix Factor", value=0.5)
else:
prompt_input_a = PromptInput(
guidance=guidance,
Expand Down Expand Up @@ -192,6 +214,7 @@ def render_audio_to_audio() -> None:
progress_callback = progress.progress

if interpolate:
assert use_magic_mix is False, "Cannot use magic mix and interpolate together"
inputs = InferenceInput(
alpha=float(alphas[i]),
num_inference_steps=num_inference_steps,
Expand All @@ -205,6 +228,20 @@ def render_audio_to_audio() -> None:
init_image=init_image_resized,
device=device,
)
elif use_magic_mix:
assert not prompt_input_a.negative_prompt, "No negative prompt with magic mix"
image = streamlit_util.run_img2img_magic_mix(
prompt=prompt_input_a.prompt,
init_image=init_image_resized,
num_inference_steps=num_inference_steps,
guidance_scale=guidance,
seed=prompt_input_a.seed,
kmin=magic_mix_kmin,
kmax=magic_mix_kmax,
mix_factor=magic_mix_mix_factor,
device=device,
scheduler=scheduler,
)
else:
image = streamlit_util.run_img2img(
prompt=prompt_input_a.prompt,
Expand Down
62 changes: 61 additions & 1 deletion riffusion/streamlit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pydub
import streamlit as st
import torch
from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
from diffusers import DiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
from PIL import Image

from riffusion.audio_splitter import AudioSplitter
Expand Down Expand Up @@ -256,6 +256,20 @@ def select_audio_extension(container: T.Any = st.sidebar) -> str:
return extension


def select_scheduler(container: T.Any = st.sidebar) -> str:
"""
Dropdown to select a scheduler.
"""
scheduler = st.sidebar.selectbox(
"Scheduler",
options=SCHEDULER_OPTIONS,
index=0,
help="Which diffusion scheduler to use",
)
assert scheduler is not None
return scheduler


@st.experimental_memo
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
return pydub.AudioSegment.from_file(audio_file)
Expand All @@ -266,6 +280,52 @@ def get_audio_splitter(device: str = "cuda"):
return AudioSplitter(device=device)


@st.experimental_singleton
def load_magic_mix_pipeline(device: str = "cuda", scheduler: str = SCHEDULER_OPTIONS[0]):
pipeline = DiffusionPipeline.from_pretrained(
"riffusion/riffusion-model-v1",
custom_pipeline="magic_mix",
).to(device)

pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)

return pipeline


@st.cache
def run_img2img_magic_mix(
prompt: str,
init_image: Image.Image,
num_inference_steps: int,
guidance_scale: float,
seed: int,
kmin: float,
kmax: float,
mix_factor: float,
device: str = "cuda",
scheduler: str = SCHEDULER_OPTIONS[0],
):
"""
Run the magic mix pipeline for img2img.
"""
with pipeline_lock():
pipeline = load_magic_mix_pipeline(
device=device,
scheduler=scheduler,
)

return pipeline(
init_image,
prompt=prompt,
kmin=kmin,
kmax=kmax,
mix_factor=mix_factor,
seed=seed,
guidance_scale=guidance_scale,
steps=num_inference_steps,
)


@st.cache
def run_img2img(
prompt: str,
Expand Down

0 comments on commit 3439aae

Please sign in to comment.