Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RF: Upgrade nitransforms and remove workarounds #3378

Merged
merged 5 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fmriprep/interfaces/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import nibabel as nb
import nitransforms as nt
import nitransforms.resampling
import numpy as np
from nipype.interfaces.base import (
File,
Expand Down Expand Up @@ -689,7 +690,7 @@ def reconstruct_fieldmap(
)

if not direct:
fmap_img = transforms.apply(fmap_img, reference=target)
fmap_img = nt.resampling.apply(transforms, fmap_img, reference=target)

fmap_img.header.set_intent('estimate', name='fieldmap Hz')
fmap_img.header.set_data_dtype('float32')
Expand Down
77 changes: 3 additions & 74 deletions fmriprep/utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@

from pathlib import Path

import h5py
import nibabel as nb
import nitransforms as nt
import numpy as np
from nitransforms.io.itk import ITKCompositeH5
from transforms3d.affines import compose as compose_affine


def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.TransformBase:
Expand All @@ -24,7 +19,8 @@ def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.Trans
for path, inv in zip(xfm_paths[::-1], inverse[::-1], strict=False):
path = Path(path)
if path.suffix == '.h5':
xfm = load_ants_h5(path)
# Load as a TransformChain
xfm = nt.manip.load(path)
else:
xfm = nt.linear.load(path)
if inv:
Expand All @@ -34,72 +30,5 @@ def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.Trans
else:
chain += xfm
if chain is None:
chain = nt.base.TransformBase()
chain = nt.Affine() # Identity
return chain


def load_ants_h5(filename: Path) -> nt.base.TransformBase:
"""Load ANTs H5 files as a nitransforms TransformChain"""
# Borrowed from https://github.com/feilong/process
# process.resample.parse_combined_hdf5()
#
# Changes:
# * Tolerate a missing displacement field
# * Return the original affine without a round-trip
# * Always return a nitransforms TransformBase
# * Construct warp affine from fixed parameters
#
# This should be upstreamed into nitransforms
h = h5py.File(filename)
xform = ITKCompositeH5.from_h5obj(h)

# nt.Affine
transforms = [nt.Affine(xform[0].to_ras())]

if '2' not in h['TransformGroup']:
return transforms[0]

transform2 = h['TransformGroup']['2']

# Confirm these transformations are applicable
if transform2['TransformType'][:][0] not in (
b'DisplacementFieldTransform_float_3_3',
b'DisplacementFieldTransform_double_3_3',
):
msg = 'Unknown transform type [2]\n'
for i in h['TransformGroup'].keys():
msg += f'[{i}]: {h["TransformGroup"][i]["TransformType"][:][0]}\n'
raise ValueError(msg)

# Warp field fixed parameters as defined in
# https://itk.org/Doxygen/html/classitk_1_1DisplacementFieldTransform.html
shape = transform2['TransformFixedParameters'][:3]
origin = transform2['TransformFixedParameters'][3:6]
spacing = transform2['TransformFixedParameters'][6:9]
direction = transform2['TransformFixedParameters'][9:].reshape((3, 3))

# We are not yet confident that we handle non-unit spacing
# or direction cosine ordering correctly.
# If we confirm or fix, we can remove these checks.
if not np.allclose(spacing, 1):
raise ValueError(f'Unexpected spacing: {spacing}')
if not np.allclose(direction, direction.T):
raise ValueError(f'Asymmetric direction matrix: {direction}')

# ITK uses LPS affines
lps_affine = compose_affine(T=origin, R=direction, Z=spacing)
ras_affine = np.diag([-1, -1, 1, 1]) @ lps_affine

# ITK stores warps in Fortran-order, where the vector components change fastest
# Vectors are in mm LPS
itk_warp = np.reshape(
transform2['TransformParameters'],
(3, *shape.astype(int)),
order='F',
)

# Nitransforms warps are in RAS, with the vector components changing slowest
nt_warp = itk_warp.transpose(1, 2, 3, 0) * np.array([-1, -1, 1])

transforms.insert(0, nt.DenseFieldTransform(nb.Nifti1Image(nt_warp, ras_affine)))
return nt.TransformChain(transforms)
7 changes: 5 additions & 2 deletions fmriprep/workflows/bold/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ def _conditional_downsampling(in_file, in_mask, zoom_th=4.0):
import nibabel as nb
import nitransforms as nt
import numpy as np
from nitransforms.resampling import apply as applyxfm
from scipy.ndimage.filters import gaussian_filter

img = nb.load(in_file)
Expand All @@ -756,14 +757,16 @@ def _conditional_downsampling(in_file, in_mask, zoom_th=4.0):
offset = old_center - newrot.dot((newshape - 1) * 0.5)
newaffine = nb.affines.from_matvec(newrot, offset)

identity = nt.Affine()

newref = nb.Nifti1Image(np.zeros(newshape, dtype=np.uint8), newaffine)
nt.Affine(reference=newref).apply(img).to_filename(out_file)
applyxfm(identity, img, reference=newref).to_filename(out_file)

mask = nb.load(in_mask)
mask.set_data_dtype(float)
mdata = gaussian_filter(mask.get_fdata(dtype=float), scaling)
floatmask = nb.Nifti1Image(mdata, mask.affine, mask.header)
newmask = nt.Affine(reference=newref).apply(floatmask)
newmask = applyxfm(identity, floatmask, reference=newref)
hdr = newmask.header.copy()
hdr.set_data_dtype(np.uint8)
newmaskdata = (newmask.get_fdata(dtype=float) > 0.5).astype(np.uint8)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies = [
"nipype >= 1.8.5",
"nireports >= 23.2.2",
"nitime",
"nitransforms >= 21.0.0, != 24.0.0",
"nitransforms >= 24.0.2",
"niworkflows >= 1.11.0",
"numpy >= 1.22",
"packaging",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ nireports==23.2.2
# via fmriprep (pyproject.toml)
nitime==0.10.2
# via fmriprep (pyproject.toml)
nitransforms==23.0.1
nitransforms==24.0.2
# via
# fmriprep (pyproject.toml)
# niworkflows
Expand Down