Skip to content

Commit

Permalink
Upgrade playground app to Streamlit 1.18+
Browse files Browse the repository at this point in the history
The first change was using the new non-experimental cache decorators,
but then I decided to refactor to get rid of using the streamlit pages
feature and instead have my own dropdown. This allows for more control
to fix a page layout issue that popped up with this version.
  • Loading branch information
hmartiro committed Mar 26, 2023
1 parent 5a989ff commit a0f12d8
Show file tree
Hide file tree
Showing 13 changed files with 85 additions and 109 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.1
0.3.1
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pysoundfile
scipy
soundfile
sox
streamlit>=1.10.0
streamlit>=1.18.0
torch
torchaudio
torchvision
Expand Down
60 changes: 23 additions & 37 deletions riffusion/streamlit/playground.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,29 @@
import streamlit as st


def render_main():
st.set_page_config(layout="wide", page_icon="🎸")

st.title(":guitar: Riffusion Playground")

left, right = st.columns(2)

with left:
create_link(":pencil2: Text to Audio", "/text_to_audio")
st.write("Generate audio clips from text prompts.")

create_link(":wave: Audio to Audio", "/audio_to_audio")
st.write("Upload audio and modify with text prompt (interpolation supported).")

create_link(":performing_arts: Interpolation", "/interpolation")
st.write("Interpolate between prompts in the latent space.")

create_link(":scissors: Audio Splitter", "/split_audio")
st.write("Split audio into stems like vocals, bass, drums, guitar, etc.")

with right:
create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch")
st.write("Generate audio in batch from a JSON file of text prompts.")

create_link(":paperclip: Sample Clips", "/sample_clips")
st.write("Export short clips from an audio file.")

create_link(":musical_keyboard: Image to Audio", "/image_to_audio")
st.write("Reconstruct audio from spectrogram images.")


