forked from d2l-ai/d2l-en
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase.py
41 lines (34 loc) · 1.03 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""The base module contains some basic functions/classes for d2l"""
import time
import mxnet as mx
from mxnet import nd
__all__ = ['try_gpu', 'try_all_gpus', 'Benchmark']
def try_gpu():
"""If GPU is available, return mx.gpu(0); else return mx.cpu()."""
try:
ctx = mx.gpu()
_ = nd.array([0], ctx=ctx)
except mx.base.MXNetError:
ctx = mx.cpu()
return ctx
def try_all_gpus():
"""Return all available GPUs, or [mx.cpu()] if there is no GPU."""
ctxes = []
try:
for i in range(16):
ctx = mx.gpu(i)
_ = nd.array([0], ctx=ctx)
ctxes.append(ctx)
except mx.base.MXNetError:
pass
if not ctxes:
ctxes = [mx.cpu()]
return ctxes
class Benchmark():
"""Benchmark programs."""
def __init__(self, prefix=None):
self.prefix = prefix + ' ' if prefix else ''
def __enter__(self):
self.start = time.time()
def __exit__(self, *args):
print('%stime: %.4f sec' % (self.prefix, time.time() - self.start))