Skip to content

Commit 238ffb3

Browse files
committed
add FEAT-Transductive
1 parent ca3ae17 commit 238ffb3

File tree

5 files changed

+337
-7
lines changed

5 files changed

+337
-7
lines changed

model/models/base.py

-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
import torch
22
import torch.nn as nn
33
import numpy as np
4-
from sklearn.svm import LinearSVC
5-
from sklearn.linear_model import LogisticRegression
6-
from sklearn.model_selection import GridSearchCV
74

85
class FewShotModel(nn.Module):
96
def __init__(self, args):

model/models/semi_feat.py

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import torch
2+
import torch.nn as nn
3+
import numpy as np
4+
import torch.nn.functional as F
5+
6+
from model.models import FewShotModel
7+
8+
class ScaledDotProductAttention(nn.Module):
9+
''' Scaled Dot-Product Attention '''
10+
11+
def __init__(self, temperature, attn_dropout=0.1):
12+
super().__init__()
13+
self.temperature = temperature
14+
self.dropout = nn.Dropout(attn_dropout)
15+
self.softmax = nn.Softmax(dim=2)
16+
17+
def forward(self, q, k, v):
18+
19+
attn = torch.bmm(q, k.transpose(1, 2))
20+
attn = attn / self.temperature
21+
log_attn = F.log_softmax(attn, 2)
22+
attn = self.softmax(attn)
23+
attn = self.dropout(attn)
24+
output = torch.bmm(attn, v)
25+
return output, attn, log_attn
26+
27+
class MultiHeadAttention(nn.Module):
28+
''' Multi-Head Attention module '''
29+
30+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
31+
super().__init__()
32+
self.n_head = n_head
33+
self.d_k = d_k
34+
self.d_v = d_v
35+
36+
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
37+
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
38+
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
39+
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
40+
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
41+
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
42+
43+
self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
44+
self.layer_norm = nn.LayerNorm(d_model)
45+
46+
self.fc = nn.Linear(n_head * d_v, d_model)
47+
nn.init.xavier_normal_(self.fc.weight)
48+
self.dropout = nn.Dropout(dropout)
49+
50+
def forward(self, q, k, v):
51+
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
52+
sz_b, len_q, _ = q.size()
53+
sz_b, len_k, _ = k.size()
54+
sz_b, len_v, _ = v.size()
55+
56+
residual = q
57+
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
58+
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
59+
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
60+
61+
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
62+
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
63+
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
64+
65+
output, attn, log_attn = self.attention(q, k, v)
66+
67+
output = output.view(n_head, sz_b, len_q, d_v)
68+
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
69+
70+
output = self.dropout(self.fc(output))
71+
output = self.layer_norm(output + residual)
72+
73+
return output
74+
75+
class SemiFEAT(FewShotModel):
76+
def __init__(self, args):
77+
super().__init__(args)
78+
if args.backbone_class == 'ConvNet':
79+
hdim = 64
80+
elif args.backbone_class == 'Res12':
81+
hdim = 640
82+
elif args.backbone_class == 'Res18':
83+
hdim = 512
84+
elif args.backbone_class == 'WRN':
85+
hdim = 640
86+
else:
87+
raise ValueError('')
88+
89+
self.slf_attn = MultiHeadAttention(1, hdim, hdim, hdim, dropout=0.5)
90+
91+
def _forward(self, instance_embs, support_idx, query_idx):
92+
emb_dim = instance_embs.size(-1)
93+
94+
# organize support/query data
95+
support = instance_embs[support_idx.contiguous().view(-1)].contiguous().view(*(support_idx.shape + (-1,)))
96+
query = instance_embs[query_idx.contiguous().view(-1)].contiguous().view( *(query_idx.shape + (-1,)))
97+
98+
# get mean of the support
99+
proto = support.mean(dim=1) # Ntask x NK x d
100+
num_batch = proto.shape[0]
101+
num_proto = proto.shape[1]
102+
num_query = np.prod(query_idx.shape[-2:])
103+
104+
# query: (num_batch, num_query, num_proto, num_emb)
105+
# proto: (num_batch, num_proto, num_emb)
106+
whole_set = torch.cat([proto, query.view(num_batch, -1, emb_dim)], 1)
107+
proto = self.slf_attn(proto, whole_set, whole_set)
108+
if self.args.use_euclidean:
109+
query = query.view(-1, emb_dim).unsqueeze(1) # (Nbatch*Nq*Nw, 1, d)
110+
proto = proto.unsqueeze(1).expand(num_batch, num_query, num_proto, emb_dim).contiguous()
111+
proto = proto.view(num_batch*num_query, num_proto, emb_dim) # (Nbatch x Nq, Nk, d)
112+
113+
logits = - torch.sum((proto - query) ** 2, 2) / self.args.temperature
114+
else:
115+
proto = F.normalize(proto, dim=-1) # normalize for cosine distance
116+
query = query.view(num_batch, -1, emb_dim) # (Nbatch, Nq*Nw, d)
117+
118+
logits = torch.bmm(query, proto.permute([0,2,1])) / self.args.temperature
119+
logits = logits.view(-1, num_proto)
120+
121+
# for regularization
122+
if self.training:
123+
aux_task = torch.cat([support.view(1, self.args.shot, self.args.way, emb_dim),
124+
query.view(1, self.args.query, self.args.way, emb_dim)], 1) # T x (K+Kq) x N x d
125+
num_query = np.prod(aux_task.shape[1:3])
126+
aux_task = aux_task.permute([0, 2, 1, 3])
127+
aux_task = aux_task.contiguous().view(-1, self.args.shot + self.args.query, emb_dim)
128+
# apply the transformation over the Aug Task
129+
aux_emb = self.slf_attn(aux_task, aux_task, aux_task) # T x N x (K+Kq) x d
130+
# compute class mean
131+
aux_emb = aux_emb.view(num_batch, self.args.way, self.args.shot + self.args.query, emb_dim)
132+
aux_center = torch.mean(aux_emb, 2) # T x N x d
133+
134+
if self.args.use_euclidean:
135+
aux_task = aux_task.permute([1,0,2]).contiguous().view(-1, emb_dim).unsqueeze(1) # (Nbatch*Nq*Nw, 1, d)
136+
aux_center = aux_center.unsqueeze(1).expand(num_batch, num_query, num_proto, emb_dim).contiguous()
137+
aux_center = aux_center.view(num_batch*num_query, num_proto, emb_dim) # (Nbatch x Nq, Nk, d)
138+
139+
logits_reg = - torch.sum((aux_center - aux_task) ** 2, 2) / self.args.temperature2
140+
else:
141+
aux_center = F.normalize(aux_center, dim=-1) # normalize for cosine distance
142+
aux_task = aux_task.permute([1,0,2]).contiguous().view(num_batch, -1, emb_dim) # (Nbatch, Nq*Nw, d)
143+
144+
logits_reg = torch.bmm(aux_task, aux_center.permute([0,2,1])) / self.args.temperature2
145+
logits_reg = logits_reg.view(-1, num_proto)
146+
147+
return logits, logits_reg
148+
else:
149+
return logits

