-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmodel.py
78 lines (53 loc) · 2.29 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
class URnn(torch.nn.Module):
def __init__(self, frame_size=512, hidden_size=128, num_layers=1, dropout=0.2, type='gru'):
super(URnn, self).__init__()
if num_layers == 1:
rnn_dropout = 0.0
else:
rnn_dropout = dropout
self.bn = torch.nn.BatchNorm2d(num_features=2)
if type == 'gru':
self.rnn = torch.nn.GRU(input_size=2*(int(frame_size/2+1)),
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=False,
dropout=rnn_dropout)
if type == 'lstm':
self.rnn = torch.nn.LSTM(input_size=2*(int(frame_size/2+1)),
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
bidirectional=False,
dropout=rnn_dropout)
self.dp = torch.nn.Dropout(p=dropout)
self.fc = torch.nn.Conv2d(in_channels=hidden_size,
out_channels=int(frame_size/2+1),
kernel_size=1)
def forward(self, x):
# Permute: N x T x F x 2 > N x 2 x T x F
x = x.permute(0, 3, 1, 2)
# Batch norm: N x 2 x T x F > N x 2 x T x F
x = self.bn(x)
# Permute: N x 2 x T x F > N x T x F x 2
x = x.permute(0, 2, 3, 1)
# View: N x T x F x 2 > N x T x 2F
x = torch.reshape(x, (x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
# RNN: N x T x 2F > N x T x H
x, _ = self.rnn(x)
# Permute: N x T x H > N x H x T
x = x.permute(0, 2, 1)
# Unsqueeze: N x H x T > N x H x T x 1
x = torch.unsqueeze(x, 3)
# Dropout
x = self.dp(x)
# Fully Connected: N x H x T x 1 > N x F x T x 1
x = self.fc(x)
# Permute: N x F x T x 1 > N x 1 x T x F
x = x.permute(0, 3, 2, 1)
# Squeeze: N x 1 x T x F > N x T x F
x = torch.squeeze(x, dim=1)
# Set between 0 and 1
x = torch.sigmoid(x)
return x