-
Notifications
You must be signed in to change notification settings - Fork 0
/
Loader.py
64 lines (55 loc) · 2.08 KB
/
Loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
from PIL import Image
import torch
from torchvision import transforms
import random
import time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def loader_train(batch_size, idx_helper, cls_1, cls_2):
transform_train = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomRotation(60),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.ToTensor()
])
images = []
label = torch.zeros(batch_size, 2).to(device)
current_time = time.time()
random.seed(int(current_time))
d_types = [cls_1, cls_2]
for i in range(batch_size):
flag = random.randint(0, 1)
path = 'Skin Cancer/Skin Cancer/' + d_types[flag]
imgs = os.listdir(path)
idx = idx_helper[flag]
idx_helper[flag] += 1
img_path = os.path.join(path, imgs[idx % len(imgs)])
image = Image.open(img_path).convert('RGB')
image = transform_train(image)
images.append(image)
label[i, flag] = 1
images = torch.stack(images, dim=0).to(device)
return images, label, idx_helper
def loader_test(batch_size, flag, cls_1, cls_2):
transform_test = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
images = []
label = torch.zeros(batch_size, 2).to(device)
nums = {'akiec': 60, 'mel': 220, 'bkl': 220, 'nv': 1340, 'bcc': 100, 'vasc': 30, 'df': 20}
flags = {'akiec': 60, 'mel': 280, 'bkl': 500, 'nv': 1840, 'bcc': 1940, 'vasc': 1970, 'df': 1990}
for i in range(batch_size):
path = 'Skin Cancer/Skin Cancer/test'
imgs = os.listdir(path)
img_path = os.path.join(path, imgs[i + flag])
image = Image.open(img_path).convert('RGB')
image = transform_test(image)
images.append(image)
if flag + i + 1 <= flags[cls_1] and flag + i + 1 <= flags[cls_2] - nums[cls_2]: # cls_1
label[i, 0] = 1
else: # cls_2
label[i, 1] = 1
images = torch.stack(images, dim=0).to(device)
return images, label