Skip to content

Commit

Permalink
Audio splitting with demucs hybrid transformer model
Browse files Browse the repository at this point in the history
Topic: audio_splitter_transformer
  • Loading branch information
hmartiro committed Jan 7, 2023
1 parent f8595d7 commit 8e87c13
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 3 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
accelerate
argh
dacite
demucs
diffusers>=0.9.0
flask
flask_cors
Expand Down
59 changes: 59 additions & 0 deletions riffusion/audio_splitter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import shutil
import subprocess
import tempfile
import typing as T
from pathlib import Path

import numpy as np
import pydub
Expand All @@ -9,10 +13,65 @@
from riffusion.util import audio_util


def split_audio(
segment: pydub.AudioSegment,
model_name: str = "htdemucs_6s",
extension: str = "wav",
jobs: int = 4,
device: str = "cuda",
) -> T.Dict[str, pydub.AudioSegment]:
"""
Split audio into stems using demucs.
"""
tmp_dir = Path(tempfile.mkdtemp(prefix="split_audio_"))

# Save the audio to a temporary file
audio_path = tmp_dir / "audio.mp3"
segment.export(audio_path, format="mp3")

# Assemble command
command = [
"demucs",
str(audio_path),
"--name",
model_name,
"--out",
str(tmp_dir),
"--jobs",
str(jobs),
"--device",
device if device != "mps" else "cpu",
]
print(" ".join(command))

if extension == "mp3":
command.append("--mp3")

# Run demucs
subprocess.run(
command,
check=True,
)

# Load the stems
stems = {}
for stem_path in tmp_dir.glob(f"{model_name}/audio/*.{extension}"):
stem = pydub.AudioSegment.from_file(stem_path)
stems[stem_path.stem] = stem

# Delete tmp dir
shutil.rmtree(tmp_dir)

return stems


class AudioSplitter:
"""
Split audio into instrument stems like {drums, bass, vocals, etc.}
NOTE(hayk): This is deprecated as it has inferior performance to the newer hybrid transformer
model in the demucs repo. See the function above. Probably just delete this.
See:
https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html
"""
Expand Down
21 changes: 18 additions & 3 deletions riffusion/streamlit/pages/split_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import streamlit as st

from riffusion.audio_splitter import split_audio
from riffusion.streamlit import util as streamlit_util


Expand Down Expand Up @@ -32,11 +33,13 @@ def render_split_audio() -> None:

audio_file = st.file_uploader(
"Upload audio",
type=["mp3", "m4a", "ogg", "wav", "flac"],
type=["mp3", "m4a", "ogg", "wav", "flac", "webm"],
label_visibility="collapsed",
)

splitter = streamlit_util.get_audio_splitter(device=device)
recombine = st.sidebar.checkbox(
"Recombine", value=False, help="Show recombined audio at the end for comparison"
)

if not audio_file:
st.info("Upload audio to get started")
Expand All @@ -51,7 +54,7 @@ def render_split_audio() -> None:
segment = streamlit_util.load_audio_file(audio_file)

# Split
stems = splitter.split(segment)
stems = split_audio(segment, device=device)

# Display each
for name, stem in stems.items():
Expand All @@ -60,6 +63,18 @@ def render_split_audio() -> None:
stem.export(audio_bytes, format="mp3")
st.audio(audio_bytes)

if recombine:
stems_list = list(stems.values())
recombined = stems_list[0]
for stem in stems_list[1:]:
recombined = recombined.overlay(stem)

# Display
st.write("#### recombined")
audio_bytes = io.BytesIO()
recombined.export(audio_bytes, format="mp3")
st.audio(audio_bytes)


if __name__ == "__main__":
render_split_audio()

0 comments on commit 8e87c13

Please sign in to comment.