-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
1 changed file
with
237 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d11e5969ed6af8a5", | ||
"metadata": {}, | ||
"source": [ | ||
"Estimate a DWI signal using the eddymotion Gaussian Process (GP) regressor estimator." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3476a8e9cfefd4b8", | ||
"metadata": {}, | ||
"source": [ | ||
"Download the \"Sherbrooke 3-shell\" dataset using DIPY and select the b=1000 s/mm^2 shell data." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"id": "69a3bc6b4fbe7036", | ||
"metadata": { | ||
"jupyter": { | ||
"is_executing": true | ||
}, | ||
"ExecuteTime": { | ||
"start_time": "2024-11-05T12:46:54.497856Z" | ||
} | ||
}, | ||
"source": [ | ||
"import dipy.data as dpd\n", | ||
"import nibabel as nib\n", | ||
"import numpy as np\n", | ||
"from dipy.core.gradients import get_bval_indices\n", | ||
"from dipy.io import read_bvals_bvecs\n", | ||
"from dipy.segment.mask import median_otsu\n", | ||
"\n", | ||
"seed = 1234\n", | ||
"rng = np.random.default_rng(seed)\n", | ||
"\n", | ||
"name = \"sherbrooke_3shell\"\n", | ||
"\n", | ||
"dwi_fname, bval_fname, bvec_fname = dpd.get_fnames(name=name)\n", | ||
"dwi_data = nib.load(dwi_fname).get_fdata()\n", | ||
"bvals, bvecs = read_bvals_bvecs(bval_fname, bvec_fname)\n", | ||
"\n", | ||
"_, brain_mask = median_otsu(dwi_data, vol_idx=[0])\n", | ||
"\n", | ||
"bval = 1000\n", | ||
"indices = get_bval_indices(bvals, bval, tol=20)\n", | ||
"\n", | ||
"bvecs_shell = bvecs[indices]\n", | ||
"shell_data = dwi_data[..., indices]" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9bd417117afaad49", | ||
"metadata": {}, | ||
"source": [ | ||
"Visualize a slice of the data for a given DWI volume." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"id": "d8547475686958f3", | ||
"metadata": {}, | ||
"source": [ | ||
"# Plot a slice\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"%matplotlib inline\n", | ||
"\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"dwi_vol_idx = len(indices) // 2\n", | ||
"slice_idx = list(map(int, np.divide(dwi_data.shape[:-1], 2)))\n", | ||
"\n", | ||
"x_slice = dwi_data[slice_idx[0], :, :, dwi_vol_idx]\n", | ||
"y_slice = dwi_data[:, slice_idx[1], :, dwi_vol_idx]\n", | ||
"z_slice = dwi_data[:, :, slice_idx[2], dwi_vol_idx]\n", | ||
"slices = [x_slice, y_slice, z_slice]\n", | ||
"\n", | ||
"fig, axes = plt.subplots(1, len(slices))\n", | ||
"for i, _slice in enumerate(slices):\n", | ||
" axes[i].imshow(_slice.T, cmap=\"gray\", origin=\"lower\", aspect='equal')\n", | ||
" axes[i].set_axis_off()\n", | ||
"\n", | ||
"plt.show()" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9dcab811fe667617", | ||
"metadata": {}, | ||
"source": [ | ||
"Define the EddyMotionGPR instance." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"id": "7d5d9562339bc849", | ||
"metadata": {}, | ||
"source": [ | ||
"from eddymotion.model.gpr import EddyMotionGPR, SphericalKriging\n", | ||
"\n", | ||
"beta_a = 1.38\n", | ||
"beta_l = 1 / 2.1\n", | ||
"kernel = SphericalKriging(beta_a=beta_a, beta_l=beta_l)\n", | ||
"\n", | ||
"alpha = 0.1\n", | ||
"disp = True\n", | ||
"optimizer = None\n", | ||
"gpr = EddyMotionGPR(kernel=kernel, alpha=alpha, disp=disp, optimizer=optimizer)\n" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ea5cc8036fa0ab48", | ||
"metadata": {}, | ||
"source": [ | ||
"Do not optimize the parameters in the fitting. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"id": "7e93b99c3b072d99", | ||
"metadata": {}, | ||
"source": [ | ||
"X_train = bvecs_shell\n", | ||
"# Consider only brain voxels\n", | ||
"dwi_mask = np.repeat(brain_mask[..., np.newaxis], shell_data.shape[-1], axis=-1)\n", | ||
"y = shell_data[dwi_mask].reshape((X_train.shape[0], -1))\n", | ||
"gpr.fit(X_train, y)" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "dfdd82afbdb22790", | ||
"metadata": {}, | ||
"source": [ | ||
"Predict on a randomly chosen direction." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"id": "ae3407b31b14928d", | ||
"metadata": {}, | ||
"source": [ | ||
"# Pick a direction to predict\n", | ||
"idx = rng.integers(0, len(indices))\n", | ||
"X_test = bvecs_shell[idx][np.newaxis, :]\n", | ||
"y_pred = gpr.predict(X_test)\n", | ||
"\n", | ||
"rmse = np.sqrt(np.mean(np.square(y[idx, ...] - y_pred.squeeze())))\n", | ||
"_rmse_element = np.sqrt(np.square(y[idx, ...] - y_pred.squeeze()))\n", | ||
"\n", | ||
"print(f\"RMSE: {rmse}\")\n", | ||
"threshold = 10\n", | ||
"n_error_thr = len(_rmse_element[_rmse_element > threshold])\n", | ||
"ratio = n_error_thr / len(_rmse_element) * 100\n", | ||
"print(f\"Number of RMSE values above {threshold}: {n_error_thr} ({ratio:.2f}%)\")" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "74b040c05621f2d9", | ||
"metadata": {}, | ||
"source": [ | ||
"Visualize the prediction." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"id": "a130de2a03dff2b5", | ||
"metadata": {}, | ||
"source": [ | ||
"# Reconstruct the data array\n", | ||
"brain_mask_idx = np.where(brain_mask)\n", | ||
"_y = np.zeros((shell_data.shape[:-1]), dtype=y.dtype)\n", | ||
"_y[brain_mask_idx] = y_pred.squeeze()\n", | ||
"\n", | ||
"x_slice = _y[slice_idx[0], :, :]\n", | ||
"y_slice = _y[:, slice_idx[1], :]\n", | ||
"z_slice = _y[:, :, slice_idx[2]]\n", | ||
"slices = [x_slice, y_slice, z_slice]\n", | ||
"\n", | ||
"fig, axes = plt.subplots(1, len(slices))\n", | ||
"for i, _slice in enumerate(slices):\n", | ||
" axes[i].imshow(_slice.T, cmap=\"gray\", origin=\"lower\", aspect='equal')\n", | ||
" axes[i].set_axis_off()\n", | ||
"\n", | ||
"plt.show()" | ||
], | ||
"outputs": [], | ||
"execution_count": null | ||
}, | ||
{ | ||
"metadata": {}, | ||
"cell_type": "code", | ||
"source": "", | ||
"id": "fae657ba6d3734a4", | ||
"outputs": [], | ||
"execution_count": null | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |