Skip to content

Commit

Permalink
Audio to audio handles interpolation within it
Browse files Browse the repository at this point in the history
Kill the separate page.

Topic: audio_to_audio_interpolation
  • Loading branch information
hmartiro committed Jan 14, 2023
1 parent 40bf61e commit 8b07a5a
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 384 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:

- name: Install system packages
run: |
sudo apt-get update
sudo apt-get install -y ffmpeg libsndfile1
- name: Install pip packages from requirements.txt
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ __pycache__/
# Cog
.cog/

# Random stuff I don't care about
.graveyard/

# Distribution / packaging
.Python
build/
Expand Down
3 changes: 3 additions & 0 deletions riffusion/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class PromptInput:
# Random seed for denoising
seed: int

# Negative prompt to avoid (optional)
negative_prompt: T.Optional[str] = None

# Denoising strength
denoising: float = 0.75

Expand Down
292 changes: 178 additions & 114 deletions riffusion/streamlit/pages/audio_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import streamlit as st
from PIL import Image

from riffusion.datatypes import InferenceInput, PromptInput
from riffusion.spectrogram_params import SpectrogramParams
from riffusion.streamlit import util as streamlit_util
from riffusion.streamlit.pages.interpolation import get_prompt_inputs, run_interpolation
from riffusion.util import audio_util


def render_audio_to_audio() -> None:
Expand Down Expand Up @@ -37,6 +40,19 @@ def render_audio_to_audio() -> None:

device = streamlit_util.select_device(st.sidebar)

num_inference_steps = T.cast(
int,
st.sidebar.number_input(
"Steps per sample", value=50, help="Number of denoising steps per model run"
),
)

guidance = st.sidebar.number_input(
"Guidance",
value=7.0,
help="How much the model listens to the text prompt",
)

audio_file = st.file_uploader(
"Upload audio",
type=["mp3", "m4a", "ogg", "wav", "flac", "webm"],
Expand All @@ -53,113 +69,58 @@ def render_audio_to_audio() -> None:
segment = streamlit_util.load_audio_file(audio_file)

# TODO(hayk): Fix
segment = segment.set_frame_rate(44100)
if segment.frame_rate != 44100:
st.warning("Audio must be 44100Hz. Converting")
segment = segment.set_frame_rate(44100)
st.write(f"Duration: {segment.duration_seconds:.2f}s, Sample Rate: {segment.frame_rate}Hz")

if "counter" not in st.session_state:
st.session_state.counter = 0
clip_p = get_clip_params()
start_time_s = clip_p["start_time_s"]
clip_duration_s = clip_p["clip_duration_s"]
overlap_duration_s = clip_p["overlap_duration_s"]

def increment_counter():
st.session_state.counter += 1

cols = st.columns(4)
start_time_s = cols[0].number_input(
"Start Time [s]",
min_value=0.0,
value=0.0,
)
duration_s = cols[1].number_input(
"Duration [s]",
min_value=0.0,
value=15.0,
)
clip_duration_s = cols[2].number_input(
"Clip Duration [s]",
min_value=3.0,
max_value=10.0,
value=5.0,
)
overlap_duration_s = cols[3].number_input(
"Overlap Duration [s]",
min_value=0.0,
max_value=10.0,
value=0.2,
)

duration_s = min(duration_s, segment.duration_seconds - start_time_s)
duration_s = min(clip_p["duration_s"], segment.duration_seconds - start_time_s)
increment_s = clip_duration_s - overlap_duration_s
clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s)
st.write(
f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s "
f"with overlap {overlap_duration_s}s."

write_clip_details(
clip_start_times=clip_start_times,
clip_duration_s=clip_duration_s,
overlap_duration_s=overlap_duration_s,
)

with st.expander("Clip Times"):
st.dataframe(
{
"Start Time [s]": clip_start_times,
"End Time [s]": clip_start_times + clip_duration_s,
"Duration [s]": clip_duration_s,
}
)
interpolate = st.checkbox("Interpolate between two settings", False)

with st.form("Conversion Params"):
with st.form("audio to audio form"):
if interpolate:
left, right = st.columns(2)

prompt = st.text_input("Text Prompt")
negative_prompt = st.text_input("Negative Prompt")
with left:
st.write("##### Prompt A")
prompt_input_a = PromptInput(guidance=guidance, **get_prompt_inputs(key="a"))

cols = st.columns(4)
denoising_strength = cols[0].number_input(
"Denoising Strength",
min_value=0.0,
max_value=1.0,
value=0.45,
)
guidance_scale = cols[1].number_input(
"Guidance Scale",
min_value=0.0,
max_value=20.0,
value=7.0,
)
num_inference_steps = int(
cols[2].number_input(
"Num Inference Steps",
min_value=1,
max_value=150,
value=50,
)
)
with right:
st.write("##### Prompt B")
prompt_input_b = PromptInput(guidance=guidance, **get_prompt_inputs(key="b"))

seed = int(
cols[3].number_input(
"Seed",
min_value=0,
value=42,
else:
prompt_input_a = PromptInput(
guidance=guidance,
**get_prompt_inputs(key="a", include_negative_prompt=True, cols=True),
)
)

submit_button = st.form_submit_button("Convert", on_click=increment_counter)

# TODO fix
submit_button = st.form_submit_button("Riff", type="primary")

show_clip_details = st.sidebar.checkbox("Show Clip Details", True)
show_difference = st.sidebar.checkbox("Show Difference", False)

clip_segments: T.List[pydub.AudioSegment] = []
for i, clip_start_time_s in enumerate(clip_start_times):
clip_start_time_ms = int(clip_start_time_s * 1000)
clip_duration_ms = int(clip_duration_s * 1000)
clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms]

# TODO(hayk): I don't think this is working properly
if i == len(clip_start_times) - 1:
silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000)
if silence_ms > 0:
clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms))