model/models/semi_protofeat.py

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
import torch
2+
import torch.nn as nn
3+
import numpy as np
4+
import torch.nn.functional as F
5+
6+
from model.models import FewShotModel
7+
from model.utils import one_hot
8+
9+
class ScaledDotProductAttention(nn.Module):
10+
''' Scaled Dot-Product Attention '''
11+
12+
def __init__(self, temperature, attn_dropout=0.1):
13+
super().__init__()
14+
self.temperature = temperature
15+
self.dropout = nn.Dropout(attn_dropout)
16+
self.softmax = nn.Softmax(dim=2)
17+
18+
def forward(self, q, k, v):
19+
20+
attn = torch.bmm(q, k.transpose(1, 2))
21+
attn = attn / self.temperature
22+
log_attn = F.log_softmax(attn, 2)
23+
attn = self.softmax(attn)
24+
attn = self.dropout(attn)
25+
output = torch.bmm(attn, v)
26+
return output, attn, log_attn
27+
28+
class MultiHeadAttention(nn.Module):
29+
''' Multi-Head Attention module '''
30+
31+
def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
32+
super().__init__()
33+
self.n_head = n_head
34+
self.d_k = d_k
35+
self.d_v = d_v
36+
37+
self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
38+
self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
39+
self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
40+
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
41+
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
42+
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
43+
44+
self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
45+
self.layer_norm = nn.LayerNorm(d_model)
46+
47+
self.fc = nn.Linear(n_head * d_v, d_model)
48+
nn.init.xavier_normal_(self.fc.weight)
49+
self.dropout = nn.Dropout(dropout)
50+
51+
def forward(self, q, k, v):
52+
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
53+
sz_b, len_q, _ = q.size()
54+
sz_b, len_k, _ = k.size()
55+
sz_b, len_v, _ = v.size()
56+
57+
residual = q
58+
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
59+
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
60+
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
61+
62+
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
63+
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
64+
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
65+
66+
output, attn, log_attn = self.attention(q, k, v)
67+
68+
output = output.view(n_head, sz_b, len_q, d_v)
69+
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
70+
71+
output = self.dropout(self.fc(output))
72+
output = self.layer_norm(output + residual)
73+
74+
return output
75+
76+
class SemiProtoFEAT(FewShotModel):
77+
def __init__(self, args):
78+
super().__init__(args)
79+
if args.backbone_class == 'ConvNet':
80+
hdim = 64
81+
elif args.backbone_class == 'Res12':
82+
hdim = 640
83+
elif args.backbone_class == 'Res18':
84+
hdim = 512
85+
elif args.backbone_class == 'WRN':
86+
hdim = 640
87+
else:
88+
raise ValueError('')
89+
90+
self.slf_attn = MultiHeadAttention(1, hdim, hdim, hdim, dropout=0.5)
91+
92+
def get_proto(self, x_shot, x_pool):
93+
# get the prototypes based w/ an unlabeled pool set
94+
num_batch, num_shot, num_way, emb_dim = x_shot.shape
95+
num_pool_shot = x_pool.shape[1]
96+
num_pool = num_pool_shot * num_way
97+
label_support = torch.arange(self.args.way).repeat(self.args.shot).type(torch.LongTensor)
98+
label_support_onehot = one_hot(label_support, num_way)
99+
label_support_onehot = label_support_onehot.unsqueeze(0).repeat([num_batch, 1, 1])
100+
if torch.cuda.is_available():
101+
label_support_onehot = label_support_onehot.cuda()
102+
103+
proto_shot = x_shot.mean(dim = 1)
104+
if self.args.use_euclidean:
105+
dis = - torch.sum((proto_shot.unsqueeze(1).expand(num_batch, num_pool, num_way, emb_dim).contiguous().view(num_batch*num_pool, num_way, emb_dim) - x_pool.view(-1, emb_dim).unsqueeze(1)) ** 2, 2) / self.args.temperature
106+
else:
107+
dis = torch.bmm(x_pool.view(num_batch, -1, emb_dim), F.normalize(proto_shot, dim=-1).permute([0,2,1])) / self.args.temperature
108+
109+
dis = dis.view(num_batch, -1, num_way)
110+
z_hat = F.softmax(dis, dim=2)
111+
z = torch.cat([label_support_onehot, z_hat], dim = 1) # (num_batch, n_shot + n_pool, n_way)
112+
h = torch.cat([x_shot.view(num_batch, -1, emb_dim), x_pool.view(num_batch, -1, emb_dim)], dim = 1) # (num_batch, n_shot + n_pool, n_embedding)
113+
114+
proto = torch.bmm(z.permute([0,2,1]), h)
115+
sum_z = z.sum(dim = 1).view((num_batch, -1, 1))
116+
proto = proto / sum_z
117+
return proto
118+
119+
def _forward(self, instance_embs, support_idx, query_idx):
120+
emb_dim = instance_embs.size(-1)
121+
122+
# organize support/query data
123+
support = instance_embs[support_idx.contiguous().view(-1)].contiguous().view(*(support_idx.shape + (-1,)))
124+
query = instance_embs[query_idx.contiguous().view(-1)].contiguous().view( *(query_idx.shape + (-1,)))
125+
126+
num_batch = support.shape[0]
127+
num_shot, num_way = support.shape[1], support.shape[2]
128+
num_query = np.prod(query_idx.shape[-2:])
129+
130+
# transformation
131+
whole_set = torch.cat([support.view(num_batch, -1, emb_dim), query.view(num_batch, -1, emb_dim)], 1)
132+
support = self.slf_attn(support.view(num_batch, -1, emb_dim), whole_set, whole_set).view(num_batch, num_shot, num_way, emb_dim)
133+
134+
# get mean of the support
135+
proto = self.get_proto(support, query) # we can also use adapted query set here to achieve better results
136+
# proto = support.mean(dim=1) # Ntask x NK x d
137+
num_proto = proto.shape[1]
138+
139+
# query: (num_batch, num_query, num_proto, num_emb)
140+
# proto: (num_batch, num_proto, num_emb)
141+
if self.args.use_euclidean:
142+
query = query.view(-1, emb_dim).unsqueeze(1) # (Nbatch*Nq*Nw, 1, d)
143+
proto = proto.unsqueeze(1).expand(num_batch, num_query, num_proto, emb_dim).contiguous()
144+
proto = proto.view(num_batch*num_query, num_proto, emb_dim) # (Nbatch x Nq, Nk, d)
145+
146+
logits = - torch.sum((proto - query) ** 2, 2) / self.args.temperature
147+
else:
148+
proto = F.normalize(proto, dim=-1) # normalize for cosine distance
149+
query = query.view(num_batch, -1, emb_dim) # (Nbatch, Nq*Nw, d)
150+
151+
logits = torch.bmm(query, proto.permute([0,2,1])) / self.args.temperature
152+
logits = logits.view(-1, num_proto)
153+
154+
# for regularization
155+
if self.training:
156+
aux_task = torch.cat([support.view(1, self.args.shot, self.args.way, emb_dim),
157+
query.view(1, self.args.query, self.args.way, emb_dim)], 1) # T x (K+Kq) x N x d
158+
num_query = np.prod(aux_task.shape[1:3])
159+
aux_task = aux_task.permute([0, 2, 1, 3])
160+
aux_task = aux_task.contiguous().view(-1, self.args.shot + self.args.query, emb_dim)
161+
# apply the transformation over the Aug Task
162+
aux_emb = self.slf_attn(aux_task, aux_task, aux_task) # T x N x (K+Kq) x d
163+
# compute class mean
164+
aux_emb = aux_emb.view(num_batch, self.args.way, self.args.shot + self.args.query, emb_dim)
165+
aux_center = torch.mean(aux_emb, 2) # T x N x d
166+
167+
if self.args.use_euclidean:
168+
aux_task = aux_task.permute([1,0,2]).contiguous().view(-1, emb_dim).unsqueeze(1) # (Nbatch*Nq*Nw, 1, d)
169+
aux_center = aux_center.unsqueeze(1).expand(num_batch, num_query, num_proto, emb_dim).contiguous()
170+
aux_center = aux_center.view(num_batch*num_query, num_proto, emb_dim) # (Nbatch x Nq, Nk, d)
171+
172+
logits_reg = - torch.sum((aux_center - aux_task) ** 2, 2) / self.args.temperature2
173+
else:
174+
aux_center = F.normalize(aux_center, dim=-1) # normalize for cosine distance
175+
aux_task = aux_task.permute([1,0,2]).contiguous().view(num_batch, -1, emb_dim) # (Nbatch, Nq*Nw, d)
176+
177+
logits_reg = torch.bmm(aux_task, aux_center.permute([0,2,1])) / self.args.temperature2
178+
logits_reg = logits_reg.view(-1, num_proto)
179+
180+
return logits, logits_reg
181+
else:
182+
return logits

