-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
115 lines (94 loc) · 3.11 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# References:
# https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/data/unaligned_dataset.py
from torch.utils.data import Dataset
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from PIL import Image
from pathlib import Path
import random
import config
class UnpairedImageDataset(Dataset):
def __init__(
self,
data_dir,
x_mean,
x_std,
y_mean,
y_std,
fixed_pairs=False,
split="train",
):
super().__init__()
self.x_mean = x_mean
self.x_std = x_std
self.y_mean = y_mean
self.y_std = y_std
self.fixed_pairs = fixed_pairs
self.split = split
self.x_paths = list(Path(data_dir).glob(f"""{split}A/*.jpg"""))
self.x_len = len(self.x_paths)
self.y_paths = list(Path(data_dir).glob(f"""{split}B/*.jpg"""))
self.y_len = len(self.y_paths)
self.rand_resized_crop = T.RandomResizedCrop(
size=config.IMG_SIZE, scale=config.SCALE, ratio=(1, 1), antialias=True,
) # Not in the paper.
def transform(self, x, y):
x = self.rand_resized_crop(x)
y = self.rand_resized_crop(y)
if self.split == "train":
if random.random() > 0.5:
x = TF.hflip(x)
if random.random() > 0.5:
y = TF.hflip(y)
x = T.ToTensor()(x)
x = T.Normalize(mean=self.x_mean, std=self.x_std)(x)
y = T.ToTensor()(y)
y = T.Normalize(mean=self.y_mean, std=self.y_std)(y)
return x, y
def __len__(self):
return max(self.x_len, self.y_len)
def __getitem__(self, idx):
if self.fixed_pairs:
x_path = self.x_paths[idx % self.x_len]
y_path = self.y_paths[idx % self.y_len]
elif self.x_len >= self.y_len:
x_path = self.x_paths[idx]
y_path = random.choice(self.y_paths)
else:
y_path = self.y_paths[idx]
x_path = random.choice(self.x_paths)
x = Image.open(x_path).convert("RGB")
y = Image.open(y_path).convert("RGB")
x, y = self.transform(x=x, y=y)
return x, y
class OneSideImageDataset(Dataset):
def __init__(
self,
data_dir,
x_or_y,
mean,
std,
split="train",
):
super().__init__()
self.x_or_y = x_or_y
self.mean = mean
self.std = std
self.split = split
if x_or_y == "x":
self.paths = list(Path(data_dir).glob(f"""{split}A/*.jpg"""))
elif x_or_y == "y":
self.paths = list(Path(data_dir).glob(f"""{split}B/*.jpg"""))
def transform(self, image):
if self.split == "train":
if random.random() > 0.5:
image = TF.hflip(image)
image = T.ToTensor()(image)
image = T.Normalize(mean=self.mean, std=self.std)(image)
return image
def __len__(self):
return len(self.paths)
def __getitem__(self, idx):
image = Image.open(self.paths[idx]).convert("RGB")
image = self.transform(image)
return image