diff --git a/src/torchio/transforms/augmentation/intensity/random_motion.py b/src/torchio/transforms/augmentation/intensity/random_motion.py index f00828da..7aa0b4df 100644 --- a/src/torchio/transforms/augmentation/intensity/random_motion.py +++ b/src/torchio/transforms/augmentation/intensity/random_motion.py @@ -224,35 +224,18 @@ def get_rigid_transforms( ) -> List[sitk.Euler3DTransform]: center_ijk = np.array(image.GetSize()) / 2 center_lps = image.TransformContinuousIndexToPhysicalPoint(center_ijk) - identity = np.eye(4) - matrices = [identity] + ident_transform = sitk.Euler3DTransform() + ident_transform.SetCenter(center_lps) + transforms = [ident_transform] for degrees, translation in zip(degrees_params, translation_params): radians = np.radians(degrees).tolist() motion = sitk.Euler3DTransform() motion.SetCenter(center_lps) motion.SetRotation(*radians) motion.SetTranslation(translation.tolist()) - motion_matrix = self.transform_to_matrix(motion) - matrices.append(motion_matrix) - transforms = [self.matrix_to_transform(m) for m in matrices] + transforms.append(motion) return transforms - @staticmethod - def transform_to_matrix(transform: sitk.Euler3DTransform) -> np.ndarray: - matrix = np.eye(4) - rotation = np.array(transform.GetMatrix()).reshape(3, 3) - matrix[:3, :3] = rotation - matrix[:3, 3] = transform.GetTranslation() - return matrix - - @staticmethod - def matrix_to_transform(matrix: np.ndarray) -> sitk.Euler3DTransform: - transform = sitk.Euler3DTransform() - rotation = matrix[:3, :3].flatten().tolist() - transform.SetMatrix(rotation) - transform.SetTranslation(matrix[:3, 3]) - return transform - def resample_images( self, image: sitk.Image, @@ -261,10 +244,10 @@ def resample_images( ) -> List[sitk.Image]: floating = reference = image default_value = np.float64(sitk.GetArrayViewFromImage(image).min()) + interpolator = self.get_sitk_interpolator(interpolation) transforms = transforms[1:] # first is identity images = [image] # first is identity for transform in transforms: - interpolator = self.get_sitk_interpolator(interpolation) resampler = sitk.ResampleImageFilter() resampler.SetInterpolator(interpolator) resampler.SetReferenceImage(reference) @@ -301,7 +284,7 @@ def add_artifact( spectra.append(spectrum) self.sort_spectra(spectra, times) result_spectrum = torch.empty_like(spectra[0]) - last_index = result_spectrum.shape[2] + last_index = result_spectrum.shape[axis] indices = (last_index * times).astype(int).tolist() indices.append(last_index) ini = 0