Skip to content

Commit

Permalink
Minor fixes (facebookresearch#489)
Browse files Browse the repository at this point in the history
* Allow manually run workflow

* Add HTDemucs to AnyModel

* Set minimum Python version to 3.8

* Add type hints for save_audio

* Update segment machenism

* Update help message that htdemucs is the default

* Add segment test

* Fix linter
  • Loading branch information
CarlGao4 authored May 23, 2023
1 parent e25cfeb commit 51b6545
Show file tree
Hide file tree
Showing 13 changed files with 60 additions and 41 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/linter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:

jobs:
build:
runs-on: ubuntu-latest
if: ${{ github.repository == 'facebookresearch/demucs' || github.event_name == 'workflow_dispatch' }}
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: 3.8

- uses: actions/cache@v2
with:
path: env
key: env-${{ hashFiles('**/requirements.txt') }}
key: env-${{ hashFiles('**/requirements.txt', '.github/workflows/*') }}

- name: Install dependencies
run: |
Expand Down
6 changes: 4 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:

jobs:
build:
runs-on: ubuntu-latest
if: ${{ github.repository == 'facebookresearch/demucs' || github.event_name == 'workflow_dispatch' }}
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: 3.8

- uses: actions/cache@v2
with:
path: env
key: env-${{ hashFiles('**/requirements.txt') }}
key: env-${{ hashFiles('**/requirements.txt', '.github/workflows/*') }}

- name: Install dependencies
run: |
Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ test_eval:
python3 -m demucs -n demucs_unittest --two-stems=vocals test.mp3
python3 -m demucs -n demucs_unittest --mp3 test.mp3
python3 -m demucs -n demucs_unittest --int24 --clip-mode clamp test.mp3
python3 -m demucs -n demucs_unittest --segment 8 test.mp3

tests/musdb:
test -e tests || mkdir tests
Expand Down
34 changes: 17 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,28 +79,28 @@ of the naturalness and absence of artifacts given by human listeners (5 = no art
is a rating from 1 to 5 with 5 being zero contamination by other sources. We refer the reader to our [paper][hybrid_paper],
for more details.

| Model | Domain | Extra data? | Overall SDR | MOS Quality | MOS Contamination |
|------------------------------|-------------|-------------|-------------|-------------|-------------------|
| [Wave-U-Net][waveunet] | waveform | no | 3.2 | - | - |
| [Open-Unmix][openunmix] | spectrogram | no | 5.3 | - | - |
| [D3Net][d3net] | spectrogram | no | 6.0 | - | - |
| [Conv-Tasnet][demucs_v2] | waveform | no | 5.7 | - | |
| [Demucs (v2)][demucs_v2] | waveform | no | 6.3 | 2.37 | 2.36 |
| [ResUNetDecouple+][decouple] | spectrogram | no | 6.7 | - | - |
| [KUIELAB-MDX-Net][kuielab] | hybrid | no | 7.5 | **2.86** | 2.55 |
| [Band-Spit RNN][bandsplit] | spectrogram | no | **8.2** | - | - |
| **Hybrid Demucs (v3)** | hybrid | no | 7.7 | **2.83** | **3.04** |
| [MMDenseLSTM][mmdenselstm] | spectrogram | 804 songs | 6.0 | - | - |
| [D3Net][d3net] | spectrogram | 1.5k songs | 6.7 | - | - |
| [Spleeter][spleeter] | spectrogram | 25k songs | 5.9 | - | - |
| [Band-Spit RNN][bandsplit] | spectrogram | 1.7k (mixes only) | **9.0** | - | - |
| **HT Demucs f.t. (v4)** | hybrid | 800 songs | **9.0** | - | - |
| Model | Domain | Extra data? | Overall SDR | MOS Quality | MOS Contamination |
|------------------------------|-------------|-------------------|-------------|-------------|-------------------|
| [Wave-U-Net][waveunet] | waveform | no | 3.2 | - | - |
| [Open-Unmix][openunmix] | spectrogram | no | 5.3 | - | - |
| [D3Net][d3net] | spectrogram | no | 6.0 | - | - |
| [Conv-Tasnet][demucs_v2] | waveform | no | 5.7 | - | |
| [Demucs (v2)][demucs_v2] | waveform | no | 6.3 | 2.37 | 2.36 |
| [ResUNetDecouple+][decouple] | spectrogram | no | 6.7 | - | - |
| [KUIELAB-MDX-Net][kuielab] | hybrid | no | 7.5 | **2.86** | 2.55 |
| [Band-Spit RNN][bandsplit] | spectrogram | no | **8.2** | - | - |
| **Hybrid Demucs (v3)** | hybrid | no | 7.7 | **2.83** | **3.04** |
| [MMDenseLSTM][mmdenselstm] | spectrogram | 804 songs | 6.0 | - | - |
| [D3Net][d3net] | spectrogram | 1.5k songs | 6.7 | - | - |
| [Spleeter][spleeter] | spectrogram | 25k songs | 5.9 | - | - |
| [Band-Spit RNN][bandsplit] | spectrogram | 1.7k (mixes only) | **9.0** | - | - |
| **HT Demucs f.t. (v4)** | hybrid | 800 songs | **9.0** | - | - |



## Requirements

You will need at least Python 3.7. See `requirements_minimal.txt` for requirements for separation only,
You will need at least Python 3.8. See `requirements_minimal.txt` for requirements for separation only,
and `environment-[cpu|cuda].yml` (or `requirements.txt`) if you want to train a new model.

### For Windows users
Expand Down
13 changes: 10 additions & 3 deletions demucs/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@

from .demucs import Demucs
from .hdemucs import HDemucs
from .htdemucs import HTDemucs
from .utils import center_trim, DummyPoolExecutor

Model = tp.Union[Demucs, HDemucs]
Model = tp.Union[Demucs, HDemucs, HTDemucs]


class BagOfModels(nn.Module):
Expand Down Expand Up @@ -122,7 +123,7 @@ def tensor_chunk(tensor_or_chunk):

def apply_model(model, mix, shifts=1, split=True,
overlap=0.25, transition_power=1., progress=False, device=None,
num_workers=0, pool=None):
num_workers=0, segment=None, pool=None):
"""
Apply model to a given mixture.
Expand Down Expand Up @@ -157,6 +158,7 @@ def apply_model(model, mix, shifts=1, split=True,
'progress': progress,
'device': device,
'pool': pool,
'segment': segment,
}
if isinstance(model, BagOfModels):
# Special treatment for bag of model.
Expand Down Expand Up @@ -201,7 +203,11 @@ def apply_model(model, mix, shifts=1, split=True,
kwargs['split'] = False
out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
sum_weight = th.zeros(length, device=mix.device)
segment = int(model.samplerate * model.segment)
if segment is None:
segment = model.segment
segment_old = model.segment
model.segment = segment
segment = int(model.samplerate * segment)
stride = int((1 - overlap) * segment)
offsets = range(0, length, stride)
scale = float(format(stride / model.samplerate, ".2f"))
Expand All @@ -227,6 +233,7 @@ def apply_model(model, mix, shifts=1, split=True,
chunk_length = chunk_out.shape[-1]
out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
sum_weight[offset:offset + segment] += weight[:chunk_length].to(mix.device)
model.segment = segment_old
assert sum_weight.min() > 0
out /= sum_weight
return out
Expand Down
12 changes: 10 additions & 2 deletions demucs/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import torch
import torchaudio as ta
import typing as tp

from .utils import temp_filenames

Expand Down Expand Up @@ -218,6 +219,8 @@ def prevent_clip(wav, mode='rescale'):
"""
different strategies for avoiding raw clipping.
"""
if mode is None or mode == 'none':
return wav
assert wav.dtype.is_floating_point, "too late for clipping"
if mode == 'rescale':
wav = wav / max(1.01 * wav.abs().max(), 1)
Expand All @@ -230,8 +233,13 @@ def prevent_clip(wav, mode='rescale'):
return wav


def save_audio(wav, path, samplerate, bitrate=320, clip='rescale',
bits_per_sample=16, as_float=False):
def save_audio(wav: torch.Tensor,
path: tp.Union[str, Path],
samplerate: int,
bitrate: int = 320,
clip: tp.Literal["rescale", "clamp", "tanh", "none"] = 'rescale',
bits_per_sample: tp.Literal[16, 24, 32] = 16,
as_float: bool = False):
"""Save audio file, automatically preventing clipping if necessary
based on the given `clip` strategy. If the path ends in `.mp3`, this
will save as mp3 with the given `bitrate`.
Expand Down
2 changes: 1 addition & 1 deletion demucs/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def add_model_flags(parser):
group = parser.add_mutually_exclusive_group(required=False)
group.add_argument("-s", "--sig", help="Locally trained XP signature.")
group.add_argument("-n", "--name", default=None,
help="Pretrained model name or signature. Default is mdx_extra_q.")
help="Pretrained model name or signature. Default is htdemucs.")
parser.add_argument("--repo", type=Path,
help="Folder containing all pre-trained models for use with -n.")

