-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
341 lines (288 loc) · 9.86 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
from pathlib import Path
from typing import Optional
import jax
import numpy as np
import torch
import torchvision
from jax import numpy as jnp
from config import NoiseLevel, NOISE_LEVEL_TO_VARIANCE, NoiseType
def get_6x6_numbers_bitmaps() -> torch.Tensor:
zero = torch.Tensor(
[
[1, 1, 0, 0, 1, 1],
[1, 0, 1, 1, 0, 1],
[1, 0, 1, 1, 0, 1],
[1, 0, 1, 1, 0, 1],
[1, 0, 1, 1, 0, 1],
[1, 1, 0, 0, 1, 1],
]
)
one = torch.Tensor(
[
[1, 1, 0, 1, 1, 1],
[1, 0, 0, 1, 1, 1],
[1, 1, 0, 1, 1, 1],
[1, 1, 0, 1, 1, 1],
[1, 1, 0, 1, 1, 1],
[1, 0, 0, 0, 1, 1],
]
)
two = torch.Tensor(
[
[1, 0, 0, 0, 1, 1],
[0, 1, 1, 1, 0, 1],
[1, 1, 1, 0, 1, 1],
[1, 1, 0, 1, 1, 1],
[1, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 0, 1],
]
)
three = torch.Tensor(
[
[1, 1, 1, 1, 1, 1],
[1, 0, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 1],
[1, 1, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 1],
[1, 0, 0, 0, 0, 1],
]
)
four = torch.Tensor(
[
[1, 1, 1, 1, 1, 1],
[1, 0, 1, 0, 1, 1],
[1, 0, 1, 0, 1, 1],
[1, 0, 0, 0, 1, 1],
[1, 1, 1, 0, 1, 1],
[1, 1, 1, 0, 1, 1],
]
)
five = torch.Tensor(
[
[1, 1, 1, 1, 1, 1],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 1, 1, 1],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 0, 0, 0, 0],
]
)
six = torch.Tensor(
[
[1, 1, 1, 1, 1, 1],
[1, 0, 0, 0, 0, 1],
[1, 0, 1, 1, 1, 1],
[1, 0, 0, 0, 0, 1],
[1, 0, 1, 1, 0, 1],
[1, 0, 0, 0, 0, 1],
]
)
seven = torch.Tensor(
[
[1, 1, 1, 1, 1, 1],
[1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 0, 1],
[1, 1, 1, 0, 1, 1],
[1, 1, 0, 1, 1, 1],
]
)
eight = torch.Tensor(
[
[1, 1, 1, 1, 1, 1],
[1, 1, 0, 0, 1, 1],
[1, 0, 1, 1, 0, 1],
[1, 1, 0, 0, 1, 1],
[1, 0, 1, 1, 0, 1],
[1, 1, 0, 0, 1, 1],
]
)
nine = torch.Tensor(
[
[1, 0, 0, 0, 0, 1],
[1, 0, 1, 1, 0, 1],
[1, 0, 1, 1, 0, 1],
[1, 0, 0, 0, 0, 1],
[1, 1, 1, 1, 0, 1],
[1, 1, 1, 1, 0, 1],
]
)
return torch.stack([zero, one, two, three, four, five, six, seven, eight, nine])
def change_random_bits_in_image(
img: torch.Tensor,
num_bits_to_change: int,
lower_bound: float = 0,
upper_bound: float = 1,
variance: float = 0.3,
) -> torch.Tensor:
"""
Changes num_bits_to_change random bits in a flattened image
"""
noisy_img = img.clone()
indices = np.random.choice(len(noisy_img), num_bits_to_change, replace=False)
noisy_img[indices] = torch.clamp(
noisy_img[indices] + variance * torch.randn(num_bits_to_change).float(),
lower_bound,
upper_bound,
)
return noisy_img
def noise_bitmaps(
bitmaps: torch.Tensor,
num_variations_per_bitmap: int,
noise_level: NoiseLevel,
is_continuous: bool,
flatten: bool,
max_attempts_to_balance: int = 1,
) -> torch.Tensor:
"""
Adds noise to the given bitmaps
"""
noisy_bitmaps: list[torch.Tensor] = []
for bitmap in bitmaps:
for i in range(num_variations_per_bitmap):
# If the closest bitmap to the noisy bitmap from the original bitmaps is not its own original bitmap,
# we try again for max_attempts_to_balance times. This is used to create a balanced dataset in which
# every noisy bitmap has the original bitmap as its closest bitmap. If we fail to create a balanced
# dataset, we just use the last noisy bitmap we created.
noisy_bitmap = None
for attempt in range(max_attempts_to_balance):
noisy_bitmap = bitmap.flatten().clone()
noise_variance = NOISE_LEVEL_TO_VARIANCE[noise_level]
noisy_bitmap += torch.clamp(
torch.randn(noisy_bitmap.shape) * noise_variance, 0, 1
)
if not is_continuous:
noisy_bitmap = torch.clamp(noisy_bitmap, 0, 1).round()
noisy_bitmap = (
noisy_bitmap if flatten else reshape_flatten_image(noisy_bitmap)
)
similarities = bitmaps.reshape(
bitmaps.shape[0], -1
) @ noisy_bitmap.flatten().unsqueeze(1)
# if there is more than one bitmap with the same similarity, we continue to the next iteration
if torch.sum(similarities == torch.max(similarities)) > 1:
continue
closest_bitmap = bitmaps[torch.argmax(similarities)]
if torch.equal(closest_bitmap, bitmap):
break
if max_attempts_to_balance > 1 and attempt == max_attempts_to_balance - 1:
raise ValueError(
"Tried to create a balanced dataset but failed, some digits are probably too similar"
)
if noisy_bitmap is None:
raise ValueError(
f"Failed to create a balanced dataset for bitmap {bitmap}."
f"Did you set max_attempts_to_balance to a value less than 1?"
)
noisy_bitmaps.append(noisy_bitmap)
return torch.stack(noisy_bitmaps)
def reshape_flatten_image(img: torch.Tensor) -> torch.Tensor:
"""
Reshapes a flattened image to a square matrix
"""
dim = int(np.sqrt(len(img)))
return img.reshape(dim, dim)
def get_train_bitmap_shape(train_bitmap: torch.Tensor) -> tuple[int, int]:
"""
Returns the shape of the given train bitmap, assuming it's square
"""
dim = int(np.sqrt(len(train_bitmap)))
return dim, dim
def generate_random_flat_bitmap(
width: int,
height: int,
is_continuous: bool,
lower_bound: float = 0,
upper_bound: float = 1,
discrete_values: tuple[int] = (0, 1),
) -> torch.Tensor:
"""
Generates a random bitmap of the given width and height
"""
if is_continuous:
return torch.rand(width * height) * (upper_bound - lower_bound) + lower_bound
else:
return torch.from_numpy(
np.random.choice(discrete_values, size=(width * height))
).to(torch.float)
def generate_random_train_bitmap_data(
width: int, height: int, num_images: int, is_continuous: bool
) -> torch.Tensor:
"""
Generates a list of random bitmaps for training
"""
if is_continuous:
bitmaps = [
generate_random_flat_bitmap(width, height, is_continuous=True)
for _ in range(num_images)
]
else:
bitmaps = [
generate_random_flat_bitmap(width, height, is_continuous=False)
for _ in range(num_images)
]
return torch.stack(bitmaps)
def save_bitmaps(bitmaps: torch.Tensor, path: Path) -> None:
"""
Saves a list of bitmaps as images
"""
for i, bitmap in enumerate(bitmaps):
torch.save(bitmap, path / f"{i}.pt")
# Save a grayscale image
torchvision.utils.save_image(bitmap.unsqueeze(0), path / f"{i}.png")
def get_experiments_data_path() -> Path:
"""
Get the path to the experiments data directory
"""
return Path(__file__).parent / "experiments_data"
def tensor_to_jax(train_data: torch.Tensor) -> jax.Array:
return jnp.array(train_data.numpy(), dtype=jnp.float32)
def get_correct_and_incorrect_num_memories(
digits_to_test: list[int], num_examples_per_digit: int
):
return len(digits_to_test), len(digits_to_test) * num_examples_per_digit
def get_train_data(
noise_type: NoiseType,
digits_to_test: list[int],
num_noise_variations: Optional[int] = None,
noise_level: Optional[NoiseLevel] = None,
) -> torch.Tensor:
"""
:param noise_type: one of 'none', 'discrete' or 'continuous_noise'
:param digits_to_test: list of digits to test
:param num_noise_variations: number of noisy variations for each digit
:param noise_level: noise level
:return: train data
"""
experiments_path = get_experiments_data_path()
numbers_path = experiments_path / "numbers"
# Handle original data differently
if noise_type == NoiseType.NONE:
data_path = numbers_path / "originals"
sorted_files = sorted(data_path / f"{digit}.pt" for digit in digits_to_test)
return torch.stack([torch.load(file).flatten() for file in sorted_files])
if not num_noise_variations:
raise ValueError("num_noise_variations must be specified for noisy data")
if not noise_level:
raise ValueError("noise_level must be specified for noisy data")
if noise_type == NoiseType.DISCRETE:
data_path = numbers_path / "discrete"
elif noise_type == NoiseType.BALANCED_DISCRETE:
data_path = numbers_path / "balanced_discrete"
elif noise_type == NoiseType.CONTINUOUS:
data_path = numbers_path / "continuous"
elif noise_type == NoiseType.BALANCED_CONTINUOUS:
data_path = numbers_path / "balanced_continuous"
else:
raise ValueError(f"Invalid noise type: {noise_type}")
data_path = data_path / noise_level.value
data_paths = [data_path / str(digit) for digit in digits_to_test]
# Load num_noise_variations of .pt files and flatten them
data = []
for path in data_paths:
data += [
torch.load(file).flatten()
for i, file in enumerate(path.glob("*.pt"))
if i < num_noise_variations
]
return torch.stack(data)