-
Notifications
You must be signed in to change notification settings - Fork 10
/
transform.py
82 lines (64 loc) · 2.06 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
import torch
import cv2
class PadandRandomCrop(object):
'''
Input tensor is expected to have shape of (H, W, 3)
'''
def __init__(self, border=4, cropsize=(32, 32)):
self.border = border
self.cropsize = cropsize
def __call__(self, im):
borders = [(self.border, self.border), (self.border, self.border), (0, 0)]
convas = np.pad(im, borders, mode='reflect')
H, W, C = convas.shape
h, w = self.cropsize
dh, dw = max(0, H-h), max(0, W-w)
sh, sw = np.random.randint(0, dh), np.random.randint(0, dw)
out = convas[sh:sh+h, sw:sw+w, :]
return out
class RandomHorizontalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, im):
if np.random.rand() < self.p:
im = im[:, ::-1, :]
return im
class Resize(object):
def __init__(self, size):
self.size = size
def __call__(self, im):
im = cv2.resize(im, self.size)
return im
class Normalize(object):
'''
Inputs are pixel values in range of [0, 255], channel order is 'rgb'
'''
def __init__(self, mean, std):
self.mean = np.array(mean, np.float32).reshape(1, 1, -1)
self.std = np.array(std, np.float32).reshape(1, 1, -1)
def __call__(self, im):
if len(im.shape) == 4:
mean, std = self.mean[None, ...], self.std[None, ...]
elif len(im.shape) == 3:
mean, std = self.mean, self.std
im = im.astype(np.float32) / 255.
# im = (im.astype(np.float32) / 127.5) - 1
im -= mean
im /= std
return im
class ToTensor(object):
def __init__(self):
pass
def __call__(self, im):
if len(im.shape) == 4:
return torch.from_numpy(im.transpose(0, 3, 1, 2))
elif len(im.shape) == 3:
return torch.from_numpy(im.transpose(2, 0, 1))
class Compose(object):
def __init__(self, ops):
self.ops = ops
def __call__(self, im):
for op in self.ops:
im = op(im)
return im