Skip to content

Commit

Permalink
avoid using mps for complex numbers (facebookresearch#503)
Browse files Browse the repository at this point in the history
* avoid using mps for complex numbers

* fixed typo (x->z)

* fixed lint and added description in release.md
  • Loading branch information
HeLehm authored Jun 2, 2023
1 parent 65d2bee commit 5d2ccf2
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 2 deletions.
14 changes: 13 additions & 1 deletion demucs/hdemucs.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ def forward(self, mix):
length = x.shape[-1]

z = self._spec(mix)
mag = self._magnitude(z)
mag = self._magnitude(z).to(mix.device)
x = mag

B, C, Fq, T = x.shape
Expand Down Expand Up @@ -772,9 +772,21 @@ def forward(self, mix):
x = x.view(B, S, -1, Fq, T)
x = x * std[:, None] + mean[:, None]

# to cpu as mps doesnt support complex numbers
# demucs issue #435 ##432
# NOTE: in this case z already is on cpu
# TODO: remove this when mps supports complex numbers
x_is_mps = x.device.type == "mps"
if x_is_mps:
x = x.cpu()

zout = self._mask(z, x)
x = self._ispec(zout, length)

# back to mps device
if x_is_mps:
x = x.to('mps')

if self.hybrid:
xt = xt.view(B, S, -1, length)
xt = xt * stdt[:, None] + meant[:, None]
Expand Down
14 changes: 13 additions & 1 deletion demucs/htdemucs.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def forward(self, mix):
length_pre_pad = mix.shape[-1]
mix = F.pad(mix, (0, training_length - length_pre_pad))
z = self._spec(mix)
mag = self._magnitude(z)
mag = self._magnitude(z).to(mix.device)
x = mag

B, C, Fq, T = x.shape
Expand Down Expand Up @@ -625,6 +625,14 @@ def forward(self, mix):
x = x.view(B, S, -1, Fq, T)
x = x * std[:, None] + mean[:, None]

# to cpu as mps doesnt support complex numbers
# demucs issue #435 ##432
# NOTE: in this case z already is on cpu
# TODO: remove this when mps supports complex numbers
x_is_mps = x.device.type == "mps"
if x_is_mps:
x = x.cpu()

zout = self._mask(z, x)
if self.use_train_segment:
if self.training:
Expand All @@ -634,6 +642,10 @@ def forward(self, mix):
else:
x = self._ispec(zout, length)

# back to mps device
if x_is_mps:
x = x.to("mps")

if self.use_train_segment:
if self.training:
xt = xt.view(B, S, -1, length)
Expand Down
6 changes: 6 additions & 0 deletions demucs/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
def spectro(x, n_fft=512, hop_length=None, pad=0):
*other, length = x.shape
x = x.reshape(-1, length)
is_mps = x.device.type == 'mps'
if is_mps:
x = x.cpu()
z = th.stft(x,
n_fft * (1 + pad),
hop_length or n_fft // 4,
Expand All @@ -29,6 +32,9 @@ def ispectro(z, hop_length=None, length=None, pad=0):
n_fft = 2 * freqs - 2
z = z.view(-1, freqs, frames)
win_length = n_fft // (1 + pad)
is_mps = z.device.type == 'mps'
if is_mps:
z = z.cpu()
x = th.istft(z,
n_fft,
hop_length,
Expand Down
2 changes: 2 additions & 0 deletions docs/release.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Made diffq an optional dependency, with an error message if not installed.

Added output format flac (Free Lossless Audio Codec)

Will use CPU for complex numbers, when using MPS device (all other computations are performed by mps).

Optimize codes to save memory

## V4.0.0, 7th of December 2022
Expand Down

0 comments on commit 5d2ccf2

Please sign in to comment.