-
Notifications
You must be signed in to change notification settings - Fork 1
/
calc_mean_std.py
90 lines (81 loc) · 2.44 KB
/
calc_mean_std.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader
import albumentations as a
from albumentations.pytorch.transforms import ToTensorV2
from torch.utils.data import Dataset
from skimage import io as skio
from skimage.transform import resize
from tqdm import tqdm
data_folder = Path("/data/MFWD/patches/train")
bs = 5000
class MFWDDataset(Dataset):
def __init__(self, data_folder:Path, labels, img_size:tuple, transforms):
self.data_folder = data_folder
self.img_size = img_size
self.transforms = transforms
self.img_list = sorted(list(data_folder.glob("*/*")))
self.labels = labels
return
def _load_images(self, img_path):
img = skio.imread(str(img_path))
img = resize(img, self.img_size, clip=True, preserve_range=True, anti_aliasing=True)
img = img.astype(np.uint8)
return img
def __len__(self):
return len(self.img_list)
def __getitem__(self, item):
img_path = self.img_list[item]
label = img_path.parent.stem
label_id = self.labels[f"{label}"]
img = self._load_images(img_path)
if self.transforms:
img = self.transforms(image=img)["image"]
return img.type(torch.float32), label_id
train_transform = a.Compose([
ToTensorV2(),
])
labels = {"ACHMI": 0,
"AETCY": 1,
"AGRRE": 2,
"ALOMY": 3,
"ARTVU": 4,
"CHEAL": 5,
"CIRAR": 6,
"CONAR": 7,
"ECHCG": 8,
"GALAP": 9,
"GASPA": 10,
"GERMO": 11,
"LAMAL": 12,
"MATCH": 13,
"PLAMA": 14,
"POAAN": 15,
"POLAM": 16,
"POLCO": 17,
"POROL": 18,
"PULDY": 19,
"SOLNI": 20,
"SORVU": 21,
"SSYOF": 22,
"STEME": 23,
"THLAR": 24,
"VEROF": 25,
"VIOAR": 26}
ds = MFWDDataset(data_folder=data_folder, labels=labels, img_size=(224,224), transforms=train_transform)
loader = DataLoader(ds, batch_size=bs, shuffle=False)
mean = 0.
std = 0.
nb_samples = 0.
tqdm_l = tqdm(loader, total = len(ds)//bs + 1)
for data, lbls in tqdm_l:
batch_samples = data.size(0)
data = data.view(batch_samples, data.size(1), -1)
mean += data.mean(2).sum(0)
std += data.std(2).sum(0)
nb_samples += batch_samples
mean /= nb_samples
std /= nb_samples
print(mean.numpy()/255.0)
print(std.numpy()/255.0)