-
Notifications
You must be signed in to change notification settings - Fork 147
/
utils.py
183 lines (160 loc) · 5.72 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""
General utilities
"""
from typing import List, Optional, Tuple
import numpy as np
import pandas as pd
import torch
import torchvision
from matplotlib import pyplot as plt
from torch import Tensor, nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from easyfsl.methods import FewShotClassifier
def plot_images(images: Tensor, title: str, images_per_row: int):
"""
Plot images in a grid.
Args:
images: 4D mini-batch Tensor of shape (B x C x H x W)
title: title of the figure to plot
images_per_row: number of images in each row of the grid
"""
plt.figure()
plt.title(title)
plt.imshow(
torchvision.utils.make_grid(images, nrow=images_per_row).permute(1, 2, 0)
)
def sliding_average(value_list: List[float], window: int) -> float:
"""
Computes the average of the latest instances in a list
Args:
value_list: input list of floats (can't be empty)
window: number of instances to take into account. If value is 0 or greater than
the length of value_list, all instances will be taken into account.
Returns:
average of the last window instances in value_list
Raises:
ValueError: if the input list is empty
"""
if len(value_list) == 0:
raise ValueError("Cannot perform sliding average on an empty list.")
return np.asarray(value_list[-window:]).mean()
def predict_embeddings(
dataloader: DataLoader,
model: nn.Module,
device: Optional[str] = None,
) -> pd.DataFrame:
"""
Predict embeddings for a dataloader.
Args:
dataloader: dataloader to predict embeddings for. Must deliver tuples (images, class_names)
model: model to use for prediction
device: device to cast the images to. If none, no casting is performed. Must be the same as
the device the model is on.
Returns:
dataframe with columns embedding and class_name
"""
all_embeddings = []
all_class_names = []
with torch.no_grad():
for images, class_names in tqdm(
dataloader, unit="batch", desc="Predicting embeddings"
):
if device is not None:
images = images.to(device)
all_embeddings.append(model(images).detach().cpu())
if isinstance(class_names, torch.Tensor):
all_class_names += class_names.tolist()
else:
all_class_names += class_names
concatenated_embeddings = torch.cat(all_embeddings)
return pd.DataFrame(
{"embedding": list(concatenated_embeddings), "class_name": all_class_names}
)
def evaluate_on_one_task(
model: FewShotClassifier,
support_images: Tensor,
support_labels: Tensor,
query_images: Tensor,
query_labels: Tensor,
) -> Tuple[int, int]:
"""
Returns the number of correct predictions of query labels, and the total number of
predictions.
"""
model.process_support_set(support_images, support_labels)
predictions = model(query_images).detach().data
number_of_correct_predictions = int(
(torch.max(predictions, 1)[1] == query_labels).sum().item()
)
return number_of_correct_predictions, len(query_labels)
def evaluate(
model: FewShotClassifier,
data_loader: DataLoader,
device: str = "cuda",
use_tqdm: bool = True,
tqdm_prefix: Optional[str] = None,
) -> float:
"""
Evaluate the model on few-shot classification tasks
Args:
model: a few-shot classifier
data_loader: loads data in the shape of few-shot classification tasks*
device: where to cast data tensors.
Must be the same as the device hosting the model's parameters.
use_tqdm: whether to display the evaluation's progress bar
tqdm_prefix: prefix of the tqdm bar
Returns:
average classification accuracy
"""
# We'll count everything and compute the ratio at the end
total_predictions = 0
correct_predictions = 0
# eval mode affects the behaviour of some layers (such as batch normalization or dropout)
# no_grad() tells torch not to keep in memory the whole computational graph
model.eval()
with torch.no_grad():
# We use a tqdm context to show a progress bar in the logs
with tqdm(
enumerate(data_loader),
total=len(data_loader),
disable=not use_tqdm,
desc=tqdm_prefix,
) as tqdm_eval:
for _, (
support_images,
support_labels,
query_images,
query_labels,
_,
) in tqdm_eval:
correct, total = evaluate_on_one_task(
model,
support_images.to(device),
support_labels.to(device),
query_images.to(device),
query_labels.to(device),
)
total_predictions += total
correct_predictions += correct
# Log accuracy in real time
tqdm_eval.set_postfix(accuracy=correct_predictions / total_predictions)
return correct_predictions / total_predictions
def compute_average_features_from_images(
dataloader: DataLoader,
model: nn.Module,
device: Optional[str] = None,
):
"""
Compute the average features vector from all images in a DataLoader.
Assumes the images are always first element of the batch.
Returns:
Tensor: shape (1, feature_dimension)
"""
all_embeddings = torch.stack(
predict_embeddings(dataloader, model, device)["embedding"].to_list()
)
average_features = all_embeddings.mean(dim=0)
if device is not None:
average_features = average_features.to(device)
return average_features