Expand Down
15 changes: 7 additions & 8 deletions demucs/separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def load_track(track, audio_channels, samplerate):
return wav


def main(opts=None):
def get_parser():
parser = argparse.ArgumentParser("demucs.separate",
description="Separate the sources for the given tracks")
parser.add_argument("tracks", nargs='+', type=Path, default=[], help='Path to tracks')
Expand Down Expand Up @@ -115,6 +115,11 @@ def main(opts=None):
help="Number of jobs. This can increase memory usage but will "
"be much faster when multiple cores are available.")

return parser


def main(opts=None):
parser = get_parser()
args = parser.parse_args(opts)

try:
Expand All @@ -131,12 +136,6 @@ def main(opts=None):
if isinstance(model, BagOfModels):
print(f"Selected model is a bag of {len(model.models)} models. "
"You will see that many progress bars per track.")
if args.segment is not None:
for sub in model.models:
sub.segment = args.segment
else:
if args.segment is not None:
model.segment = args.segment

model.cpu()
model.eval()
Expand All @@ -162,7 +161,7 @@ def main(opts=None):
wav = (wav - ref.mean()) / ref.std()
sources = apply_model(model, wav[None], device=args.device, shifts=args.shifts,
split=args.split, overlap=args.overlap, progress=True,
num_workers=args.jobs)[0]
num_workers=args.jobs, segment=args.segment)[0]
sources = sources * ref.std() + ref.mean()

