-
Notifications
You must be signed in to change notification settings - Fork 1
/
densenet.py
459 lines (395 loc) · 15.9 KB
/
densenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
"""
The densenet model with classes composed into the densenet class.
"""
import json
import re
from datetime import datetime
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
from pytorch_lightning.utilities.distributed import rank_zero_info
from typing import Any, List, Optional, Sequence, Tuple
from pathlib import Path
import torch as pt
from torch import nn
from torchvision import models
import pytorch_lightning as pl
import torchmetrics as tm
from pytorch_lightning.callbacks import BasePredictionWriter
class CensoredDataWriter(BasePredictionWriter):
"""
Extends pytorch_lightning.callbacks.prediction_writer.BasePredictionWriter
to store metadata for detected censorship.
"""
def __init__(self, write_interval: str = "batch", storage_path: str = '~/data') -> None:
"""
Constructor for AutoencoderWriter.
Parameters
----------
write_interval: str
See parent class BasePredictionWriter
storage_path: str
A string file path to the directory in which output will be stored.
"""
super().__init__(write_interval)
self.__storage_path = storage_path
self.__found = 0.0
def write_on_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", prediction: Any,
batch_indices: Optional[Sequence[int]], batch: Any, batch_idx: int,
dataloader_idx: int) -> None:
"""
Logic to write the results of a single batch to files.
Parameters
----------
Parameter signature defined in the parent class.
Returns
----------
void
See Also
----------
pytorch_lightning.callbacks.prediction_writer.BasePredictionWriter.write_on_batch_end
"""
meta: List[dict]
processed: pt.Tensor
meta, processed = prediction
# Copy to cpu and convert to numpy array.
prep_for_numpy = processed.cpu()
data = prep_for_numpy.numpy()
for outcome in data:
# Also step through the metadata.
outcome_meta = meta.pop(0)
if outcome >= 0.5:
self.__found += 1.0
trainer.logger.log_metrics({'found': self.__found})
# Prep storage.
storage_path = Path(self.__storage_path, outcome_meta['domain'])
storage_path.mkdir(parents=True, exist_ok=True)
detection_time = datetime.fromtimestamp(outcome_meta['timestamp']).isoformat(sep='-')
stem = re.sub(r'[\\:]', '_', detection_time)
data_storage = Path(storage_path, f"{stem}.json")
# Predicted as censored.
with data_storage.open(mode='x') as target:
json.dump(outcome_meta, target)
def write_on_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", predictions: Sequence[Any],
batch_indices: Optional[Sequence[Any]]) -> None:
"""
Implementation expected by the base class. Unused in our case.
"""
pass
class QuackDenseNet(pl.LightningModule):
"""
A modification of the Pytorch Densenet 121 pretrained model.
References
----------
https://pytorch.org/hub/pytorch_vision_densenet/
"""
def __init__(self, learning_rate: float = 1e-1, learning_rate_min: float = 1e-4,
lr_max_epochs: int = -1, freeze: bool = True, *args: Any, **kwargs: Any) -> None:
"""
Constructor for QuackDenseNet.
Parameters
----------
learning_rate: float
Hyperparameter passed to pt.optim.lr_scheduler.CosineAnnealingLR
learning_rate_min: float
Hyperparameter passed to pt.optim.lr_scheduler.CosineAnnealingLR
lr_max_epochs: int
Hyperparameter passed to pt.optim.lr_scheduler.CosineAnnealingLR
freeze: bool
Should the image analyzing layers of the pre-trained Densenet be frozen?
args: Any
Passed to the parent constructor.
kwargs: Any
Passed to the parent constructor.
"""
super().__init__(*args, **kwargs)
# Load the pre-trained densenet
pre_trained = models.densenet121(pretrained=True)
if freeze:
# Freeze the existing gradients.
pre_trained.requires_grad_(False)
# We want to replace the classifier. New instances of models have
# requires_grad = True by default.
classifier_features_in = pre_trained.classifier.in_features
pre_trained.classifier = nn.Linear(classifier_features_in, 1)
self.__densenet = pre_trained
self.__to_probability = nn.Sigmoid()
self.__learning_rate_init = learning_rate
self.__learning_rate_min = learning_rate_min
self.__lr_max_epochs = lr_max_epochs
self.__train_acc = tm.Accuracy()
self.__train_f1 = tm.F1Score(num_classes=2)
self.__val_acc = tm.Accuracy()
self.__val_f1 = tm.F1Score(num_classes=2)
self.__test_acc = tm.Accuracy()
self.__test_f1 = tm.F1Score(num_classes=2)
self.__loss_module = nn.BCEWithLogitsLoss()
# For tuning.
self.batch_size = 2
def set_balanced_loss(self, balance: float):
"""
Using the balance factor passed, even the loss.
Parameters
----------
balance: float
The proportion of negative samples / positive samples.
Returns
-------
void
References
----------
https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html?highlight=bceloss#bcewithlogitsloss
"""
if balance is not None:
balance_factor = pt.tensor(balance)
self.__loss_module = nn.BCEWithLogitsLoss(pos_weight=balance_factor)
def forward(self, x: pt.Tensor) -> pt.Tensor:
"""
Process a batch of input through the model.
Parameters
----------
x: pt.Tensor
The input, which should be (B, 3, H, W) shaped, where:
B: batch size
3: Densnet is trained on RGB = 3 channels
H: height
W: width
Returns
-------
pt.Tensor
The output, which should be (B, 1) sized, of confidence score floats.
"""
# Densenet outputs an un-normalized confidence score.
return self.__densenet(x)
def _common_step(self, x: Tuple[pt.Tensor, pt.Tensor], batch_index: int, step_id: str) -> Tuple[pt.Tensor, pt.Tensor, pt.Tensor]:
"""
The step task in each loop type shares a common set of tasks.
Parameters
----------
x: Tuple[pt.Tensor, pt.Tensor]
The batch
batch_index: int
The batch index
step_id: str
The step id.
Returns
-------
Tuple[pt.Tensor, pt.Tensor, pt.Tensor]
A tuple of batch losses, batch labels as integers, batch predictions as integers
each in a tensor.
"""
inputs, labels = x
outputs = self.forward(inputs)
loss = self.__loss_module(outputs, labels)
# Binarize predictions to 0 and 1.
outputs = self.__to_probability(outputs)
output_labels = outputs.ge(0.5).long()
# Then match to labels type.
output_labels = output_labels.to(pt.float)
log_interval_option = None if step_id == 'train' else True
log_sync = False if step_id == 'train' else True
self.log(f"{step_id}_loss", loss, on_step=log_interval_option, sync_dist=log_sync)
# Return labels and output_labels for use in accuracy, which expects integer tensors.
return loss, labels.to(pt.int8), output_labels.to(pt.int8)
def training_step(self, x: Tuple[pt.Tensor, pt.Tensor], batch_index: int) -> dict:
"""
Calls _common_step for step 'train'.
Parameters
----------
x: Tuple[pt.Tensor, pt.Tensor]
The input tensor and a label tensor
batch_index: int
The index of the batch. Required to match the parent signature. Unused in our model.
Returns
-------
dict
Format expected by the parent class. Has three keys:
loss
The loss returned by `_common_step`.
expected
The labels from the batch returned by `_common_step`.
predicted
The predicted labels from the batch returned by `_common_step`.
"""
loss, expected, predicted = self._common_step(x, batch_index, 'train')
return {'loss': loss, 'expected': expected, 'predicted': predicted}
def validation_step(self, x: Tuple[pt.Tensor, pt.Tensor], batch_index: int) -> dict:
"""
Calls _common_step for step 'val'.
Parameters
----------
x: Tuple[pt.Tensor, pt.Tensor]
The input tensor and a label tensor
batch_index: int
The index of the batch. Required to match the parent signature. Unused in our model.
Returns
-------
dict
Format expected by the parent class. Has three keys:
loss
The loss returned by `_common_step`.
expected
The labels from the batch returned by `_common_step`.
predicted
The predicted labels from the batch returned by `_common_step`.
"""
loss, expected, predicted = self._common_step(x, batch_index, 'val')
return {'loss': loss, 'expected': expected, 'predicted': predicted}
def test_step(self, x: Tuple[pt.Tensor, pt.Tensor], batch_index: int) -> dict:
"""
Calls _common_step for step 'test'.
Parameters
----------
x: Tuple[pt.Tensor, pt.Tensor]
The input tensor and a label tensor
batch_index: int
The index of the batch. Required to match the parent signature. Unused in our model.
Returns
-------
dict
Format expected by the parent class. Has three keys:
loss
The loss returned by `_common_step`.
expected
The labels from the batch returned by `_common_step`.
predicted
The predicted labels from the batch returned by `_common_step`.
"""
loss, expected, predicted = self._common_step(x, batch_index, 'test')
return {'loss': loss, 'expected': expected, 'predicted': predicted}
def predict_step(self, batch: Tuple[pt.Tensor, List[dict]], batch_idx: int, dataloader_idx: Optional[int] = None) -> Tuple[List[dict], pt.Tensor]:
"""
Calls `forward` for prediction.
Parameters
----------
batch: Tuple[pt.Tensor, List[dict]]
An tuple of a metadata dictionary and the associated input data
batch_idx: int
The index of the batch. Required to match the parent signature. Unused in our model.
dataloader_idx: int
Index of the current dataloader. Required to match the parent signature. Unused in our model.
Returns
-------
Tuple[dict, pt.Tensor]
An tuple of the batch metadata dictionary and the associated output data
"""
inputs, meta = batch
# Densenet outputs an un-normalized confidence score.
# Use sigmoid to transform to a probability.
confidence = self.forward(inputs)
output = self.__to_probability(confidence)
return meta, output
def training_step_end(self, outputs: dict, *args, **kwargs):
"""
When using distributed backends, only a portion of the batch is inside the `training_step`.
We calculate metrics here with the entire batch.
Parameters
----------
outputs: dict
The return values from `training_step` for each batch part.
args: Any
Matching to the parent constructor.
kwargs: Any
Matching to the parent constructor.
Returns
-------
void
"""
self.__train_acc(outputs['predicted'], outputs['expected'])
self.log('train_acc', self.__train_acc)
self.__train_f1(outputs['predicted'], outputs['expected'])
self.log('train_f1', self.__train_f1)
def validation_step_end(self, outputs: dict, *args, **kwargs):
"""
When using distributed backends, only a portion of the batch is inside the `validation_step`.
We calculate metrics here with the entire batch.
Parameters
----------
outputs: dict
The return values from `training_step` for each batch part.
args: Any
Matching to the parent constructor.
kwargs: Any
Matching to the parent constructor.
Returns
-------
void
"""
self.__val_acc.update(outputs['predicted'], outputs['expected'])
self.__val_f1.update(outputs['predicted'], outputs['expected'])
def test_step_end(self, outputs: dict, *args, **kwargs):
"""
When using distributed backends, only a portion of the batch is inside the `test_step`.
We calculate metrics here with the entire batch.
Parameters
----------
outputs: dict
The return values from `training_step` for each batch part.
args: Any
Matching to the parent constructor.
kwargs: Any
Matching to the parent constructor.
Returns
-------
void
"""
self.__test_acc.update(outputs['predicted'], outputs['expected'])
self.__test_f1.update(outputs['predicted'], outputs['expected'])
def test_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
"""
Called at the end of a test epoch with the output of all test steps.
Now that all the test steps are complete, we compute the metrics.
Parameters
----------
outputs: None
No outputs are passed on from `test_step_end`.
Returns
-------
void
"""
self.__test_acc.compute()
self.__test_f1.compute()
self.log('test_acc', self.__test_acc)
self.log('test_f1', self.__test_f1)
def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
"""
Called at the end of a validation epoch with the output of all test steps.
Now that all the validation steps are complete, we compute the metrics.
Parameters
----------
outputs: None
No outputs are passed on from `test_step_end`.
Returns
-------
void
"""
self.__val_acc.compute()
self.__val_f1.compute()
self.log('val_acc', self.__val_acc)
self.log('val_f1', self.__val_f1)
def configure_optimizers(self):
"""
Configures the optimizer and learning rate scheduler objects.
Returns
-------
dict
A dictionary with keys:
- optimizer: pt.optim.AdamW
- lr_scheduler: pt.optim.lr_scheduler.CosineAnnealingLR
See Also
--------
pytorch_lightning.core.lightning.LightningModule.configure_optimizers
"""
parameters = list(self.parameters())
trainable_parameters = list(filter(lambda p: p.requires_grad, parameters))
rank_zero_info(
f"The model will start training with only {len(trainable_parameters)} "
f"trainable parameters out of {len(parameters)}."
)
configured_optimizer = pt.optim.AdamW(params=trainable_parameters, lr=self.__learning_rate_init)
return {
'optimizer': configured_optimizer,
'lr_scheduler': pt.optim.lr_scheduler.CosineAnnealingLR(
optimizer=configured_optimizer,
T_max=self.__lr_max_epochs,
eta_min=self.__learning_rate_min
)
}