clip_segments.append(clip_segment)
clip_segments = slice_audio_into_clips(
segment=segment,
clip_start_times=clip_start_times,
clip_duration_s=clip_duration_s,
)

if not prompt:
if not prompt_input_a.prompt:
st.info("Enter a prompt")
return

Expand All @@ -168,10 +129,16 @@ def increment_counter():

params = SpectrogramParams()

if interpolate:
# TODO(hayk): Make not linspace
alphas = list(np.linspace(0, 1, len(clip_segments)))
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
st.write(f"**Alphas** : [{alphas_str}]")

result_images: T.List[Image.Image] = []
result_segments: T.List[pydub.AudioSegment] = []
for i, clip_segment in enumerate(clip_segments):
st.write(f"### Clip {i} at {clip_start_times[i]}s")
st.write(f"### Clip {i} at {clip_start_times[i]:.2f}s")

audio_bytes = io.BytesIO()
clip_segment.export(audio_bytes, format="wav")
Expand All @@ -183,10 +150,7 @@ def increment_counter():
)

# TODO(hayk): Roll this into spectrogram_image_from_audio?
# TODO(hayk): Scale something when computing audio
closest_width = int(np.ceil(init_image.width / 32) * 32)
closest_height = int(np.ceil(init_image.height / 32) * 32)
init_image_resized = init_image.resize((closest_width, closest_height), Image.BICUBIC)
init_image_resized = scale_image_to_32_stride(init_image)

progress_callback = None
if show_clip_details:
Expand All @@ -203,17 +167,32 @@ def increment_counter():
progress = st.progress(0.0)
progress_callback = progress.progress

image = streamlit_util.run_img2img(
prompt=prompt,
init_image=init_image_resized,
denoising_strength=denoising_strength,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
seed=seed,
progress_callback=progress_callback,
device=device,
)
if interpolate:
inputs = InferenceInput(
alpha=float(alphas[i]),
num_inference_steps=num_inference_steps,
seed_image_id="og_beat",
start=prompt_input_a,
end=prompt_input_b,
)

