Skip to content

Commit

Permalink
adding support for saving states as half precision, support for conca…
Browse files Browse the repository at this point in the history
…tenating musdb and custom wav dataset
  • Loading branch information
adefossez committed Jul 5, 2021
1 parent 9602180 commit d4b9235
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 5 deletions.
11 changes: 11 additions & 0 deletions demucs/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch as th
from torch import distributed, nn
from torch.utils.data import ConcatDataset
from torch.nn.parallel.distributed import DistributedDataParallel

from .augment import FlipChannels, FlipSign, Remix, Scale, Shift
Expand Down Expand Up @@ -147,6 +148,7 @@ def main():
if args.save_model:
if args.rank == 0:
model.to("cpu")
assert saved.best_state is not None, "model needs to train for 1 epoch at least."
model.load_state_dict(saved.best_state)
save_model(model, quantizer, args, args.models / model_name)
return
Expand Down Expand Up @@ -198,10 +200,19 @@ def main():
valid_set = Rawset(args.raw / "valid", channels=args.audio_channels)
elif args.wav:
train_set, valid_set = get_wav_datasets(args, samples, model.sources)

if args.concat:
if args.is_wav:
mus_train, mus_valid = get_musdb_wav_datasets(args, samples, model.sources)
else:
mus_train, mus_valid = get_compressed_datasets(args, samples)
train_set = ConcatDataset([train_set, mus_train])
valid_set = ConcatDataset([valid_set, mus_valid])
elif args.is_wav:
train_set, valid_set = get_musdb_wav_datasets(args, samples, model.sources)
else:
train_set, valid_set = get_compressed_datasets(args, samples)
print("Train set and valid set sizes", len(train_set), len(valid_set))

if args.repitch:
train_set = RepitchedWrapper(
Expand Down
4 changes: 4 additions & 0 deletions demucs/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def get_parser():
parser.add_argument("--wav", type=Path,
help="Path to a wav dataset. This should contain a 'train' and a 'valid' "
"subfolder.")
parser.add_argument("--concat", action='store_true',
help="Concat MusDB and wav dataset when provided.")
parser.add_argument("--samplerate", type=int, default=44100)
parser.add_argument("--audio_channels", type=int, default=2)
parser.add_argument("--samples",
Expand Down Expand Up @@ -188,6 +190,8 @@ def get_parser():
help="Skip training, just save state "
"for the current checkpoint value. You should "
"provide a model name as argument.")
parser.add_argument("--half", action="store_true",
help="When saving the model, uses half precision.")

# Quantization options
parser.add_argument("--q-min-size", type=float, default=1,
Expand Down
7 changes: 4 additions & 3 deletions demucs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,10 @@ def load_model(path, strict=False):
return model


def get_state(model, quantizer):
def get_state(model, quantizer, half=False):
if quantizer is None:
state = {k: p.data.to('cpu') for k, p in model.state_dict().items()}
dtype = th.half if half else None
state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()}
else:
state = quantizer.get_quantized_state()
buf = io.BytesIO()
Expand Down Expand Up @@ -301,7 +302,7 @@ def save_model(model, quantizer, training_args, path):
args, kwargs = model._init_args_kwargs
klass = model.__class__

state = get_state(model, quantizer)
state = get_state(model, quantizer, half=training_args.half)

save_to = path
package = {
Expand Down
9 changes: 7 additions & 2 deletions demucs/wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import math
import json
from pathlib import Path
import os

import julius
import torch as th
Expand Down Expand Up @@ -53,8 +54,12 @@ def _track_metadata(track, sources):
def _build_metadata(path, sources):
meta = {}
path = Path(path)
for file in path.iterdir():
meta[file.name] = _track_metadata(file, sources)
for root, folders, files in os.walk(path, followlinks=True):
root = Path(root)
if root.name.startswith('.') or folders or root == path:
continue
name = str(root.relative_to(path))
meta[name] = _track_metadata(root, sources)
return meta


Expand Down

0 comments on commit d4b9235

Please sign in to comment.