-
Notifications
You must be signed in to change notification settings - Fork 151
/
few_shot_classifier.py
126 lines (106 loc) · 3.88 KB
/
few_shot_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from abc import abstractmethod
import torch
from torch import nn, Tensor
from easyfsl.utils import compute_backbone_output_shape, compute_prototypes
class FewShotClassifier(nn.Module):
"""
Abstract class providing methods usable by all few-shot classification algorithms
"""
def __init__(self, backbone: nn.Module, use_softmax: bool = False):
"""
Initialize the Few-Shot Classifier
Args:
backbone: the feature extractor used by the method. Must output a tensor of the
appropriate shape (depending on the method)
use_softmax: whether to return predictions as soft probabilities
"""
super().__init__()
self.backbone = backbone
self.backbone_output_shape = compute_backbone_output_shape(backbone)
self.feature_dimension = self.backbone_output_shape[0]
self.use_softmax = use_softmax
self.prototypes = None
self.support_features = None
self.support_labels = None
@abstractmethod
def forward(
self,
query_images: Tensor,
) -> Tensor:
"""
Predict classification labels.
Args:
query_images: images of the query set
Returns:
a prediction of classification scores for query images
"""
raise NotImplementedError(
"All few-shot algorithms must implement a forward method."
)
@abstractmethod
def process_support_set(
self,
support_images: Tensor,
support_labels: Tensor,
):
"""
Harness information from the support set, so that query labels can later be predicted using
a forward call
Args:
support_images: images of the support set
support_labels: labels of support set images
"""
raise NotImplementedError(
"All few-shot algorithms must implement a process_support_set method."
)
@staticmethod
def is_transductive() -> bool:
raise NotImplementedError(
"All few-shot algorithms must implement a is_transductive method."
)
def softmax_if_specified(self, output: Tensor) -> Tensor:
"""
If the option is chosen when the classifier is initialized, we perform a softmax on the
output in order to return soft probabilities.
Args:
output: output of the forward method
Returns:
output as it was, or output as soft probabilities
"""
return output.softmax(-1) if self.use_softmax else output
def l2_distance_to_prototypes(self, samples: Tensor) -> Tensor:
"""
Compute prediction logits from their euclidean distance to support set prototypes.
Args:
samples: features of the items to classify
Returns:
prediction logits
"""
return -torch.cdist(samples, self.prototypes)
def cosine_distance_to_prototypes(self, samples) -> Tensor:
"""
Compute prediction logits from their cosine distance to support set prototypes.
Args:
samples: features of the items to classify
Returns:
prediction logits
"""
return (
nn.functional.normalize(samples, dim=1)
@ nn.functional.normalize(self.prototypes, dim=1).T
)
def store_support_set_data(
self,
support_images: Tensor,
support_labels: Tensor,
):
"""
Extract support features, compute prototypes,
and store support labels, features, and prototypes
Args:
support_images: images of the support set
support_labels: labels of support set images
"""
self.support_labels = support_labels
self.support_features = self.backbone(support_images)
self.prototypes = compute_prototypes(self.support_features, support_labels)