-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathresnet.py
158 lines (130 loc) · 5.19 KB
/
resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from __future__ import absolute_import
from __future__ import division
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
from ..utils.torchtools import weights_init_kaiming
__all__ = ['ResNet50', 'ResNet101', 'ResNet50M', 'ResNet50B']
class ResNet50(nn.Module):
def __init__(self, num_classes, loss={'xent'}, **kwargs):
super(ResNet50, self).__init__()
self.loss = loss
resnet50 = torchvision.models.resnet50(pretrained=True)
self.base = nn.Sequential(*list(resnet50.children())[:-2])
self.classifier = nn.Linear(2048, num_classes)
self.feat_dim = 2048
def forward(self, x):
x = self.base(x)
x = F.avg_pool2d(x, x.size()[2:])
f = x.view(x.size(0), -1)
if not self.training:
return f
y = self.classifier(f)
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
return y, f
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
class ResNet101(nn.Module):
def __init__(self, num_classes, loss={'xent'}, **kwargs):
super(ResNet101, self).__init__()
self.loss = loss
resnet101 = torchvision.models.resnet101(pretrained=True)
self.base = nn.Sequential(*list(resnet101.children())[:-2])
self.classifier = nn.Linear(2048, num_classes)
self.feat_dim = 2048 # feature dimension
def forward(self, x):
x = self.base(x)
x = F.avg_pool2d(x, x.size()[2:])
f = x.view(x.size(0), -1)
if not self.training:
return f
y = self.classifier(f)
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
return y, f
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
class ResNet50M(nn.Module):
"""ResNet50 + mid-level features.
Reference:
Yu et al. The Devil is in the Middle: Exploiting Mid-level Representations for
Cross-Domain Instance Matching. arXiv:1711.08106.
"""
def __init__(self, num_classes=0, loss={'xent'}, **kwargs):
super(ResNet50M, self).__init__()
self.loss = loss
resnet50 = torchvision.models.resnet50(pretrained=True)
base = nn.Sequential(*list(resnet50.children())[:-2])
self.layers1 = nn.Sequential(base[0], base[1], base[2])
self.layers2 = nn.Sequential(base[3], base[4])
self.layers3 = base[5]
self.layers4 = base[6]
self.layers5a = base[7][0]
self.layers5b = base[7][1]
self.layers5c = base[7][2]
self.fc_fuse = nn.Sequential(nn.Linear(4096, 1024), nn.BatchNorm1d(1024), nn.ReLU())
self.classifier = nn.Linear(3072, num_classes)
self.feat_dim = 3072 # feature dimension
def forward(self, x):
x1 = self.layers1(x)
x2 = self.layers2(x1)
x3 = self.layers3(x2)
x4 = self.layers4(x3)
x5a = self.layers5a(x4)
x5b = self.layers5b(x5a)
x5c = self.layers5c(x5b)
x5a_feat = F.avg_pool2d(x5a, x5a.size()[2:]).view(x5a.size(0), x5a.size(1))
x5b_feat = F.avg_pool2d(x5b, x5b.size()[2:]).view(x5b.size(0), x5b.size(1))
x5c_feat = F.avg_pool2d(x5c, x5c.size()[2:]).view(x5c.size(0), x5c.size(1))
midfeat = torch.cat((x5a_feat, x5b_feat), dim=1)
midfeat = self.fc_fuse(midfeat)
combofeat = torch.cat((x5c_feat, midfeat), dim=1)
if not self.training:
return combofeat
prelogits = self.classifier(combofeat)
if self.loss == {'xent'}:
return prelogits
elif self.loss == {'xent', 'htri'}:
return prelogits, combofeat
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
class ResNet50B(nn.Module):
"""Resnet50+bottleneck
Reference:
https://github.com/L1aoXingyu/reid_baseline
"""
def __init__(self, num_classes=0, loss={'xent'}, **kwargs):
super(ResNet50B, self).__init__()
self.loss = loss
resnet50 = torchvision.models.resnet50(pretrained=True)
resnet50.layer4[0].conv2.stride = (1, 1)
resnet50.layer4[0].downsample[0].stride = (1, 1)
self.base = nn.Sequential(*list(resnet50.children())[:-2])
self.in_planes = 2048
self.bottleneck = nn.Sequential(
nn.Linear(self.in_planes, 512),
nn.BatchNorm1d(512),
nn.LeakyReLU(0.1),
nn.Dropout(p=0.5))
self.bottleneck.apply(weights_init_kaiming)
self.classifier = nn.Linear(512, num_classes)
self.classifier.apply(weights_init_kaiming)
def forward(self, x):
global_feat = self.base(x)
global_feat = F.avg_pool2d(global_feat, global_feat.size()[-2:])
global_feat = global_feat.view(global_feat.size(0), -1)
if not self.training:
return global_feat
else:
feat = self.bottleneck(global_feat)
y = self.classifier(feat)
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
return y, global_feat
else:
raise KeyError("Unsupported loss: {}".format(self.loss))