def create_link(name: str, url: str) -> None:
st.markdown(
f"### <a href='{url}' target='_self' style='text-decoration: none;'>{name}</a>",
unsafe_allow_html=True,
PAGES = {
"🎛️ Home": "tasks.home",
"🌊 Text to Audio": "tasks.text_to_audio",
"✨ Audio to Audio": "tasks.audio_to_audio",
"🎭 Interpolation": "tasks.interpolation",
"✂️ Audio Splitter": "tasks.split_audio",
"📜 Text to Audio Batch": "tasks.text_to_audio_batch",
"📎 Sample Clips": "tasks.sample_clips",
"⏈ Spectrogram to Audio": "tasks.image_to_audio",
}


def main() -> None:
st.set_page_config(
page_title="Riffusion Playground",
page_icon="🎸",
layout="wide",
)

page = st.sidebar.selectbox("Page", list(PAGES.keys()))
assert page is not None
module = __import__(PAGES[page], fromlist=["render"])
module.render()


if __name__ == "__main__":
render_main()
main()
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@
from riffusion.datatypes import InferenceInput, PromptInput
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
from riffusion.streamlit.pages.interpolation import get_prompt_inputs, run_interpolation
from riffusion.streamlit.tasks.interpolation import get_prompt_inputs, run_interpolation
from riffusion.util import audio_util


def render_audio_to_audio() -> None:
st.set_page_config(layout="wide", page_icon="🎸")

st.subheader(":wave: Audio to Audio")
def render() -> None:
st.subheader("✨ Audio to Audio")
st.write(
"""
Modify existing audio from a text prompt or interpolate between two.
Expand Down Expand Up @@ -408,7 +406,3 @@ def scale_image_to_32_stride(image: Image.Image) -> Image.Image:
closest_width = int(np.ceil(image.width / 32) * 32)
closest_height = int(np.ceil(image.height / 32) * 32)
return image.resize((closest_width, closest_height), Image.BICUBIC)


if __name__ == "__main__":
render_audio_to_audio()
32 changes: 32 additions & 0 deletions riffusion/streamlit/tasks/home.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import streamlit as st


def render():
st.title("✨🎸 Riffusion Playground 🎸✨")

st.write("Select a task from the sidebar to get started!")

left, right = st.columns(2)

with left:
st.subheader("🌊 Text to Audio")
st.write("Generate audio clips from text prompts.")

st.subheader("✨ Audio to Audio")
st.write("Upload audio and modify with text prompt (interpolation supported).")

st.subheader("🎭 Interpolation")
st.write("Interpolate between prompts in the latent space.")

st.subheader("✂️ Audio Splitter")
st.write("Split audio into stems like vocals, bass, drums, guitar, etc.")

with right:
st.subheader("📜 Text to Audio Batch")
st.write("Generate audio in batch from a JSON file of text prompts.")

st.subheader("📎 Sample Clips")
st.write("Export short clips from an audio file.")

st.subheader("⏈ Spectrogram to Audio")
st.write("Reconstruct audio from spectrogram images.")
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from riffusion.util.image_util import exif_from_image


def render_image_to_audio() -> None:
st.set_page_config(layout="wide", page_icon="🎸")

st.subheader(":musical_keyboard: Image to Audio")
def render() -> None:
st.subheader("⏈ Image to Audio")
st.write(
"""
Reconstruct audio from spectrogram images.
Expand Down Expand Up @@ -77,7 +75,3 @@ def render_image_to_audio() -> None:
name=Path(image_file.name).stem,
extension=extension,
)


if __name__ == "__main__":
render_image_to_audio()
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
from riffusion.streamlit import util as streamlit_util


def render_interpolation() -> None:
st.set_page_config(layout="wide", page_icon="🎸")

st.subheader(":performing_arts: Interpolation")
def render() -> None:
st.subheader("🎭 Interpolation")
st.write(
"""
Interpolate between prompts in the latent space.
Expand Down Expand Up @@ -241,7 +239,7 @@ def get_prompt_inputs(
return p


@st.experimental_memo
@st.cache_data
def run_interpolation(
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3"
) -> T.Tuple[Image.Image, io.BytesIO]:
Expand Down Expand Up @@ -275,7 +273,3 @@ def run_interpolation(
)

return image, audio_bytes


if __name__ == "__main__":
render_interpolation()
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
from riffusion.streamlit import util as streamlit_util


def render_sample_clips() -> None:
st.set_page_config(layout="wide", page_icon="🎸")

st.subheader(":paperclip: Sample Clips")
def render() -> None:
st.subheader("📎 Sample Clips")
st.write(
"""
Export short clips from an audio file.
Expand Down Expand Up @@ -125,7 +123,3 @@ def render_sample_clips() -> None:

if save_to_disk:
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")


if __name__ == "__main__":
render_sample_clips()
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from riffusion.util import audio_util


def render_split_audio() -> None:
st.set_page_config(layout="wide", page_icon="🎸")

st.subheader(":scissors: Audio Splitter")
def render() -> None:
st.subheader("✂️ Audio Splitter")
st.write(
"""
Split audio into individual instrument stems.
Expand Down Expand Up @@ -99,7 +97,3 @@ def split_audio_cached(
segment: pydub.AudioSegment, device: str = "cuda"
) -> T.Dict[str, pydub.AudioSegment]:
return split_audio(segment, device=device)


if __name__ == "__main__":
render_split_audio()
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
from riffusion.streamlit import util as streamlit_util


def render_text_to_audio() -> None:
st.set_page_config(layout="wide", page_icon="🎸")

st.subheader(":pencil2: Text to Audio")
def render() -> None:
st.subheader("🌊 Text to Audio")
st.write(
"""
Generate audio from text prompts.
Expand Down Expand Up @@ -119,7 +117,3 @@ def render_text_to_audio() -> None:
)

seed += 1


if __name__ == "__main__":
render_text_to_audio()
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"seed": 42,
"num_inference_steps": 50,
"guidance": 7.0,
"width": 512,
"width": 512
},
"entries": [
{
Expand All @@ -32,10 +32,8 @@
"""


def render_text_to_audio_batch() -> None:
st.set_page_config(layout="wide", page_icon="🎸")

st.subheader(":scroll: Text to Audio Batch")
def render() -> None:
st.subheader("📜 Text to Audio Batch")
st.write(
"""
Generate audio in batch from a JSON file of text prompts.
Expand Down Expand Up @@ -141,7 +139,3 @@ def render_text_to_audio_batch() -> None:
st.info(f"Output written to {str(output_path)}")
else:
st.info("Enter output directory in sidebar to save to disk")


if __name__ == "__main__":
render_text_to_audio_batch()
22 changes: 11 additions & 11 deletions riffusion/streamlit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
]


@st.experimental_singleton
@st.cache_resource
def load_riffusion_checkpoint(
checkpoint: str = DEFAULT_CHECKPOINT,
no_traced_unet: bool = False,
Expand All @@ -49,7 +49,7 @@ def load_riffusion_checkpoint(
)


@st.experimental_singleton
@st.cache_resource
def load_stable_diffusion_pipeline(
checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda",
Expand Down Expand Up @@ -109,15 +109,15 @@ def get_scheduler(scheduler: str, config: T.Any) -> T.Any:
raise ValueError(f"Unknown scheduler {scheduler}")


@st.experimental_singleton
@st.cache_resource
def pipeline_lock() -> threading.Lock:
"""
Singleton lock used to prevent concurrent access to any model pipeline.
"""
return threading.Lock()


@st.experimental_singleton
@st.cache_resource
def load_stable_diffusion_img2img_pipeline(
checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda",
Expand Down Expand Up @@ -145,7 +145,7 @@ def load_stable_diffusion_img2img_pipeline(
return pipeline


@st.experimental_memo
@st.cache_data
def run_txt2img(
prompt: str,
num_inference_steps: int,
Expand Down Expand Up @@ -184,7 +184,7 @@ def run_txt2img(
return output["images"][0]


@st.experimental_singleton
@st.cache_resource
def spectrogram_image_converter(
params: SpectrogramParams,
device: str = "cuda",
Expand All @@ -202,7 +202,7 @@ def spectrogram_image_from_audio(
return converter.spectrogram_image_from_audio(segment)


@st.experimental_memo
@st.cache_data
def audio_segment_from_spectrogram_image(
image: Image.Image,
params: SpectrogramParams,
Expand All @@ -212,7 +212,7 @@ def audio_segment_from_spectrogram_image(
return converter.audio_from_spectrogram_image(image)


@st.experimental_memo
@st.cache_data
def audio_bytes_from_spectrogram_image(
image: Image.Image,
params: SpectrogramParams,
Expand Down Expand Up @@ -289,17 +289,17 @@ def select_checkpoint(container: T.Any = st.sidebar) -> str:
return custom_checkpoint or DEFAULT_CHECKPOINT


@st.experimental_memo
@st.cache_data
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
return pydub.AudioSegment.from_file(audio_file)


@st.experimental_singleton
@st.cache_resource
def get_audio_splitter(device: str = "cuda"):
return AudioSplitter(device=device)


@st.experimental_singleton
@st.cache_resource
def load_magic_mix_pipeline(
checkpoint: str = DEFAULT_CHECKPOINT,
device: str = "cuda",
Expand Down

0 comments on commit a0f12d8

Please sign in to comment.