-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
127 lines (108 loc) · 4.77 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.nn as nn
from scipy.stats import bernoulli
import torch.nn.functional as F
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
# input of shape(batch_size,inp_chan,iW)
class ConvNet(nn.Module):
def __init__(self, nummotif, motiflen, poolType, neuType, mode, dropprob, learning_rate, momentum_rate, sigmaConv,
sigmaNeu, beta1, beta2, beta3, reverse_complemet_mode=False):
super(ConvNet, self).__init__()
self.poolType = poolType
self.neuType = neuType
self.mode = mode
self.reverse_complemet_mode = reverse_complemet_mode
self.dropprob = dropprob
self.learning_rate = learning_rate
self.momentum_rate = momentum_rate
self.sigmaConv = sigmaConv
self.sigmaNeu = sigmaNeu
self.beta1 = beta1
self.beta2 = beta2
self.beta3 = beta3
self.wConv = torch.randn(nummotif, 4, motiflen).to(device)
torch.nn.init.normal_(self.wConv, mean=0, std=self.sigmaConv)
self.wConv.requires_grad = True
self.wRect = torch.randn(nummotif).to(device)
torch.nn.init.normal_(self.wRect)
self.wRect = -self.wRect
self.wRect.requires_grad = True
if neuType == 'nohidden':
if poolType == 'maxavg':
self.wNeu = torch.randn(2 * nummotif, 1).to(device)
else:
self.wNeu = torch.randn(nummotif, 1).to(device)
self.wNeuBias = torch.randn(1).to(device)
torch.nn.init.normal_(self.wNeu, mean=0, std=self.sigmaNeu)
torch.nn.init.normal_(self.wNeuBias, mean=0, std=self.sigmaNeu)
else:
if poolType == 'maxavg':
self.wHidden = torch.randn(2 * nummotif, 32).to(device)
else:
self.wHidden = torch.randn(nummotif, 32).to(device)
self.wNeu = torch.randn(32, 1).to(device)
self.wNeuBias = torch.randn(1).to(device)
self.wHiddenBias = torch.randn(32).to(device)
torch.nn.init.normal_(self.wNeu, mean=0, std=self.sigmaNeu)
torch.nn.init.normal_(self.wNeuBias, mean=0, std=self.sigmaNeu)
torch.nn.init.normal_(self.wHidden, mean=0, std=0.3)
torch.nn.init.normal_(self.wHiddenBias, mean=0, std=0.3)
self.wHidden.requires_grad = True
self.wHiddenBias.requires_grad = True
# wHiddenBias=tf.Variable(tf.truncated_normal([32,1],mean=0,stddev=sigmaNeu)) #hidden bias for everything
self.wNeu.requires_grad = True
self.wNeuBias.requires_grad = True
def divide_two_tensors(self, x):
l = torch.unbind(x)
list1 = [l[2 * i] for i in range(int(x.shape[0] / 2))]
list2 = [l[2 * i + 1] for i in range(int(x.shape[0] / 2))]
x1 = torch.stack(list1, 0)
x2 = torch.stack(list2, 0)
return x1, x2
def forward_pass(self, x, mask=None, use_mask=False):
conv = F.conv1d(x, self.wConv, bias=self.wRect, stride=1, padding=0)
rect = conv.clamp(min=0)
maxPool, _ = torch.max(rect, dim=2)
if self.poolType == 'maxavg':
avgPool = torch.mean(rect, dim=2)
pool = torch.cat((maxPool, avgPool), 1)
else:
pool = maxPool
if self.neuType == 'nohidden':
if self.mode == 'training':
if not use_mask:
mask = bernoulli.rvs(self.dropprob, size=len(pool[0]))
mask = torch.from_numpy(mask).float().to(device)
pooldrop = pool * mask
out = pooldrop @ self.wNeu
out.add_(self.wNeuBias)
else:
out = self.dropprob * (pool @ self.wNeu)
out.add_(self.wNeuBias)
else:
hid = pool @ self.wHidden
hid.add_(self.wHiddenBias)
hid = hid.clamp(min=0)
if self.mode == 'training':
if not use_mask:
mask = bernoulli.rvs(self.dropprob, size=len(hid[0]))
mask = torch.from_numpy(mask).float().to(device)
hiddrop = hid * mask
out = self.dropprob * (hid @ self.wNeu)
out.add_(self.wNeuBias)
else:
out = self.dropprob * (hid @ self.wNeu)
out.add_(self.wNeuBias)
return out, mask
def forward(self, x):
if not self.reverse_complemet_mode:
out, _ = self.forward_pass(x)
# print(out)
# print(out.shape, '---------------------------------')
else:
x1, x2 = self.divide_two_tensors(x)
out1, mask = self.forward_pass(x1)
out2, _ = self.forward_pass(x2, mask, True)
out = torch.max(out1, out2)
return out1