forked from adefossez/demucs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
export.py
71 lines (57 loc) · 2.63 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Export a trained model from the full checkpoint (with optimizer etc.) to
a final checkpoint, with only the model itself. The model is always stored as
half float to gain space, and because this has zero impact on the final loss.
When DiffQ was used for training, the model will actually be quantized and bitpacked."""
from argparse import ArgumentParser
from fractions import Fraction
import logging
from pathlib import Path
import sys
import torch
from demucs import train
from demucs.states import serialize_model, save_with_checksum
logger = logging.getLogger(__name__)
def main():
logging.basicConfig(level=logging.INFO, stream=sys.stderr)
parser = ArgumentParser("tools.export", description="Export trained models from XP sigs.")
parser.add_argument('signatures', nargs='*', help='XP signatures.')
parser.add_argument('-o', '--out', type=Path, default=Path("release_models"),
help="Path where to store release models (default release_models)")
parser.add_argument('-s', '--sign', action='store_true',
help='Add sha256 prefix checksum to the filename.')
args = parser.parse_args()
args.out.mkdir(exist_ok=True, parents=True)
for sig in args.signatures:
xp = train.main.get_xp_from_sig(sig)
name = train.main.get_name(xp)
logger.info('Handling %s/%s', sig, name)
out_path = args.out / (sig + ".th")
solver = train.get_solver_from_sig(sig)
if len(solver.history) < solver.args.epochs:
logger.warning(
'Model %s has less epoch than expected (%d / %d)',
sig, len(solver.history), solver.args.epochs)
solver.model.load_state_dict(solver.best_state)
pkg = serialize_model(solver.model, solver.args, solver.quantizer, half=True)
if getattr(solver.model, 'use_train_segment', False):
batch = solver.augment(next(iter(solver.loaders['train'])))
pkg['kwargs']['segment'] = Fraction(batch.shape[-1], solver.model.samplerate)
print("Override", pkg['kwargs']['segment'])
valid, test = None, None
for m in solver.history:
if 'valid' in m:
valid = m['valid']
if 'test' in m:
test = m['test']
pkg['metrics'] = (valid, test)
if args.sign:
save_with_checksum(pkg, out_path)
else:
torch.save(pkg, out_path)
if __name__ == '__main__':
main()