Skip to content

Commit

Permalink
add anatomical axis tests to RandomFlip, RandomGhosting, and RandomMo…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
blakedewey committed Jul 6, 2024
1 parent 1d8761c commit a7eb171
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 0 deletions.
4 changes: 4 additions & 0 deletions tests/transforms/augmentation/test_random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def test_wrong_flip_probability_type(self):
with pytest.raises(ValueError):
tio.RandomFlip(flip_probability='wrong')

def test_wrong_anatomical_axis(self):
with pytest.raises(ValueError):
tio.RandomFlip(axes=('g',))

def test_anatomical_axis(self):
transform = tio.RandomFlip(axes=['i'], flip_probability=1)
tensor = torch.rand(1, 2, 3, 4)
Expand Down
12 changes: 12 additions & 0 deletions tests/transforms/augmentation/test_random_ghosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def test_with_ghosting(self):
transformed.t1.data,
)

def test_anatomical_axis(self):
transform = RandomGhosting(axes=['a'])
transformed = transform(self.sample_subject)
self.assert_tensor_not_equal(
self.sample_subject.t1.data,
transformed.t1.data,
)

def test_intensity_range_with_negative_min(self):
with pytest.raises(ValueError):
RandomGhosting(intensity=(-0.5, 4))
Expand Down Expand Up @@ -74,3 +82,7 @@ def test_out_of_range_restore(self):
def test_wrong_restore_type(self):
with pytest.raises(TypeError):
RandomGhosting(restore='wrong')

def test_wrong_anatomical_axis(self):
with pytest.raises(ValueError):
RandomGhosting(axes=('v',))
12 changes: 12 additions & 0 deletions tests/transforms/augmentation/test_random_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ def test_with_movement(self):
transformed.t1.data,
)

def test_anatomical_axis(self):
transform = RandomMotion(axes=('a',))
transformed = transform(self.sample_subject)
self.assert_tensor_not_equal(
self.sample_subject.t1.data,
transformed.t1.data,
)

def test_negative_degrees(self):
with pytest.raises(ValueError):
RandomMotion(degrees=-10)
Expand Down Expand Up @@ -70,3 +78,7 @@ def test_out_of_range_axis_in_tuple(self):
def test_wrong_axes_type(self):
with pytest.raises(ValueError):
RandomMotion(axes=None)

def test_wrong_anatomical_axis(self):
with pytest.raises(ValueError):
RandomMotion(axes=('C',))

0 comments on commit a7eb171

Please sign in to comment.