diff --git a/tests/transforms/augmentation/test_random_flip.py b/tests/transforms/augmentation/test_random_flip.py index 79bc52f9..d766b591 100644 --- a/tests/transforms/augmentation/test_random_flip.py +++ b/tests/transforms/augmentation/test_random_flip.py @@ -31,6 +31,10 @@ def test_wrong_flip_probability_type(self): with pytest.raises(ValueError): tio.RandomFlip(flip_probability='wrong') + def test_wrong_anatomical_axis(self): + with pytest.raises(ValueError): + tio.RandomFlip(axes=('g',)) + def test_anatomical_axis(self): transform = tio.RandomFlip(axes=['i'], flip_probability=1) tensor = torch.rand(1, 2, 3, 4) diff --git a/tests/transforms/augmentation/test_random_ghosting.py b/tests/transforms/augmentation/test_random_ghosting.py index 944aafe4..b969538a 100644 --- a/tests/transforms/augmentation/test_random_ghosting.py +++ b/tests/transforms/augmentation/test_random_ghosting.py @@ -31,6 +31,14 @@ def test_with_ghosting(self): transformed.t1.data, ) + def test_anatomical_axis(self): + transform = RandomGhosting(axes=['a']) + transformed = transform(self.sample_subject) + self.assert_tensor_not_equal( + self.sample_subject.t1.data, + transformed.t1.data, + ) + def test_intensity_range_with_negative_min(self): with pytest.raises(ValueError): RandomGhosting(intensity=(-0.5, 4)) @@ -74,3 +82,7 @@ def test_out_of_range_restore(self): def test_wrong_restore_type(self): with pytest.raises(TypeError): RandomGhosting(restore='wrong') + + def test_wrong_anatomical_axis(self): + with pytest.raises(ValueError): + RandomGhosting(axes=('v',)) diff --git a/tests/transforms/augmentation/test_random_motion.py b/tests/transforms/augmentation/test_random_motion.py index ec2f9676..70bf3a5b 100644 --- a/tests/transforms/augmentation/test_random_motion.py +++ b/tests/transforms/augmentation/test_random_motion.py @@ -35,6 +35,14 @@ def test_with_movement(self): transformed.t1.data, ) + def test_anatomical_axis(self): + transform = RandomMotion(axes=('a',)) + transformed = transform(self.sample_subject) + self.assert_tensor_not_equal( + self.sample_subject.t1.data, + transformed.t1.data, + ) + def test_negative_degrees(self): with pytest.raises(ValueError): RandomMotion(degrees=-10) @@ -70,3 +78,7 @@ def test_out_of_range_axis_in_tuple(self): def test_wrong_axes_type(self): with pytest.raises(ValueError): RandomMotion(axes=None) + + def test_wrong_anatomical_axis(self): + with pytest.raises(ValueError): + RandomMotion(axes=('C',))