Skip to content

Commit

Permalink
update type hints for mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
blakedewey committed Jul 7, 2024
1 parent 671a373 commit 88f0e94
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 15 deletions.
6 changes: 3 additions & 3 deletions src/torchio/data/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def read_shape(path: TypePath) -> TypeQuartetInt:
message = f'Error loading image with SimpleITK:\n{e}\n\nTrying NiBabel...'
warnings.warn(message, stacklevel=2)
try:
obj = nib.load(str(path))
obj: SpatialImage = nib.load(str(path)) # type: ignore[assignment]
except nib.loadsave.ImageFileError as e:
message = (
f'File "{path}" not understood.'
Expand All @@ -105,8 +105,8 @@ def read_shape(path: TypePath) -> TypeQuartetInt:
' and https://nipy.org/nibabel/api.html#file-formats'
)
raise RuntimeError(message) from e
num_dimensions = obj.header.get('dim')[0]
shape = obj.header.get('dim')[1 : 1 + num_dimensions]
num_dimensions = obj.ndim
shape = obj.shape
num_channels = 1 if num_dimensions < 4 else shape[-1]
assert 2 <= num_dimensions <= 4
if num_dimensions == 2:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def apply_transform(self, subject: Subject) -> Subject:
np.asarray(translation),
sitk_image,
)
assert isinstance(axis, int)
assert isinstance(image_interpolation, str)
transformed_channel = self.add_artifact(
sitk_image,
Expand Down
2 changes: 1 addition & 1 deletion src/torchio/transforms/augmentation/spatial/random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
):
super().__init__(**kwargs)
self.axes = self.parse_axes(axes)
self.args_names = ('axes',)
self.args_names = ['axes']

def apply_transform(self, subject: Subject) -> Subject:
axes = self.ensure_axes_indices(subject, self.axes)
Expand Down
21 changes: 13 additions & 8 deletions src/torchio/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def parse_include_and_exclude_keys(

@staticmethod
def parse_axes(
axes: Union[int, Tuple[int, ...], str, Tuple[str, ...]],
axes: Union[int, str, Tuple[int, ...], Tuple[str, ...]],
) -> Union[Tuple[int, ...], Tuple[str, ...]]:
axes_tuple = to_tuple(axes)
for axis in axes_tuple:
Expand All @@ -416,14 +416,19 @@ def ensure_axes_indices(
axes: Union[Tuple[int, ...], Tuple[str, ...]],
) -> Tuple[int, ...]:
image = subject.get_first_image()
if any(isinstance(n, str) for n in axes):
if any(isinstance(n, str) for n in axes): # axis strings
subject.check_consistent_orientation()
axes = tuple(sorted({3 + image.axis_name_to_index(n) for n in axes}))
if image.is_2d() and 2 in axes:
axes = list(axes)
axes.remove(2)
axes = tuple(axes)
return axes
int_axes = tuple(
{
(3 + image.axis_name_to_index(n)) if isinstance(n, str) else int(n)
for n in axes
}
)
if image.is_2d() and 2 in int_axes:
list_axes = list(int_axes)
list_axes.remove(2)
int_axes = tuple(list_axes)
return int_axes

@staticmethod
def validate_keys_sequence(keys: TypeKeys, name: str) -> None:
Expand Down
5 changes: 2 additions & 3 deletions src/torchio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,13 @@
from tqdm.auto import trange

from . import constants
from .typing import TypeNumber
from .typing import TypePath


def to_tuple(
value: Any,
value: Union[Any, Iterable[Any]],
length: int = 1,
) -> Tuple[TypeNumber, ...]:
) -> Tuple[Any, ...]:
"""Convert variable to tuple of length n.
Example:
Expand Down

0 comments on commit 88f0e94

Please sign in to comment.