-
Notifications
You must be signed in to change notification settings - Fork 3
/
datasets.py
79 lines (67 loc) · 2.53 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import random
import h5py
import numpy as np
from torch.utils.data import Dataset
class TrainDataset(Dataset):
def __init__(self, h5_file, patch_size, scale):
super(TrainDataset, self).__init__()
self.h5_file = h5_file
self.patch_size = patch_size
self.scale = scale
@staticmethod
def random_crop(lr, hr, size, scale):
lr_left = random.randint(0, lr.shape[1] - size)
lr_right = lr_left + size
lr_top = random.randint(0, lr.shape[0] - size)
lr_bottom = lr_top + size
hr_left = lr_left * scale
hr_right = lr_right * scale
hr_top = lr_top * scale
hr_bottom = lr_bottom * scale
lr = lr[lr_top:lr_bottom, lr_left:lr_right]
hr = hr[hr_top:hr_bottom, hr_left:hr_right]
return lr, hr
@staticmethod
def random_horizontal_flip(lr, hr):
if random.random() < 0.5:
lr = lr[:, ::-1, :].copy()
hr = hr[:, ::-1, :].copy()
return lr, hr
@staticmethod
def random_vertical_flip(lr, hr):
if random.random() < 0.5:
lr = lr[::-1, :, :].copy()
hr = hr[::-1, :, :].copy()
return lr, hr
@staticmethod
def random_rotate_90(lr, hr):
if random.random() < 0.5:
lr = np.rot90(lr, axes=(1, 0)).copy()
hr = np.rot90(hr, axes=(1, 0)).copy()
return lr, hr
def __getitem__(self, idx):
with h5py.File(self.h5_file, 'r') as f:
lr = f['lr'][str(idx)][::]
hr = f['hr'][str(idx)][::]
lr, hr = self.random_crop(lr, hr, self.patch_size, self.scale)
lr, hr = self.random_horizontal_flip(lr, hr)
lr, hr = self.random_vertical_flip(lr, hr)
lr, hr = self.random_rotate_90(lr, hr)
lr = lr.astype(np.float32).transpose([2, 0, 1]) / 255.0
hr = hr.astype(np.float32).transpose([2, 0, 1]) / 255.0
return lr, hr
def __len__(self):
with h5py.File(self.h5_file, 'r') as f:
return len(f['lr'])
class EvalDataset(Dataset):
def __init__(self, h5_file):
super(EvalDataset, self).__init__()
self.h5_file = h5_file
def __getitem__(self, idx):
with h5py.File(self.h5_file, 'r') as f:
lr = f['lr'][str(idx)][::].astype(np.float32).transpose([2, 0, 1]) / 255.0
hr = f['hr'][str(idx)][::].astype(np.float32).transpose([2, 0, 1]) / 255.0
return lr, hr
def __len__(self):
with h5py.File(self.h5_file, 'r') as f:
return len(f['lr'])