Skip to content

Commit

Permalink
Audio to audio improvements (some WIP)
Browse files Browse the repository at this point in the history
Topic: audio_to_audio
  • Loading branch information
hmartiro committed Jan 6, 2023
1 parent 503c5e4 commit 83b2792
Showing 1 changed file with 54 additions and 21 deletions.
75 changes: 54 additions & 21 deletions riffusion/streamlit/pages/audio_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,12 @@
import numpy as np
import pydub
import streamlit as st
import torch
from PIL import Image

from riffusion.spectrogram_image_converter import SpectrogramImageConverter
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util


@st.experimental_memo
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
return pydub.AudioSegment.from_file(audio_file)


def render_audio_to_audio() -> None:
st.set_page_config(layout="wide", page_icon="🎸")

Expand All @@ -31,18 +24,22 @@ def render_audio_to_audio() -> None:

audio_file = st.file_uploader(
"Upload audio",
type=["mp3", "ogg", "wav", "flac"],
type=["mp3", "m4a", "ogg", "wav", "flac"],
label_visibility="collapsed",
)

if not audio_file:
st.info("Upload audio to get started")
return

st.write("#### Original Audio")
st.write("#### Original")
st.audio(audio_file)

segment = load_audio_file(audio_file)
segment = streamlit_util.load_audio_file(audio_file)

# TODO(hayk): Fix
segment = segment.set_frame_rate(44100)
st.write(f"Duration: {segment.duration_seconds:.2f}s, Sample Rate: {segment.frame_rate}Hz")

if "counter" not in st.session_state:
st.session_state.counter = 0
Expand All @@ -59,7 +56,6 @@ def increment_counter():
duration_s = cols[1].number_input(
"Duration [s]",
min_value=0.0,
max_value=segment.duration_seconds,
value=15.0,
)
clip_duration_s = cols[2].number_input(
Expand All @@ -75,12 +71,14 @@ def increment_counter():
value=0.2,
)

duration_s = min(duration_s, segment.duration_seconds - start_time_s)
increment_s = clip_duration_s - overlap_duration_s
clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s)
st.write(
f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s "
f"with overlap {overlap_duration_s}s."
)
st.write(clip_start_times)

with st.form("Conversion Params"):

Expand All @@ -92,7 +90,7 @@ def increment_counter():
"Denoising Strength",
min_value=0.0,
max_value=1.0,
value=0.65,
value=0.45,
)
guidance_scale = cols[1].number_input(
"Guidance Scale",
Expand All @@ -108,27 +106,37 @@ def increment_counter():
value=50,
)
)

seed = int(
cols[3].number_input(
"Seed",
min_value=-1,
value=-1,
min_value=0,
value=42,
)
)
# TODO replace seed -1 with random

submit_button = st.form_submit_button("Convert", on_click=increment_counter)

# TODO fix


show_clip_details = st.sidebar.checkbox("Show Clip Details", True)
show_difference = st.sidebar.checkbox("Show Difference", False)

clip_segments: T.List[pydub.AudioSegment] = []
for clip_start_time_s in clip_start_times:
for i, clip_start_time_s in enumerate(clip_start_times):
clip_start_time_ms = int(clip_start_time_s * 1000)
clip_duration_ms = int(clip_duration_s * 1000)
clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms]

if i == len(clip_start_times) - 1:
silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000)
st.write(f"Last clip: {clip_duration_ms=}ms")
st.write(f"Last clip: {clip_start_time_ms=}ms")
st.write(f"Last clip: {clip_segment.duration_seconds=:.2f}s")
st.write(f"Last clip: {silence_ms=}ms")
if silence_ms > 0:
clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms))

clip_segments.append(clip_segment)

if not prompt:
Expand All @@ -154,6 +162,13 @@ def increment_counter():
device=device,
)

# TODO(hayk): Roll this into spectrogram_image_from_audio?
# TODO(hayk): Scale something when computing audio
closest_width = int(np.ceil(init_image.width / 32) * 32)
closest_height = int(np.ceil(init_image.height / 32) * 32)
init_image = init_image.resize((closest_width, closest_height), Image.BICUBIC)

progress_callback = None
if show_clip_details:
left, right = st.columns(2)

Expand All @@ -166,6 +181,7 @@ def increment_counter():
with empty_bin.container():
st.info("Riffing...")
progress = st.progress(0.0)
progress_callback = progress.progress

image = streamlit_util.run_img2img(
prompt=prompt,
Expand All @@ -175,10 +191,11 @@ def increment_counter():
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
seed=seed,
progress_callback=progress.progress,
progress_callback=progress_callback,
device=device,
)

st.write(init_image.size)
st.write(image.size)
result_images.append(image)

if show_clip_details:
Expand All @@ -191,13 +208,29 @@ def increment_counter():
device=device,
)
result_segments.append(riffed_segment)

st.write(clip_segment.duration_seconds)
st.write(riffed_segment.duration_seconds)
audio_bytes = io.BytesIO()
riffed_segment.export(audio_bytes, format="wav")

if show_clip_details:
right.audio(audio_bytes)

if show_clip_details and show_difference:
diff_np = np.maximum(0, np.asarray(init_image).astype(np.float32) - np.asarray(image).astype(np.float32))
st.write(diff_np.shape)
diff_image = Image.fromarray(255 - diff_np.astype(np.uint8))
st.image(diff_image)
diff_segment = streamlit_util.audio_segment_from_spectrogram_image(
image=diff_image,
params=params,
device=device,
)

audio_bytes = io.BytesIO()
diff_segment.export(audio_bytes, format="wav")
st.audio(audio_bytes)

# Combine clips with a crossfade based on overlap
crossfade_ms = int(overlap_duration_s * 1000)
combined_segment = result_segments[0]
Expand All @@ -207,7 +240,7 @@ def increment_counter():
audio_bytes = io.BytesIO()
combined_segment.export(audio_bytes, format="mp3")
st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)")
st.audio(audio_bytes)
st.audio(audio_bytes, format="audio/mp3")


@st.cache
Expand Down

0 comments on commit 83b2792

Please sign in to comment.