-
Notifications
You must be signed in to change notification settings - Fork 147
/
few_shot_classifier.py
158 lines (141 loc) · 6.17 KB
/
few_shot_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from abc import abstractmethod
from typing import Optional
import torch
from torch import Tensor, nn
from easyfsl.methods.utils import compute_prototypes
class FewShotClassifier(nn.Module):
"""
Abstract class providing methods usable by all few-shot classification algorithms
"""
def __init__(
self,
backbone: Optional[nn.Module] = None,
use_softmax: bool = False,
feature_centering: Optional[Tensor] = None,
feature_normalization: Optional[float] = None,
):
"""
Initialize the Few-Shot Classifier
Args:
backbone: the feature extractor used by the method. Must output a tensor of the
appropriate shape (depending on the method).
If None is passed, the backbone will be initialized as nn.Identity().
use_softmax: whether to return predictions as soft probabilities
feature_centering: a features vector on which to center all computed features.
If None is passed, no centering is performed.
feature_normalization: a value by which to normalize all computed features after centering.
It is used as the p argument in torch.nn.functional.normalize().
If None is passed, no normalization is performed.
"""
super().__init__()
self.backbone = backbone if backbone is not None else nn.Identity()
self.use_softmax = use_softmax
self.prototypes = torch.tensor(())
self.support_features = torch.tensor(())
self.support_labels = torch.tensor(())
self.feature_centering = (
feature_centering if feature_centering is not None else torch.tensor(0)
)
self.feature_normalization = feature_normalization
@abstractmethod
def forward(
self,
query_images: Tensor,
) -> Tensor:
"""
Predict classification labels.
Args:
query_images: images of the query set of shape (n_query, **image_shape)
Returns:
a prediction of classification scores for query images of shape (n_query, n_classes)
"""
raise NotImplementedError(
"All few-shot algorithms must implement a forward method."
)
def process_support_set(
self,
support_images: Tensor,
support_labels: Tensor,
):
"""
Harness information from the support set, so that query labels can later be predicted using a forward call.
The default behaviour shared by most few-shot classifiers is to compute prototypes and store the support set.
Args:
support_images: images of the support set of shape (n_support, **image_shape)
support_labels: labels of support set images of shape (n_support, )
"""
self.compute_prototypes_and_store_support_set(support_images, support_labels)
@staticmethod
def is_transductive() -> bool:
raise NotImplementedError(
"All few-shot algorithms must implement a is_transductive method."
)
def compute_features(self, images: Tensor) -> Tensor:
"""
Compute features from images and perform centering and normalization.
Args:
images: images of shape (n_images, **image_shape)
Returns:
features of shape (n_images, feature_dimension)
"""
original_features = self.backbone(images)
centered_features = original_features - self.feature_centering
if self.feature_normalization is not None:
return nn.functional.normalize(
centered_features, p=self.feature_normalization, dim=1
)
return centered_features
def softmax_if_specified(self, output: Tensor, temperature: float = 1.0) -> Tensor:
"""
If the option is chosen when the classifier is initialized, we perform a softmax on the
output in order to return soft probabilities.
Args:
output: output of the forward method of shape (n_query, n_classes)
temperature: temperature of the softmax
Returns:
output as it was, or output as soft probabilities, of shape (n_query, n_classes)
"""
return (temperature * output).softmax(-1) if self.use_softmax else output
def l2_distance_to_prototypes(self, samples: Tensor) -> Tensor:
"""
Compute prediction logits from their euclidean distance to support set prototypes.
Args:
samples: features of the items to classify of shape (n_samples, feature_dimension)
Returns:
prediction logits of shape (n_samples, n_classes)
"""
return -torch.cdist(samples, self.prototypes)
def cosine_distance_to_prototypes(self, samples) -> Tensor:
"""
Compute prediction logits from their cosine distance to support set prototypes.
Args:
samples: features of the items to classify of shape (n_samples, feature_dimension)
Returns:
prediction logits of shape (n_samples, n_classes)
"""
return (
nn.functional.normalize(samples, dim=1)
@ nn.functional.normalize(self.prototypes, dim=1).T
)
def compute_prototypes_and_store_support_set(
self,
support_images: Tensor,
support_labels: Tensor,
):
"""
Extract support features, compute prototypes, and store support labels, features, and prototypes.
Args:
support_images: images of the support set of shape (n_support, **image_shape)
support_labels: labels of support set images of shape (n_support, )
"""
self.support_labels = support_labels
self.support_features = self.compute_features(support_images)
self._raise_error_if_features_are_multi_dimensional(self.support_features)
self.prototypes = compute_prototypes(self.support_features, support_labels)
@staticmethod
def _raise_error_if_features_are_multi_dimensional(features: Tensor):
if len(features.shape) != 2:
raise ValueError(
"Illegal backbone or feature shape. "
"Expected output for an image is a 1-dim tensor."
)