Skip to content

Commit

Permalink
remove matrix to transform and back. add axis for last_index
Browse files Browse the repository at this point in the history
  • Loading branch information
blakedewey committed Jul 7, 2024
1 parent a7eb171 commit aaa7be9
Showing 1 changed file with 6 additions and 23 deletions.
29 changes: 6 additions & 23 deletions src/torchio/transforms/augmentation/intensity/random_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,35 +224,18 @@ def get_rigid_transforms(
) -> List[sitk.Euler3DTransform]:
center_ijk = np.array(image.GetSize()) / 2
center_lps = image.TransformContinuousIndexToPhysicalPoint(center_ijk)
identity = np.eye(4)
matrices = [identity]
ident_transform = sitk.Euler3DTransform()
ident_transform.SetCenter(center_lps)
transforms = [ident_transform]
for degrees, translation in zip(degrees_params, translation_params):
radians = np.radians(degrees).tolist()
motion = sitk.Euler3DTransform()
motion.SetCenter(center_lps)
motion.SetRotation(*radians)
motion.SetTranslation(translation.tolist())
motion_matrix = self.transform_to_matrix(motion)
matrices.append(motion_matrix)
transforms = [self.matrix_to_transform(m) for m in matrices]
transforms.append(motion)
return transforms

@staticmethod
def transform_to_matrix(transform: sitk.Euler3DTransform) -> np.ndarray:
matrix = np.eye(4)
rotation = np.array(transform.GetMatrix()).reshape(3, 3)
matrix[:3, :3] = rotation
matrix[:3, 3] = transform.GetTranslation()
return matrix

@staticmethod
def matrix_to_transform(matrix: np.ndarray) -> sitk.Euler3DTransform:
transform = sitk.Euler3DTransform()
rotation = matrix[:3, :3].flatten().tolist()
transform.SetMatrix(rotation)
transform.SetTranslation(matrix[:3, 3])
return transform

def resample_images(
self,
image: sitk.Image,
Expand All @@ -261,10 +244,10 @@ def resample_images(
) -> List[sitk.Image]:
floating = reference = image
default_value = np.float64(sitk.GetArrayViewFromImage(image).min())
interpolator = self.get_sitk_interpolator(interpolation)
transforms = transforms[1:] # first is identity
images = [image] # first is identity
for transform in transforms:
interpolator = self.get_sitk_interpolator(interpolation)
resampler = sitk.ResampleImageFilter()
resampler.SetInterpolator(interpolator)
resampler.SetReferenceImage(reference)
Expand Down Expand Up @@ -301,7 +284,7 @@ def add_artifact(
spectra.append(spectrum)
self.sort_spectra(spectra, times)
result_spectrum = torch.empty_like(spectra[0])
last_index = result_spectrum.shape[2]
last_index = result_spectrum.shape[axis]
indices = (last_index * times).astype(int).tolist()
indices.append(last_index)
ini = 0
Expand Down

0 comments on commit aaa7be9

Please sign in to comment.