-
Notifications
You must be signed in to change notification settings - Fork 151
/
tim.py
126 lines (106 loc) · 4.56 KB
/
tim.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
from torch import Tensor, nn
from easyfsl.methods import FewShotClassifier
class TIM(FewShotClassifier):
"""
Malik Boudiaf, Ziko Imtiaz Masud, Jérôme Rony, José Dolz, Pablo Piantanida, Ismail Ben Ayed.
"Transductive Information Maximization For Few-Shot Learning" (NeurIPS 2020)
https://arxiv.org/abs/2008.11297
Fine-tune prototypes based on
1) classification error on support images
2) mutual information between query features and their label predictions
Classify w.r.t. to euclidean distance to updated prototypes.
As is, it is incompatible with episodic training because we freeze the backbone to perform
fine-tuning.
TIM is a transductive method.
"""
def __init__(
self,
*args,
fine_tuning_steps: int = 100,
fine_tuning_lr: float = 1e-3,
cross_entropy_weight: float = 1.0,
marginal_entropy_weight: float = 1.0,
conditional_entropy_weight: float = 0.1,
**kwargs,
):
"""
Args:
fine_tuning_steps: number of fine-tuning steps
fine_tuning_lr: learning rate for fine-tuning
cross_entropy_weight: weight given to the cross-entropy term of the loss
marginal_entropy_weight: weight given to the marginal entropy term of the loss
conditional_entropy_weight: weight given to the conditional entropy term of the loss
"""
super().__init__(*args, **kwargs)
# Since we fine-tune the prototypes we need to make them leaf variables
# i.e. we need to freeze the backbone.
self.backbone.requires_grad_(False)
self.fine_tuning_steps = fine_tuning_steps
self.fine_tuning_lr = fine_tuning_lr
self.cross_entropy_weight = cross_entropy_weight
self.marginal_entropy_weight = marginal_entropy_weight
self.conditional_entropy_weight = conditional_entropy_weight
def process_support_set(
self,
support_images: torch.Tensor,
support_labels: torch.Tensor,
):
"""
Overrides process_support_set of FewShotClassifier.
Args:
support_images: images of the support set
support_labels: labels of support set images
"""
self.store_support_set_data(support_images, support_labels)
def forward(
self,
query_images: Tensor,
) -> Tensor:
"""
Overrides forward method of FewShotClassifier.
Fine-tune prototypes based on support classification error and mutual information between
query features and their label predictions.
Then classify w.r.t. to euclidean distance to prototypes.
Args:
query_images: images of the query set
Returns:
a prediction of classification scores for query images
"""
query_features = self.backbone.forward(query_images)
num_classes = self.support_labels.unique().size(0)
support_labels_one_hot = nn.functional.one_hot(self.support_labels, num_classes)
with torch.enable_grad():
self.prototypes.requires_grad_()
optimizer = torch.optim.Adam([self.prototypes], lr=self.fine_tuning_lr)
for _ in range(self.fine_tuning_steps):
support_logits = self.l2_distance_to_prototypes(self.support_features)
query_logits = self.l2_distance_to_prototypes(query_features)
support_cross_entropy = (
-(support_labels_one_hot * support_logits.log_softmax(1))
.sum(1)
.mean(0)
)
query_soft_probs = query_logits.softmax(1)
query_conditional_entropy = (
-(query_soft_probs * torch.log(query_soft_probs + 1e-12))
.sum(1)
.mean(0)
)
marginal_prediction = query_soft_probs.mean(0)
marginal_entropy = -(
marginal_prediction * torch.log(marginal_prediction)
).sum(0)
loss = self.cross_entropy_weight * support_cross_entropy - (
self.marginal_entropy_weight * marginal_entropy
- self.conditional_entropy_weight * query_conditional_entropy
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return self.softmax_if_specified(
self.l2_distance_to_prototypes(query_features)
).detach()
@staticmethod
def is_transductive() -> bool:
return True