From 88f0e94db0341b873bf6ab9ec8eac44b0d30edce Mon Sep 17 00:00:00 2001 From: Blake Dewey Date: Sun, 7 Jul 2024 08:58:15 -0400 Subject: [PATCH] update type hints for mypy --- src/torchio/data/io.py | 6 +++--- .../augmentation/intensity/random_motion.py | 1 + .../augmentation/spatial/random_flip.py | 2 +- src/torchio/transforms/transform.py | 21 ++++++++++++------- src/torchio/utils.py | 5 ++--- 5 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/torchio/data/io.py b/src/torchio/data/io.py index 8513f9fe..576f3903 100644 --- a/src/torchio/data/io.py +++ b/src/torchio/data/io.py @@ -96,7 +96,7 @@ def read_shape(path: TypePath) -> TypeQuartetInt: message = f'Error loading image with SimpleITK:\n{e}\n\nTrying NiBabel...' warnings.warn(message, stacklevel=2) try: - obj = nib.load(str(path)) + obj: SpatialImage = nib.load(str(path)) # type: ignore[assignment] except nib.loadsave.ImageFileError as e: message = ( f'File "{path}" not understood.' @@ -105,8 +105,8 @@ def read_shape(path: TypePath) -> TypeQuartetInt: ' and https://nipy.org/nibabel/api.html#file-formats' ) raise RuntimeError(message) from e - num_dimensions = obj.header.get('dim')[0] - shape = obj.header.get('dim')[1 : 1 + num_dimensions] + num_dimensions = obj.ndim + shape = obj.shape num_channels = 1 if num_dimensions < 4 else shape[-1] assert 2 <= num_dimensions <= 4 if num_dimensions == 2: diff --git a/src/torchio/transforms/augmentation/intensity/random_motion.py b/src/torchio/transforms/augmentation/intensity/random_motion.py index a9111f3a..b29271a6 100644 --- a/src/torchio/transforms/augmentation/intensity/random_motion.py +++ b/src/torchio/transforms/augmentation/intensity/random_motion.py @@ -206,6 +206,7 @@ def apply_transform(self, subject: Subject) -> Subject: np.asarray(translation), sitk_image, ) + assert isinstance(axis, int) assert isinstance(image_interpolation, str) transformed_channel = self.add_artifact( sitk_image, diff --git a/src/torchio/transforms/augmentation/spatial/random_flip.py b/src/torchio/transforms/augmentation/spatial/random_flip.py index f1951d18..abbacc17 100644 --- a/src/torchio/transforms/augmentation/spatial/random_flip.py +++ b/src/torchio/transforms/augmentation/spatial/random_flip.py @@ -89,7 +89,7 @@ def __init__( ): super().__init__(**kwargs) self.axes = self.parse_axes(axes) - self.args_names = ('axes',) + self.args_names = ['axes'] def apply_transform(self, subject: Subject) -> Subject: axes = self.ensure_axes_indices(subject, self.axes) diff --git a/src/torchio/transforms/transform.py b/src/torchio/transforms/transform.py index 37397596..2618af79 100644 --- a/src/torchio/transforms/transform.py +++ b/src/torchio/transforms/transform.py @@ -396,7 +396,7 @@ def parse_include_and_exclude_keys( @staticmethod def parse_axes( - axes: Union[int, Tuple[int, ...], str, Tuple[str, ...]], + axes: Union[int, str, Tuple[int, ...], Tuple[str, ...]], ) -> Union[Tuple[int, ...], Tuple[str, ...]]: axes_tuple = to_tuple(axes) for axis in axes_tuple: @@ -416,14 +416,19 @@ def ensure_axes_indices( axes: Union[Tuple[int, ...], Tuple[str, ...]], ) -> Tuple[int, ...]: image = subject.get_first_image() - if any(isinstance(n, str) for n in axes): + if any(isinstance(n, str) for n in axes): # axis strings subject.check_consistent_orientation() - axes = tuple(sorted({3 + image.axis_name_to_index(n) for n in axes})) - if image.is_2d() and 2 in axes: - axes = list(axes) - axes.remove(2) - axes = tuple(axes) - return axes + int_axes = tuple( + { + (3 + image.axis_name_to_index(n)) if isinstance(n, str) else int(n) + for n in axes + } + ) + if image.is_2d() and 2 in int_axes: + list_axes = list(int_axes) + list_axes.remove(2) + int_axes = tuple(list_axes) + return int_axes @staticmethod def validate_keys_sequence(keys: TypeKeys, name: str) -> None: diff --git a/src/torchio/utils.py b/src/torchio/utils.py index 17da1295..dadaf232 100644 --- a/src/torchio/utils.py +++ b/src/torchio/utils.py @@ -25,14 +25,13 @@ from tqdm.auto import trange from . import constants -from .typing import TypeNumber from .typing import TypePath def to_tuple( - value: Any, + value: Union[Any, Iterable[Any]], length: int = 1, -) -> Tuple[TypeNumber, ...]: +) -> Tuple[Any, ...]: """Convert variable to tuple of length n. Example: