Skip to content

Commit

Permalink
release fix for number of stems with mel band roformer, thanks to @cr…
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 14, 2023
1 parent ec16878 commit edb9de2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord

- <a href="https://github.com/ZFTurbo">Roman</a> for successfully training the model and open sourcing his training code and weights at <a href="https://github.com/ZFTurbo/Music-Source-Separation-Training">this repository</a>!

- <a href="https://github.com/crlandsc">Christopher</a> for fixing an issue with multiple stems in MelBandRoFormer

## Install

```bash
Expand Down
14 changes: 7 additions & 7 deletions bs_roformer/mel_band_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,11 +400,10 @@ def forward(
if raw_audio.ndim == 2:
raw_audio = rearrange(raw_audio, 'b t -> b 1 t')

original_length = (
raw_audio.shape[-1] if self.match_input_audio_length else original_length
)
batch, channels, raw_audio_length = raw_audio.shape

istft_length = raw_audio_length if self.match_input_audio_length else None

batch, channels, *_ = raw_audio.shape
assert (not self.stereo and channels == 1) or (self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'

# to stft
Expand Down Expand Up @@ -466,9 +465,10 @@ def forward(

# need to average the estimated mask for the overlapped frequencies

scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b = batch, n=num_stems, t = stft_repr.shape[-1])
scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b = batch, n = num_stems, t = stft_repr.shape[-1])

masks_summed = torch.zeros_like(stft_repr.expand(-1, num_stems, -1, -1)).scatter_add_(2, scatter_indices, masks)
stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n = num_stems)
masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)

denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r = channels)

Expand All @@ -482,7 +482,7 @@ def forward(

stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s = self.audio_channels)

recon_audio = torch.istft(stft_repr, **self.stft_kwargs, return_complex = False, length=original_length)
recon_audio = torch.istft(stft_repr, **self.stft_kwargs, return_complex = False, length = istft_length)

recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s = self.audio_channels, n = num_stems)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'BS-RoFormer',
packages = find_packages(exclude=[]),
version = '0.3.4',
version = '0.3.5',
license='MIT',
description = 'BS-RoFormer - Band-Split Rotary Transformer for SOTA Music Source Separation',
author = 'Phil Wang',
Expand Down

0 comments on commit edb9de2

Please sign in to comment.