-
Notifications
You must be signed in to change notification settings - Fork 0
/
rotation_mm.py
115 lines (97 loc) · 4.54 KB
/
rotation_mm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import math
import torch
import random
import numbers
import torchvision.transforms.functional as F
class RandomMuellerRotation(object):
"""Rotate the Mueller matrix frame by angle.
Args:
degrees (sequence or float or int): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees).
resample ({PIL.Image.NEAREST, PIL.Image.BILINEAR, PIL.Image.BICUBIC}, optional):
An optional resampling filter. See `filters`_ for more information.
If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST.
expand (bool, optional): Optional expansion flag.
If true, expands the output to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
center (2-tuple, optional): Optional center of rotation.
Origin is the upper left corner.
Default is the center of the image.
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
If int, it is used for all channels respectively.
p (float): probability threshold with which image should be rotated or left untreated instead.
any (bool): If false, only angles [90, 180, 270] will be chosen. Otherwise, a random angle
within the boundaries will be generated (default).
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
def __init__(self, degrees, resample=False, expand=False, center=None, fill=0, p=0.5, any=True):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
if len(degrees) != 2:
raise ValueError("If degrees is a sequence, it must be of len 2.")
self.degrees = degrees
self.resample = resample
self.expand = expand
self.center = center
self.fill = fill
self.p = p
self.any = any
def get_params(self, degrees):
"""Get parameters for ``rotate`` for a random rotation.
Returns:
sequence: params to be passed to ``rotate`` for random rotation.
"""
angle = random.uniform(degrees[0], degrees[1]) if self.any else random.choice([90, 180, 270])
return angle
@staticmethod
def get_rmat(deg):
"""Get rotation matrix for Mueller matrix.
Returns:
sequence: rotation matrix for Mueller matrix.
"""
theta = deg / 180 * math.pi
rmat = torch.tensor([
[1, 0, 0, 0],
[0, math.cos(2*theta), -math.sin(2*theta), 0],
[0, math.sin(2*theta), +math.cos(2*theta), 0],
[0, 0, 0, 1],
])
return rmat
def __call__(self, img, label=None, angle=None, *args, **kwargs):
"""
Args:
img (PIL Image): Image to be rotated.
Returns:
PIL Image: Rotated image.
"""
if random.random() < self.p:
# spatial transformation
angle = self.get_params(self.degrees) if angle is None else angle
fill16 = torch.eye(4).flatten().tolist()
rotated_img = F.rotate(img, angle, self.resample, self.expand, self.center, fill16)
rotated_img = rotated_img.moveaxis(0, 2)
# mueller matrix transformation
T = self.get_rmat(angle)
rotated_img = T @ rotated_img.view(*rotated_img.shape[:-1], 4, 4) @ T.transpose(-2, -1)
rotated_img = rotated_img.flatten(-2, -1).moveaxis(2, 0)
if label is not None:
rotated_label = F.rotate(label, angle, self.resample, self.expand, self.center, self.fill)
return rotated_img, rotated_label
return rotated_img
else:
if label is not None:
return img, label
return img
def __repr__(self):
format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
format_string += ', resample={0}'.format(self.resample)
format_string += ', expand={0}'.format(self.expand)
if self.center is not None:
format_string += ', center={0}'.format(self.center)
format_string += ')'
return format_string