-
Notifications
You must be signed in to change notification settings - Fork 661
/
Copy pathutils_regularizers.py
104 lines (92 loc) · 3.34 KB
/
utils_regularizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn
'''
# --------------------------------------------
# Kai Zhang (github: https://github.com/cszn)
# 03/Mar/2019
# --------------------------------------------
'''
# --------------------------------------------
# SVD Orthogonal Regularization
# --------------------------------------------
def regularizer_orth(m):
"""
# ----------------------------------------
# SVD Orthogonal Regularization
# ----------------------------------------
# Applies regularization to the training by performing the
# orthogonalization technique described in the paper
# This function is to be called by the torch.nn.Module.apply() method,
# which applies svd_orthogonalization() to every layer of the model.
# usage: net.apply(regularizer_orth)
# ----------------------------------------
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1:
w = m.weight.data.clone()
c_out, c_in, f1, f2 = w.size()
# dtype = m.weight.data.type()
w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
# self.netG.apply(svd_orthogonalization)
u, s, v = torch.svd(w)
s[s > 1.5] = s[s > 1.5] - 1e-4
s[s < 0.5] = s[s < 0.5] + 1e-4
w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
else:
pass
# --------------------------------------------
# SVD Orthogonal Regularization
# --------------------------------------------
def regularizer_orth2(m):
"""
# ----------------------------------------
# Applies regularization to the training by performing the
# orthogonalization technique described in the paper
# This function is to be called by the torch.nn.Module.apply() method,
# which applies svd_orthogonalization() to every layer of the model.
# usage: net.apply(regularizer_orth2)
# ----------------------------------------
"""
classname = m.__class__.__name__
if classname.find('Conv') != -1:
w = m.weight.data.clone()
c_out, c_in, f1, f2 = w.size()
# dtype = m.weight.data.type()
w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out)
u, s, v = torch.svd(w)
s_mean = s.mean()
s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4
s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4
w = torch.mm(torch.mm(u, torch.diag(s)), v.t())
m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype)
else:
pass
def regularizer_clip(m):
"""
# ----------------------------------------
# usage: net.apply(regularizer_clip)
# ----------------------------------------
"""
eps = 1e-4
c_min = -1.5
c_max = 1.5
classname = m.__class__.__name__
if classname.find('Conv') != -1 or classname.find('Linear') != -1:
w = m.weight.data.clone()
w[w > c_max] -= eps
w[w < c_min] += eps
m.weight.data = w
if m.bias is not None:
b = m.bias.data.clone()
b[b > c_max] -= eps
b[b < c_min] += eps
m.bias.data = b
# elif classname.find('BatchNorm2d') != -1:
#
# rv = m.running_var.data.clone()
# rm = m.running_mean.data.clone()
#
# if m.affine:
# m.weight.data
# m.bias.data