forked from facebookresearch/demucs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bench.py
78 lines (67 loc) · 2.23 KB
/
bench.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
benchmarking script, useful to check for OOM, reasonable train time,
and for the MDX competion, estimate if we will match the time limit."""
from contextlib import contextmanager
import logging
import sys
import time
import torch
from demucs.train import get_solver, main
from demucs.apply import apply_model
logging.basicConfig(level=logging.INFO, stream=sys.stderr)
class Result:
pass
@contextmanager
def bench():
import gc
gc.collect()
torch.cuda.reset_max_memory_allocated()
torch.cuda.empty_cache()
result = Result()
# before = torch.cuda.memory_allocated()
before = 0
begin = time.time()
try:
yield result
finally:
torch.cuda.synchronize()
mem = (torch.cuda.max_memory_allocated() - before) / 2 ** 20
tim = time.time() - begin
result.mem = mem
result.tim = tim
xp = main.get_xp_from_sig(sys.argv[1])
xp = main.get_xp(xp.argv + sys.argv[2:])
with xp.enter():
solver = get_solver(xp.cfg)
if getattr(solver.model, 'use_train_segment', False):
batch = solver.augment(next(iter(solver.loaders['train'])))
solver.model.segment = Fraction(batch.shape[-1], solver.model.samplerate)
train_segment = solver.model.segment
solver.model.eval()
model = solver.model
model.cuda()
x = torch.randn(2, xp.cfg.dset.channels, int(10 * model.samplerate), device='cuda')
with bench() as res:
y = model(x)
y.sum().backward()
del y
for p in model.parameters():
p.grad = None
print(f"FB: {res.mem:.1f} MB, {res.tim * 1000:.1f} ms")
x = torch.randn(1, xp.cfg.dset.channels, int(model.segment * model.samplerate), device='cuda')
with bench() as res:
with torch.no_grad():
y = model(x)
del y
print(f"FV: {res.mem:.1f} MB, {res.tim * 1000:.1f} ms")
model.cpu()
torch.set_num_threads(1)
test = torch.randn(1, xp.cfg.dset.channels, model.samplerate * 40)
b = time.time()
apply_model(model, test, split=True, shifts=1)
print("CPU 40 sec:", time.time() - b)