-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
112 lines (98 loc) · 3.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torchvision
from dataset import CarvanaDataset
from torch.utils.data import DataLoader
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
torch.save(state, filename)
def load_checkpoint(checkpoint, model):
print("=> Loading checkpoint")
model.load_state_dict(checkpoint["state_dict"])
def get_loaders(
train_dir,
train_maskdir,
val_dir,
val_maskdir,
batch_size,
train_transform,
val_transform,
num_workers=4,
pin_memory=True,
):
train_ds = CarvanaDataset(
image_dir=train_dir,
mask_dir=train_maskdir,
count= 50,
transform=train_transform,
)
train_loader = DataLoader(
train_ds,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle=1,
)
val_ds = CarvanaDataset(
image_dir=val_dir,
mask_dir=val_maskdir,
count=1,
transform=val_transform,
)
val_loader = DataLoader(
val_ds,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
shuffle=False,
)
return train_loader, val_loader
def check_accuracy(loader, model, device="cuda"):
num_correct = 0
num_pixels = 0
dice_score = 0 # pred vs ground truth
TP = 0 # detect defect correct 1 vs 1
TN = 0 # detect non-defect correct 0 vs 0
FP = 0 # thinks non-defect is defect 1 vs 0
FN = 0 # thinks defect is non defect 0 vs 1
total_gt_pos = 0; # total ground truth positive
total_gt_neg = 0;
model.eval()
with torch.no_grad():
for x, y in loader:
x = x.to(device)
y = y.to(device).unsqueeze(1)
preds = torch.sigmoid(model(x))
preds = (preds > 0.5).float()
num_correct += (preds == y).sum()
confusion_vector = preds/y
TP += torch.sum(confusion_vector == 1).item()
TN += torch.sum(torch.isnan(confusion_vector)).item()
# FP += torch.sum(confusion_vector == float('inf')).item()
# FN += torch.sum(confusion_vector == 0).item()
total_gt_pos += (y == 1).sum()
total_gt_neg += (y == 0).sum()
num_pixels += torch.numel(preds)
dice_score += (2 * (preds * y).sum()) / (
(preds + y).sum() + 1e-8
)
print(
f"Got {num_correct}/{num_pixels} with acc {num_correct/num_pixels*100:.2f}"
)
print(f"TP: {TP/total_gt_pos}, TN: {TN/total_gt_neg}") #, FP: {FP/num_pixels}, FN: {FN/num_pixels}
print(f"Dice score: {dice_score/len(loader)}")
model.train()
return dice_score/len(loader), num_correct/num_pixels*100, TP/total_gt_pos,TN/total_gt_neg #,FP/num_pixels,FN/num_pixels
def save_predictions_as_imgs(
loader, model, epoch, folder="saved_images/", device="cuda"
):
model.eval()
for idx, (x, y) in enumerate(loader):
x = x.to(device=device)
with torch.no_grad():
preds = torch.sigmoid(model(x))
preds = (preds > 0.5).float()
torchvision.utils.save_image(
preds, f"{folder}/ep_{epoch}-pred_{idx}.png"
)
torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")
model.train()