forked from facebookresearch/demucs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_pretrained.py
43 lines (32 loc) · 1.19 KB
/
test_pretrained.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Script to evaluate pretrained models.
from argparse import ArgumentParser
import logging
import sys
import torch
from demucs import train, pretrained, evaluate
def main():
torch.set_num_threads(1)
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
parser = ArgumentParser("tools.test_pretrained",
description="Evaluate pre-trained models or bags of models "
"on MusDB.")
pretrained.add_model_flags(parser)
parser.add_argument('overrides', nargs='*',
help='Extra overrides, e.g. test.shifts=2.')
args = parser.parse_args()
xp = train.main.get_xp(args.overrides)
with xp.enter():
solver = train.get_solver(xp.cfg)
model = pretrained.get_model_from_args(args)
solver.model = model.to(solver.device)
solver.model.eval()
with torch.no_grad():
results = evaluate.evaluate(solver, xp.cfg.test.sdr)
print(results)
if __name__ == '__main__':
main()