model/trainer/helpers.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from model.models.deepset import DeepSet
1212
from model.models.bilstm import BILSTM
1313
from model.models.graphnet import GCN
14+
from model.models.semi_feat import SemiFEAT
15+
from model.models.semi_protofeat import SemiProtoFEAT
1416

1517
class MultiGPUDataloader:
1618
def __init__(self, dataloader, num_device):

model/utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -142,17 +142,17 @@ def get_command_line_parser():
142142
parser.add_argument('--max_epoch', type=int, default=200)
143143
parser.add_argument('--episodes_per_epoch', type=int, default=100)
144144
parser.add_argument('--num_eval_episodes', type=int, default=600)
145-
parser.add_argument('--model_class', type=str, default='FEAT',
146-
choices=['MatchNet', 'ProtoNet', 'BILSTM', 'DeepSet', 'GCN', 'FEAT', 'FEATSTAR']) # None for MatchNet or ProtoNet
145+
parser.add_argument('--model_class', type=str, default='SemiProtoFEAT',
146+
choices=['MatchNet', 'ProtoNet', 'BILSTM', 'DeepSet', 'GCN', 'FEAT', 'FEATSTAR', 'SemiFEAT', 'SemiProtoFEAT']) # None for MatchNet or ProtoNet
147147
parser.add_argument('--use_euclidean', action='store_true', default=False)
148-
parser.add_argument('--backbone_class', type=str, default='Res12',
148+
parser.add_argument('--backbone_class', type=str, default='ConvNet',
149149
choices=['ConvNet', 'Res12', 'Res18', 'WRN'])
150150
parser.add_argument('--dataset', type=str, default='MiniImageNet',
151151
choices=['MiniImageNet', 'TieredImageNet', 'CUB'])
152152

153153
parser.add_argument('--way', type=int, default=5)
154154
parser.add_argument('--eval_way', type=int, default=5)
155-
parser.add_argument('--shot', type=int, default=1)
155+
parser.add_argument('--shot', type=int, default=3)
156156
parser.add_argument('--eval_shot', type=int, default=1)
157157
parser.add_argument('--query', type=int, default=15)
158158
parser.add_argument('--eval_query', type=int, default=15)

0 commit comments

Comments
 (0)