Skip to content

Commit

Permalink
Always use CPU RAM for the whole WAV track and most temporary results
Browse files Browse the repository at this point in the history
If CUDA is available, we use it only for the computational intensive
operations like the model apply. Small temporary results are still kept
in CUDA, in order to avoid unneeded transfer to CPU.

This approach lets us process arbitrary long tracks while requiring
GPU VRAM only for the model size. The MDX model takes about 1.3 GB
of GPU VRAM.

Processing a one hour audio track this way takes about 6% more time than
doing it purely in the GPU. It requires additional 25 GB of CPU RAM.

Without the current code changes, processing a one hour track is
not possible even with a GPU with 16 GB VRAM.
  • Loading branch information
famzah committed Dec 6, 2021
1 parent c03490f commit c5f28b4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
19 changes: 9 additions & 10 deletions demucs/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def tensor_chunk(tensor_or_chunk):
return TensorChunk(tensor_or_chunk)


def apply_model(model, mix, shifts=1, split=True,
def apply_model(model, mix, device, shifts=1, split=True,
overlap=0.25, transition_power=1., progress=False):
"""
Apply model to a given mixture.
Expand All @@ -138,7 +138,7 @@ def apply_model(model, mix, shifts=1, split=True,
totals = [0] * len(model.sources)
for sub_model, weight in zip(model.models, model.weights):
out = apply_model(
sub_model, mix,
sub_model, mix, device,
shifts=shifts, split=split, overlap=overlap,
transition_power=transition_power, progress=progress)
for k in range(out.shape[0]):
Expand All @@ -150,11 +150,10 @@ def apply_model(model, mix, shifts=1, split=True,
return estimates

assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
device = mix.device
batch, channels, length = mix.shape
if split:
out = th.zeros(batch, len(model.sources), channels, length, device=device)
sum_weight = th.zeros(length, device=device)
out = th.zeros(batch, len(model.sources), channels, length, device="cpu")
sum_weight = th.zeros(length, device="cpu")
segment = int(model.samplerate * model.segment)
stride = int((1 - overlap) * segment)
offsets = range(0, length, stride)
Expand All @@ -172,10 +171,10 @@ def apply_model(model, mix, shifts=1, split=True,
weight = (weight / weight.max())**transition_power
for offset in offsets:
chunk = TensorChunk(mix, offset, segment)
chunk_out = apply_model(model, chunk, shifts=shifts, split=False)
chunk_out = apply_model(model, chunk, device, shifts=shifts, split=False)
chunk_length = chunk_out.shape[-1]
out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out
sum_weight[offset:offset + segment] += weight[:chunk_length]
out[..., offset:offset + segment] += (weight[:chunk_length] * chunk_out).cpu()
sum_weight[offset:offset + segment] += weight[:chunk_length].cpu()
offset += segment
assert sum_weight.min() > 0
out /= sum_weight
Expand All @@ -188,7 +187,7 @@ def apply_model(model, mix, shifts=1, split=True,
for _ in range(shifts):
offset = random.randint(0, max_shift)
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
shifted_out = apply_model(model, shifted, shifts=0, split=False)
shifted_out = apply_model(model, shifted, device, shifts=0, split=False)
out += shifted_out[..., max_shift - offset:]
out /= shifts
return out
Expand All @@ -198,7 +197,7 @@ def apply_model(model, mix, shifts=1, split=True,
else:
valid_length = length
mix = tensor_chunk(mix)
padded_mix = mix.padded(valid_length)
padded_mix = mix.padded(valid_length).to(device)
with th.no_grad():
out = model(padded_mix)
return center_trim(out, length)
5 changes: 2 additions & 3 deletions demucs/separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,17 @@ def main():
file=sys.stderr)
continue
print(f"Separating track {track}")
wav = load_track(track, args.device, model.audio_channels, model.samplerate)
wav = load_track(track, "cpu", model.audio_channels, model.samplerate)

ref = wav.mean(0)
wav = (wav - ref.mean()) / ref.std()
sources = apply_model(model, wav[None], shifts=args.shifts, split=args.split,
sources = apply_model(model, wav[None], args.device, shifts=args.shifts, split=args.split,
overlap=args.overlap, progress=True)[0]
sources = sources * ref.std() + ref.mean()

track_folder = out / track.name.rsplit(".", 1)[0]
track_folder.mkdir(exist_ok=True)
for source, name in zip(sources, model.sources):
source = source.cpu()
stem = str(track_folder / name)
if args.mp3:
stem += ".mp3"
Expand Down

0 comments on commit c5f28b4

Please sign in to comment.