-
Notifications
You must be signed in to change notification settings - Fork 151
/
prototypical_networks.py
42 lines (33 loc) · 1.42 KB
/
prototypical_networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
"""
See original implementation (quite far from this one)
at https://github.com/jakesnell/prototypical-networks
"""
from torch import Tensor
from .few_shot_classifier import FewShotClassifier
class PrototypicalNetworks(FewShotClassifier):
"""
Jake Snell, Kevin Swersky, and Richard S. Zemel.
"Prototypical networks for few-shot learning." (2017)
https://arxiv.org/abs/1703.05175
Prototypical networks extract feature vectors for both support and query images. Then it
computes the mean of support features for each class (called prototypes), and predict
classification scores for query images based on their euclidean distance to the prototypes.
"""
def forward(
self,
query_images: Tensor,
) -> Tensor:
"""
Overrides forward method of FewShotClassifier.
Predict query labels based on their distance to class prototypes in the feature space.
Classification scores are the negative of euclidean distances.
"""
# Extract the features of query images
query_features = self.compute_features(query_images)
self._raise_error_if_features_are_multi_dimensional(query_features)
# Compute the euclidean distance from queries to prototypes
scores = self.l2_distance_to_prototypes(query_features)
return self.softmax_if_specified(scores)
@staticmethod
def is_transductive() -> bool:
return False