-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy pathextractSDAE.py
43 lines (39 loc) · 1.58 KB
/
extractSDAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
# This class is similar to SDAE code.
# This model is initiated when SDAE is needed for training without Dropout modules i.e., during DCC.
class extractSDAE(nn.Module):
def __init__(self, dim, slope=0.0):
super(extractSDAE, self).__init__()
self.in_dim = dim[0]
self.nlayers = len(dim)-1
self.reluslope = slope
self.enc, self.dec = [], []
for i in range(self.nlayers):
self.enc.append(nn.Linear(dim[i], dim[i+1]))
setattr(self, 'enc_{}'.format(i), self.enc[-1])
self.dec.append(nn.Linear(dim[i+1], dim[i]))
setattr(self, 'dec_{}'.format(i), self.dec[-1])
self.base = []
for i in range(self.nlayers):
self.base.append(nn.Sequential(*self.enc[:i]))
# initialization
for m in self.modules():
if isinstance(m, nn.Linear):
init.normal(m.weight, std=1e-2)
if m.bias.data is not None:
init.constant(m.bias, 0)
def forward(self,x):
inp = x.view(-1, self.in_dim)
encoded = inp
for i, encoder in enumerate(self.enc):
encoded = encoder(encoded)
if i < self.nlayers-1:
encoded = F.leaky_relu(encoded, negative_slope=self.reluslope)
out = encoded
for i, decoder in reversed(list(enumerate(self.dec))):
out = decoder(out)
if i:
out = F.leaky_relu(out, negative_slope=self.reluslope)
return encoded, out