Skip to content

Commit

Permalink
Merge pull request #149 from jhlegarreta/SimplifyEddyMotionEstimatorFit
Browse files Browse the repository at this point in the history
ENH: Simplify `eddymotion.estimator.EddyMotionEstimator.fit`
  • Loading branch information
oesteban authored Apr 3, 2024
2 parents c8c6c0d + 18ba53f commit 2d71ba0
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 80 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ inline-quotes = "double"
"*/__init__.py" = ["F401"]
"docs/conf.py" = ["E265"]
"/^\\s*\\.\\. _.*?: http/" = ["E501"]
"src/eddymotion/estimator.py" = ["C901"]

[tool.ruff.format]
quote-style = "double"
Expand Down
317 changes: 238 additions & 79 deletions src/eddymotion/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,23 +76,18 @@ def fit(
Number of parallel jobs.
seed : :obj:`int` or :obj:`bool`
Seed the random number generator (necessary when we want deterministic
estimation). If an integer, the value is used to initialize the
generator; if ``True``, the arbitrary value of ``20210324`` is used
to initialize it.
estimation). See :func:`_sort_dwdata_indices`.
Return
------
affines : :obj:`list` of :obj:`numpy.ndarray`
A list of :math:`4 \times 4` affine matrices encoding the estimated
parameters of the deformations caused by head-motion and eddy-currents.
"""
align_kwargs = align_kwargs or {}

_seed = None
if seed or seed == 0:
_seed = 20210324 if seed is True else seed
align_kwargs = align_kwargs or {}

rng = np.random.default_rng(_seed)
index_order = _sort_dwdata_indices(seed, len(dwdata))

if "num_threads" not in align_kwargs and omp_nthreads is not None:
align_kwargs["num_threads"] = omp_nthreads
Expand All @@ -104,28 +99,9 @@ def fit(
)

# When downsampling these need to be set per-level
bmask_img = None
if dwdata.brainmask is not None:
_, bmask_img = mkstemp(suffix="_bmask.nii.gz")
nb.Nifti1Image(dwdata.brainmask.astype("uint8"), dwdata.affine, None).to_filename(
bmask_img
)
kwargs["mask"] = dwdata.brainmask

if hasattr(dwdata, "bzero") and dwdata.bzero is not None:
kwargs["S0"] = _advanced_clip(dwdata.bzero)

if hasattr(dwdata, "gradients"):
kwargs["gtab"] = dwdata.gradients

if hasattr(dwdata, "frame_time"):
kwargs["timepoints"] = dwdata.frame_time
bmask_img = _prepare_brainmask_data(dwdata.brainmask, dwdata.affine)

if hasattr(dwdata, "total_duration"):
kwargs["xlim"] = dwdata.total_duration

index_order = np.arange(len(dwdata))
rng.shuffle(index_order)
_prepare_kwargs(dwdata, kwargs)

single_model = model.lower() in (
"b0",
Expand All @@ -149,6 +125,7 @@ def fit(

with TemporaryDirectory() as tmp_dir:
print(f"Processing in <{tmp_dir}>")
ptmp_dir = Path(tmp_dir)
with tqdm(total=len(index_order), unit="dwi") as pbar:
# run a original-to-synthetic affine registration
for i in index_order:
Expand Down Expand Up @@ -179,62 +156,28 @@ def fit(
predicted = dwmodel.predict(data_test[1])

# prepare data for running ANTs
tmp_dir = Path(tmp_dir)
moving = tmp_dir / f"moving{i:05d}.nii.gz"
fixed = tmp_dir / f"fixed{i:05d}.nii.gz"
_to_nifti(data_test[0], dwdata.affine, moving)
_to_nifti(
predicted,
dwdata.affine,
fixed,
clip=reg_target_type == "dwi",
fixed, moving = _prepare_registration_data(
data_test[0], predicted, dwdata.affine, i, ptmp_dir, reg_target_type
)

pbar.set_description_str(
f"Pass {i_iter + 1}/{n_iter} | Realign b-index <{i}>"
)
registration = Registration(
terminal_output="file",
from_file=pkg_fn(
"eddymotion",
f"config/dwi-to-{reg_target_type}_level{i_iter}.json",
),
fixed_image=str(fixed.absolute()),
moving_image=str(moving.absolute()),
**align_kwargs,
)
if bmask_img:
registration.inputs.fixed_image_masks = ["NULL", bmask_img]

if dwdata.em_affines is not None and np.any(dwdata.em_affines[i, ...]):
reference = namedtuple("ImageGrid", ("shape", "affine"))(
shape=dwdata.dataobj.shape[:3], affine=dwdata.affine
)

# create a nitransforms object
if dwdata.fieldmap:
# compose fieldmap into transform
raise NotImplementedError
else:
initial_xform = Affine(
matrix=dwdata.em_affines[i], reference=reference
)
mat_file = tmp_dir / f"init_{i_iter}_{i:05d}.mat"
initial_xform.to_filename(mat_file, fmt="itk")
registration.inputs.initial_moving_transform = str(mat_file)

# execute ants command line
result = registration.run(cwd=str(tmp_dir)).outputs

# read output transform
xform = nt.linear.Affine(
nt.io.itk.ITKLinearTransform.from_filename(
result.forward_transforms[0]
).to_ras(reference=fixed, moving=moving),
)
# debugging: generate aligned file for testing
xform.apply(moving, reference=fixed).to_filename(
tmp_dir / f"aligned{i:05d}_{int(data_test[1][3]):04d}.nii.gz"
xform = _run_registration(
fixed,
moving,
bmask_img,
dwdata.em_affines,
dwdata.affine,
dwdata.dataobj.shape[:3],
data_test[1][3],
dwdata.fieldmap,
i_iter,
i,
ptmp_dir,
reg_target_type,
align_kwargs,
)

# update
Expand Down Expand Up @@ -295,3 +238,219 @@ def _to_nifti(data, affine, filename, clip=True):
nii.header.set_sform(affine, code=1)
nii.header.set_qform(affine, code=1)
nii.to_filename(filename)


def _sort_dwdata_indices(seed, dwi_vol_count):
"""Sort the DWI data volume indices.
Parameters
----------
seed : :obj:`int` or :obj:`bool`
Seed the random number generator. If an integer, the value is used to initialize the
generator; if ``True``, the arbitrary value of ``20210324`` is used to initialize it.
dwi_vol_count : :obj:`int`
Number of DWI volumes.
Returns
-------
index_order : :obj:`numpy.ndarray`
Index order.
"""

_seed = None
if seed or seed == 0:
_seed = 20210324 if seed is True else seed

rng = np.random.default_rng(_seed)

index_order = np.arange(dwi_vol_count)
rng.shuffle(index_order)

return index_order


def _prepare_brainmask_data(brainmask, affine):
"""Prepare the brainmask data: save the data to disk.
Parameters
----------
brainmask : :obj:`numpy.ndarray`
Brainmask data.
affine : :obj:`numpy.ndarray`
Affine transformation matrix.
Returns
-------
bmask_img : :class:`~nibabel.nifti1.Nifti1Image`
Brainmask image.
"""

bmask_img = None
if brainmask is not None:
_, bmask_img = mkstemp(suffix="_bmask.nii.gz")
nb.Nifti1Image(brainmask.astype("uint8"), affine, None).to_filename(bmask_img)
return bmask_img


def _prepare_kwargs(dwdata, kwargs):
"""Prepare the keyword arguments depending on the DWI data: add attributes corresponding to
the ``brainmask``, ``bzero``, ``gradients``, ``frame_time``, and ``total_duration`` DWI data
properties.
Modifies kwargs in-place.
Parameters
----------
dwdata : :class:`eddymotion.data.dmri.DWI`
DWI data object.
kwargs : :obj:`dict`
Keyword arguments.
"""

if dwdata.brainmask is not None:
kwargs["mask"] = dwdata.brainmask

if hasattr(dwdata, "bzero") and dwdata.bzero is not None:
kwargs["S0"] = _advanced_clip(dwdata.bzero)

if hasattr(dwdata, "gradients"):
kwargs["gtab"] = dwdata.gradients

if hasattr(dwdata, "frame_time"):
kwargs["timepoints"] = dwdata.frame_time

if hasattr(dwdata, "total_duration"):
kwargs["xlim"] = dwdata.total_duration


def _prepare_registration_data(dwframe, predicted, affine, vol_idx, dirname, reg_target_type):
"""Prepare the registration data: save the fixed and moving images to disk.
Parameters
----------
dwframe : :obj:`numpy.ndarray`
DWI data object.
predicted : :obj:`numpy.ndarray`
Predicted data.
affine : :obj:`numpy.ndarray`
Affine transformation matrix.
vol_idx : :obj:`int
DWI volume index.
dirname : :obj:`Path`
Directory name where the data is saved.
reg_target_type : :obj:`str`
Target registration type.
Returns
-------
fixed : :obj:`Path`
Fixed image filename.
moving : :obj:`Path`
Moving image filename.
"""

moving = dirname / f"moving{vol_idx:05d}.nii.gz"
fixed = dirname / f"fixed{vol_idx:05d}.nii.gz"
_to_nifti(dwframe, affine, moving)
_to_nifti(
predicted,
affine,
fixed,
clip=reg_target_type == "dwi",
)
return fixed, moving


def _run_registration(
fixed,
moving,
bmask_img,
em_affines,
affine,
shape,
bval,
fieldmap,
i_iter,
vol_idx,
dirname,
reg_target_type,
align_kwargs,
):
"""Register the moving image to the fixed image.
Parameters
----------
fixed : :obj:`Path`
Fixed image filename.
moving : :obj:`Path`
Moving image filename.
bmask_img : :class:`~nibabel.nifti1.Nifti1Image`
Brainmask image.
em_affines : :obj:`numpy.ndarray`
Estimated eddy motion affine transformation matrices.
affine : :obj:`numpy.ndarray`
Affine transformation matrix.
shape : :obj:`tuple`
Shape of the DWI frame.
bval : :obj:`int`
b-value of the corresponding DWI volume.
fieldmap : :class:`~nibabel.nifti1.Nifti1Image`
Fieldmap.
i_iter : :obj:`int`
Iteration number.
vol_idx : :obj:`int`
DWI frame index.
dirname : :obj:`Path`
Directory name where the transformation is saved.
reg_target_type : :obj:`str`
Target registration type.
align_kwargs : :obj:`dict`
Parameters to configure the image registration process.
Returns
-------
xform : :class:`~nitransforms.linear.Affine`
Registration transformation.
"""

registration = Registration(
terminal_output="file",
from_file=pkg_fn(
"eddymotion",
f"config/dwi-to-{reg_target_type}_level{i_iter}.json",
),
fixed_image=str(fixed.absolute()),
moving_image=str(moving.absolute()),
**align_kwargs,
)
if bmask_img:
registration.inputs.fixed_image_masks = ["NULL", bmask_img]

if em_affines is not None and np.any(em_affines[vol_idx, ...]):
reference = namedtuple("ImageGrid", ("shape", "affine"))(shape=shape, affine=affine)

# create a nitransforms object
if fieldmap:
# compose fieldmap into transform
raise NotImplementedError
else:
initial_xform = Affine(matrix=em_affines[vol_idx], reference=reference)
mat_file = dirname / f"init_{i_iter}_{vol_idx:05d}.mat"
initial_xform.to_filename(mat_file, fmt="itk")
registration.inputs.initial_moving_transform = str(mat_file)

# execute ants command line
result = registration.run(cwd=str(dirname)).outputs

# read output transform
xform = nt.linear.Affine(
nt.io.itk.ITKLinearTransform.from_filename(result.forward_transforms[0]).to_ras(
reference=fixed, moving=moving
),
)
# debugging: generate aligned file for testing
xform.apply(moving, reference=fixed).to_filename(
dirname / f"aligned{vol_idx:05d}_{int(bval):04d}.nii.gz"
)

return xform

0 comments on commit 2d71ba0

Please sign in to comment.