-
Notifications
You must be signed in to change notification settings - Fork 151
/
prototypical_networks.py
82 lines (66 loc) · 2.62 KB
/
prototypical_networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
See original implementation (quite far from this one)
at https://github.com/jakesnell/prototypical-networks
"""
import torch
from torch import Tensor
from easyfsl.methods import FewShotClassifier
from easyfsl.utils import compute_prototypes
class PrototypicalNetworks(FewShotClassifier):
"""
Jake Snell, Kevin Swersky, and Richard S. Zemel.
"Prototypical networks for few-shot learning." (2017)
https://arxiv.org/abs/1703.05175
Prototypical networks extract feature vectors for both support and query images. Then it
computes the mean of support features for each class (called prototypes), and predict
classification scores for query images based on their euclidean distance to the prototypes.
"""
def __init__(self, *args, **kwargs):
"""
Raises:
ValueError: if the backbone is not a feature extractor,
i.e. if its output for a given image is not a 1-dim tensor.
"""
super().__init__(*args, **kwargs)
if len(self.backbone_output_shape) != 1:
raise ValueError(
"Illegal backbone for Prototypical Networks. "
"Expected output for an image is a 1-dim tensor."
)
def process_support_set(
self,
support_images: Tensor,
support_labels: Tensor,
):
"""
Overrides process_support_set of FewShotClassifier.
Extract feature vectors from the support set and store class prototypes.
Args:
support_images: images of the support set
support_labels: labels of support set images
"""
support_features = self.backbone.forward(support_images)
self.prototypes = compute_prototypes(support_features, support_labels)
def forward(
self,
query_images: Tensor,
) -> Tensor:
"""
Overrides forward method of FewShotClassifier.
Predict query labels based on their distance to class prototypes in the feature space.
Classification scores are the negative of euclidean distances.
Args:
query_images: images of the query set
Returns:
a prediction of classification scores for query images
"""
# Extract the features of support and query images
z_query = self.backbone.forward(query_images)
# Compute the euclidean distance from queries to prototypes
dists = torch.cdist(z_query, self.prototypes)
# Use it to compute classification scores
scores = -dists
return self.softmax_if_specified(scores)
@staticmethod
def is_transductive() -> bool:
return False