-
Notifications
You must be signed in to change notification settings - Fork 88
/
dsmil.py
72 lines (60 loc) · 2.74 KB
/
dsmil.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class FCLayer(nn.Module):
def __init__(self, in_size, out_size=1):
super(FCLayer, self).__init__()
self.fc = nn.Sequential(nn.Linear(in_size, out_size))
def forward(self, feats):
x = self.fc(feats)
return feats, x
class IClassifier(nn.Module):
def __init__(self, feature_extractor, feature_size, output_class):
super(IClassifier, self).__init__()
self.feature_extractor = feature_extractor
self.fc = nn.Linear(feature_size, output_class)
def forward(self, x):
device = x.device
feats = self.feature_extractor(x) # N x K
c = self.fc(feats.view(feats.shape[0], -1)) # N x C
return feats.view(feats.shape[0], -1), c
class BClassifier(nn.Module):
def __init__(self, input_size, output_class, dropout_v=0.0): # K, L, N
super(BClassifier, self).__init__()
self.q = nn.Linear(input_size, 128)
self.v = nn.Sequential(
nn.Dropout(dropout_v),
nn.Linear(input_size, input_size)
)
### 1D convolutional layer that can handle multiple class (including binary)
self.fcc = nn.Conv1d(output_class, output_class, kernel_size=input_size)
def forward(self, feats, c): # N x K, N x C
device = feats.device
V = self.v(feats) # N x V, unsorted
Q = self.q(feats).view(feats.shape[0], -1)
for i in range(c.shape[1]):
_, indices = torch.sort(c[:, i], 0, True)
feats = torch.index_select(feats, 0, indices) # N x K, sorted
q_max = self.q(feats[0].view(1, -1)) # 1 x 1 x Q
temp = torch.mm(Q, q_max.view(-1, 1)) / torch.sqrt(torch.tensor(Q.shape[1], dtype=torch.float32, device=device))
if i == 0:
A = F.softmax(temp, 0) # N x 1
B = torch.sum(torch.mul(A, V), 0).view(1, -1) # 1 x V
else:
temp = F.softmax(temp, 0) # N x 1
A = torch.cat((A, temp), 1) # N x C
B = torch.cat((B, torch.sum(torch.mul(temp, V), 0).view(1, -1)), 0) # C x V -> 1 x C x V
B = B.view(1, B.shape[0], B.shape[1]) # 1 x C x V
C = self.fcc(B) # 1 x C x 1
C = C.view(1, -1)
return C, A, B
class MILNet(nn.Module):
def __init__(self, i_classifier, b_classifier):
super(MILNet, self).__init__()
self.i_classifier = i_classifier
self.b_classifier = b_classifier
def forward(self, x):
feats, classes = self.i_classifier(x)
prediction_bag, A, B = self.b_classifier(feats, classes)
return classes, prediction_bag, A, B