Skip to content

Commit

Permalink
Merge pull request #228 from jhlegarreta/AddGPPredictionPlotScript
Browse files Browse the repository at this point in the history
ENH: Add a script to plot the signal estimated by the GP
  • Loading branch information
oesteban authored Oct 26, 2024
2 parents 1293454 + 8835e63 commit 33992b6
Show file tree
Hide file tree
Showing 4 changed files with 457 additions and 41 deletions.
160 changes: 160 additions & 0 deletions scripts/dwi_gp_estimation_analysis_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright The NiPreps Developers <nipreps@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#

"""
Plot the RMSE (mean and std dev) and prediction surface from the predicted DWI
signal estimated using Gaussian processes k-fold cross-validation.
"""

from __future__ import annotations

import argparse
from pathlib import Path

import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import pandas as pd
from dipy.core.gradients import gradient_table
from dipy.io import read_bvals_bvecs

from eddymotion.viz.signals import plot_error, plot_prediction_surface


def _build_arg_parser() -> argparse.ArgumentParser:
"""
Build argument parser for command-line interface.
Returns
-------
:obj:`~argparse.ArgumentParser`
Argument parser for the script.
"""
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawTextHelpFormatter
)
parser.add_argument(
"error_data_fname",
help="Filename of TSV file containing the error data to plot",
type=Path,
)
parser.add_argument(
"dwi_gt_data_fname",
help="Filename of NIfTI file containing the ground truth DWI signal",
type=Path,
)
parser.add_argument(
"bval_data_fname",
help="Filename of b-val file containing the diffusion-encoding gradient b-vals",
type=Path,
)
parser.add_argument(
"bvec_data_fname",
help="Filename of b-vecs file containing the diffusion-encoding gradient b-vecs",
type=Path,
)
parser.add_argument(
"dwi_pred_data_fname",
help="Filename of NIfTI file containing the predicted DWI signal",
type=Path,
)
parser.add_argument(
"error_plot_fname",
help="Filename of SVG file where the error plot will be saved",
type=Path,
)
parser.add_argument(
"signal_surface_plot_fname",
help="Filename of SVG file where the predicted signal plot will be saved",
type=Path,
)
return parser


def _parse_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
"""
Parse command-line arguments.
Parameters
----------
parser : :obj:`~argparse.ArgumentParser`
Argument parser for the script.
Returns
-------
:obj:`~argparse.Namespace`
Parsed arguments.
"""
return parser.parse_args()


def main() -> None:
"""Main function for running the experiment and plotting the results."""
parser = _build_arg_parser()
args = _parse_args(parser)

df = pd.read_csv(args.error_data_fname, sep="\t", keep_default_na=False, na_values="n/a")

# Plot the prediction error
kfolds = sorted(np.unique(df["n_folds"].values))
snr = np.unique(df["snr"].values).item()
rmse_data = [df.groupby("n_folds").get_group(k)["rmse"].values for k in kfolds]
axis = 1
mean = np.mean(rmse_data, axis=axis)
std_dev = np.std(rmse_data, axis=axis)
xlabel = "k"
ylabel = "RMSE"
title = f"Gaussian process estimation\n(SNR={snr})"
fig = plot_error(kfolds, mean, std_dev, xlabel, ylabel, title)
fig.savefig(args.error_plot_fname)
plt.close(fig)

# Plot the predicted DWI signal at a single voxel

# Load the dMRI data
signal = nib.load(args.dwi_gt_data_fname).get_fdata()
y_pred = nib.load(args.dwi_pred_data_fname).get_fdata()

bvals, bvecs = read_bvals_bvecs(str(args.bval_data_fname), str(args.bvec_data_fname))
gtab = gradient_table(bvals, bvecs)

# Pick one voxel randomly
rng = np.random.default_rng(1234)
idx = rng.integers(0, signal.shape[0], size=1).item()

title = "GP model signal prediction"
fig, _, _ = plot_prediction_surface(
signal[idx, ~gtab.b0s_mask],
y_pred[idx],
signal[idx, gtab.b0s_mask].item(),
gtab[~gtab.b0s_mask].bvecs,
gtab[~gtab.b0s_mask].bvecs,
title,
"gray",
)
fig.savefig(args.signal_surface_plot_fname, format="svg")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@

import argparse
from collections import defaultdict
from pathlib import Path

# import nibabel as nib
import numpy as np
import pandas as pd
from sklearn.model_selection import RepeatedKFold, cross_val_score
from sklearn.model_selection import KFold, RepeatedKFold, cross_val_predict, cross_val_score