if args.mp3:
Expand Down
4 changes: 2 additions & 2 deletions docs/linux.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Linux support for Demucs

If your distribution has at least Python 3.7, and you just wish to separate
If your distribution has at least Python 3.8, and you just wish to separate
tracks with Demucs, not train it, you can just run

```bash
Expand All @@ -11,7 +11,7 @@ python3 -m demucs -d cpu PATH_TO_AUDIO_FILE_1
demucs -d cpu PATH_TO_AUDIO_FILE_1
```

If Python is too old, or you want to be able to train, I recommend [installing Miniconda][miniconda], with Python 3.7 or more.
If Python is too old, or you want to be able to train, I recommend [installing Miniconda][miniconda], with Python 3.8 or more.

```bash
conda activate
Expand Down
2 changes: 1 addition & 1 deletion docs/windows.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

Parts of the code are untested on Windows (in particular, training a new model). If you don't have much experience with Anaconda, python or the shell, here are more detailed instructions. Note that **Demucs is not supported on 32bits systems** (as Pytorch is not available there).

- First install Anaconda with **Python 3.7** or more recent, which you can find [here][install].
- First install Anaconda with **Python 3.8** or more recent, which you can find [here][install].
- Start the [Anaconda prompt][prompt].

Then, all commands that follow must be run from this prompt.
Expand Down
2 changes: 1 addition & 1 deletion environment-cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- conda-forge

dependencies:
- python>=3.7,<3.10
- python>=3.8,<3.10
- ffmpeg>=4.2
- pytorch>=1.8.1
- torchaudio>=0.8
Expand Down
2 changes: 1 addition & 1 deletion environment-cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ channels:
- conda-forge

dependencies:
- python>=3.7,<3.10
- python>=3.8,<3.10
- ffmpeg>=4.2
- pytorch>=1.8.1
- torchaudio>=0.8
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
URL = 'https://github.com/facebookresearch/demucs'
EMAIL = '[email protected]'
AUTHOR = 'Alexandre Défossez'
REQUIRES_PYTHON = '>=3.7.0'
REQUIRES_PYTHON = '>=3.8.0'

HERE = Path(__file__).parent

Expand Down

0 comments on commit 51b6545

Please sign in to comment.