image, audio_bytes = run_interpolation(
inputs=inputs,
init_image=init_image_resized,
device=device,
)
else:
image = streamlit_util.run_img2img(
prompt=prompt_input_a.prompt,
init_image=init_image_resized,
denoising_strength=prompt_input_a.denoising,
num_inference_steps=num_inference_steps,
guidance_scale=guidance,
negative_prompt=prompt_input_a.negative_prompt,
seed=prompt_input_a.seed,
progress_callback=progress_callback,
device=device,
)

# Resize back to original size
image = image.resize(init_image.size, Image.BICUBIC)
Expand Down Expand Up @@ -253,22 +232,107 @@ def increment_counter():
st.audio(audio_bytes)

# Combine clips with a crossfade based on overlap
crossfade_ms = int(overlap_duration_s * 1000)
combined_segment = result_segments[0]
for segment in result_segments[1:]:
combined_segment = combined_segment.append(segment, crossfade=crossfade_ms)
combined_segment = audio_util.stitch_segments(result_segments, crossfade_s=overlap_duration_s)

audio_bytes = io.BytesIO()
combined_segment.export(audio_bytes, format="mp3")
st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)")
st.audio(audio_bytes, format="audio/mp3")


@st.cache
def test(segment: pydub.AudioSegment, counter: int) -> int:
st.write("#### Trimmed")
st.write(segment.duration_seconds)
return counter
def get_clip_params(advanced: bool = False) -> T.Dict[str, T.Any]:
"""
Render the parameters of slicing audio into clips.
"""
p: T.Dict[str, T.Any] = {}

cols = st.columns(4)

p["start_time_s"] = cols[0].number_input(
"Start Time [s]",
min_value=0.0,
value=0.0,
)
p["duration_s"] = cols[1].number_input(
"Duration [s]",
min_value=0.0,
value=15.0,
)

if advanced:
p["clip_duration_s"] = cols[2].number_input(
"Clip Duration [s]",
min_value=3.0,
max_value=10.0,
value=5.0,
)
else:
p["clip_duration_s"] = 5.0

if advanced:
p["overlap_duration_s"] = cols[3].number_input(
"Overlap Duration [s]",
min_value=0.0,
max_value=10.0,
value=0.2,
)
else:
p["overlap_duration_s"] = 0.2

return p


def write_clip_details(
clip_start_times: np.ndarray, clip_duration_s: float, overlap_duration_s: float
):
"""
Write details of the clips to be sliced from an audio segment.
"""
clip_details_text = (
f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s "
f"with overlap {overlap_duration_s}s"
)

with st.expander(clip_details_text):
st.dataframe(
{
"Start Time [s]": clip_start_times,
"End Time [s]": clip_start_times + clip_duration_s,
"Duration [s]": clip_duration_s,
}
)


def slice_audio_into_clips(
segment: pydub.AudioSegment, clip_start_times: T.Sequence[float], clip_duration_s: float
) -> T.List[pydub.AudioSegment]:
"""
Slice an audio segment into a list of clips of a given duration at the given start times.
"""
clip_segments: T.List[pydub.AudioSegment] = []
for i, clip_start_time_s in enumerate(clip_start_times):
clip_start_time_ms = int(clip_start_time_s * 1000)
clip_duration_ms = int(clip_duration_s * 1000)
clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms]

# TODO(hayk): I don't think this is working properly
if i == len(clip_start_times) - 1:
silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000)
if silence_ms > 0:
clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms))

clip_segments.append(clip_segment)

return clip_segments


def scale_image_to_32_stride(image: Image.Image) -> Image.Image:
"""
Scale an image to a size that is a multiple of 32.
"""
closest_width = int(np.ceil(image.width / 32) * 32)
closest_height = int(np.ceil(image.height / 32) * 32)
return image.resize((closest_width, closest_height), Image.BICUBIC)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 8b07a5a

Please sign in to comment.