-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathutils.py
94 lines (64 loc) · 2.06 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
import torch
import torch.nn.functional as F
import random
activation_getter = {'iden': lambda x: x,
'relu': F.relu,
'prelu': torch.nn.PReLU,
'leaky_relu': F.leaky_relu,
'tanh': torch.tanh,
'sigm': torch.sigmoid}
def gpu(tensor, gpu=False):
if gpu:
return tensor.cuda()
else:
return tensor
def cpu(tensor):
if tensor.is_cuda:
return tensor.cpu()
else:
return tensor
def minibatch(*tensors, **kwargs):
batch_size = kwargs.get('batch_size', 128)
if len(tensors) == 1:
tensor = tensors[0]
for i in range(0, len(tensor), batch_size):
yield tensor[i:i + batch_size]
else:
for i in range(0, len(tensors[0]), batch_size):
yield tuple(x[i:i + batch_size] for x in tensors)
def shuffle(*arrays, **kwargs):
require_indices = kwargs.get('indices', False)
if len(set(len(x) for x in arrays)) != 1:
raise ValueError('All inputs to shuffle must have '
'the same length.')
shuffle_indices = np.arange(len(arrays[0]))
np.random.shuffle(shuffle_indices)
if len(arrays) == 1:
result = arrays[0][shuffle_indices]
else:
result = tuple(x[shuffle_indices] for x in arrays)
if require_indices:
return result, shuffle_indices
else:
return result
def assert_no_grad(variable):
if variable.requires_grad:
raise ValueError(
"nn criterions don't compute the gradient w.r.t. targets - please "
"mark these variables as volatile or not requiring gradients"
)
def set_seed(seed, cuda=False):
np.random.seed(seed)
random.seed(seed)
if cuda:
torch.cuda.manual_seed(seed)
else:
torch.manual_seed(seed)
def compute_model_size(model):
num_params = 0
for param in model.parameters():
num_params += param.view(-1).size()[0]
return num_params
def str2bool(v):
return v.lower() in ('true')