from eddymotion.model._sklearn import (
EddyMotionGPR,
Expand All @@ -47,39 +47,31 @@ def cross_validate(
X: np.ndarray,
y: np.ndarray,
cv: int,
gpr: EddyMotionGPR,
) -> dict[int, list[tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]]]:
"""
Perform the experiment by estimating the dMRI signal using a Gaussian process model.
Parameters
----------
gtab : :obj:`~dipy.core.gradients.gradient_table`
Gradient table.
S0 : :obj:`float`
S0 value.
evals1 : :obj:`~numpy.ndarray`
Eigenvalues of the tensor.
evecs : :obj:`~numpy.ndarray`
Eigenvectors of the tensor.
snr : :obj:`float`
Signal-to-noise ratio.
X : :obj:`~numpy.ndarray`
Diffusion-encoding gradient vectors.
y : :obj:`~numpy.ndarray`
DWI signal.
cv : :obj:`int`
number of folds
gpr : obj:`~eddymotion.model._sklearn.EddyMotionGPR`
The eddymotion Gaussian process regressor object.
Returns
-------
:obj:`dict`
Data for the predicted signal and its error.
"""
gpm = EddyMotionGPR(
kernel=SphericalKriging(a=1.15, lambda_s=120),
alpha=100,
optimizer=None,
)

rkf = RepeatedKFold(n_splits=cv, n_repeats=120 // cv)
scores = cross_val_score(gpm, X, y, scoring="neg_root_mean_squared_error", cv=rkf)
scores = cross_val_score(gpr, X, y, scoring="neg_root_mean_squared_error", cv=rkf)
return scores


Expand All @@ -103,7 +95,32 @@ def _build_arg_parser() -> argparse.ArgumentParser:
)
parser.add_argument("bval_shell", help="Shell b-value", type=float)
parser.add_argument("S0", help="S0 value", type=float)
parser.add_argument("--evals1", help="Eigenvalues of the tensor", nargs="+", type=float)
parser.add_argument(
"error_data_fname",
help="Filename of TSV file containing the data to plot",
type=Path,
)
parser.add_argument(
"dwi_gt_data_fname",
help="Filename of NIfTI file containing the generated DWI signal",
type=Path,
)
parser.add_argument(
"bval_data_fname",
help="Filename of b-val file containing the diffusion-encoding gradient b-vals",
type=Path,
)
parser.add_argument(
"bvec_data_fname",
help="Filename of b-vecs file containing the diffusion-encoding gradient b-vecs",
type=Path,
)
parser.add_argument(
"dwi_pred_data_fname",
help="Filename of NIfTI file containing the predicted DWI signal",
type=Path,
)
parser.add_argument("--evals", help="Eigenvalues of the tensor", nargs="+", type=float)
parser.add_argument("--snr", help="Signal to noise ratio", type=float)
parser.add_argument("--repeats", help="Number of repeats", type=int, default=5)
parser.add_argument(
Expand Down Expand Up @@ -134,37 +151,60 @@ def main() -> None:
parser = _build_arg_parser()
args = _parse_args(parser)

n_voxels = 100

data, gtab = testsims.simulate_voxels(
args.S0,
args.evals1,
args.hsph_dirs,
bval_shell=args.bval_shell,
snr=args.snr,
n_voxels=100,
n_voxels=n_voxels,
evals=args.evals,
seed=None,
)

# Save the generated signal and gradient table
testsims.serialize_dmri(
data, gtab, args.dwi_gt_data_fname, args.bval_data_fname, args.bvec_data_fname
)

X = gtab[~gtab.b0s_mask].bvecs
y = data[:, ~gtab.b0s_mask]

snr_str = args.snr if args.snr is not None else "None"

a = 1.15
lambda_s = 120
alpha = 100
gpr = EddyMotionGPR(
kernel=SphericalKriging(a=a, lambda_s=lambda_s),
alpha=alpha,
optimizer=None,
)

# Use Scikit-learn cross validation
scores = defaultdict(list, {})
for n in args.kfold:
for i in range(args.repeats):
cv_scores = -1.0 * cross_validate(X, y.T, n)
cv_scores = -1.0 * cross_validate(X, y.T, n, gpr)
scores["rmse"] += cv_scores.tolist()
scores["repeat"] += [i] * len(cv_scores)
scores["n_folds"] += [n] * len(cv_scores)
scores["snr"] += [snr_str] * len(cv_scores)

print(f"Finished {n}-fold cross-validation")

scores_df = pd.DataFrame(scores)
scores_df.to_csv("cv_scores.tsv", sep="\t", index=None, na_rep="n/a")
scores_df.to_csv(args.error_data_fname, sep="\t", index=None, na_rep="n/a")

grouped = scores_df.groupby(["n_folds"])
print(grouped[["rmse"]].mean())
print(grouped[["rmse"]].std())

cv = KFold(n_splits=3, shuffle=False, random_state=None)
predictions = cross_val_predict(gpr, X, y.T, cv=cv)
testsims.serialize_dwi(predictions.T, args.dwi_pred_data_fname)


if __name__ == "__main__":
main()
Loading

0 comments on commit 33992b6

Please sign in to comment.