Skip to content

Commit

Permalink
ENH: Add GP estimation notebook
Browse files Browse the repository at this point in the history
Add GP estimation notebook.
  • Loading branch information
jhlegarreta committed Nov 5, 2024
1 parent 795a9b7 commit f549bae
Showing 1 changed file with 237 additions and 0 deletions.
237 changes: 237 additions & 0 deletions docs/notebooks/dwi_gp_estimation.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d11e5969ed6af8a5",
"metadata": {},
"source": [
"Estimate a DWI signal using the eddymotion Gaussian Process (GP) regressor estimator."
]
},
{
"cell_type": "markdown",
"id": "3476a8e9cfefd4b8",
"metadata": {},
"source": [
"Download the \"Sherbrooke 3-shell\" dataset using DIPY and select the b=1000 s/mm^2 shell data."
]
},
{
"cell_type": "code",
"id": "69a3bc6b4fbe7036",
"metadata": {
"jupyter": {
"is_executing": true
},
"ExecuteTime": {
"start_time": "2024-11-05T12:46:54.497856Z"
}
},
"source": [
"import dipy.data as dpd\n",
"import nibabel as nib\n",
"import numpy as np\n",
"from dipy.core.gradients import get_bval_indices\n",
"from dipy.io import read_bvals_bvecs\n",
"from dipy.segment.mask import median_otsu\n",
"\n",
"seed = 1234\n",
"rng = np.random.default_rng(seed)\n",
"\n",
"name = \"sherbrooke_3shell\"\n",
"\n",
"dwi_fname, bval_fname, bvec_fname = dpd.get_fnames(name=name)\n",
"dwi_data = nib.load(dwi_fname).get_fdata()\n",
"bvals, bvecs = read_bvals_bvecs(bval_fname, bvec_fname)\n",
"\n",
"_, brain_mask = median_otsu(dwi_data, vol_idx=[0])\n",
"\n",
"bval = 1000\n",
"indices = get_bval_indices(bvals, bval, tol=20)\n",
"\n",
"bvecs_shell = bvecs[indices]\n",
"shell_data = dwi_data[..., indices]"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"id": "9bd417117afaad49",
"metadata": {},
"source": [
"Visualize a slice of the data for a given DWI volume."
]
},
{
"cell_type": "code",
"id": "d8547475686958f3",
"metadata": {},
"source": [
"# Plot a slice\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"\n",
"import numpy as np\n",
"\n",
"dwi_vol_idx = len(indices) // 2\n",
"slice_idx = list(map(int, np.divide(dwi_data.shape[:-1], 2)))\n",
"\n",
"x_slice = dwi_data[slice_idx[0], :, :, dwi_vol_idx]\n",
"y_slice = dwi_data[:, slice_idx[1], :, dwi_vol_idx]\n",
"z_slice = dwi_data[:, :, slice_idx[2], dwi_vol_idx]\n",
"slices = [x_slice, y_slice, z_slice]\n",
"\n",
"fig, axes = plt.subplots(1, len(slices))\n",
"for i, _slice in enumerate(slices):\n",
" axes[i].imshow(_slice.T, cmap=\"gray\", origin=\"lower\", aspect='equal')\n",
" axes[i].set_axis_off()\n",
"\n",
"plt.show()"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"id": "9dcab811fe667617",
"metadata": {},
"source": [
"Define the EddyMotionGPR instance."
]
},
{
"cell_type": "code",
"id": "7d5d9562339bc849",
"metadata": {},
"source": [
"from eddymotion.model.gpr import EddyMotionGPR, SphericalKriging\n",
"\n",
"beta_a = 1.38\n",
"beta_l = 1 / 2.1\n",
"kernel = SphericalKriging(beta_a=beta_a, beta_l=beta_l)\n",
"\n",
"alpha = 0.1\n",
"disp = True\n",
"optimizer = None\n",
"gpr = EddyMotionGPR(kernel=kernel, alpha=alpha, disp=disp, optimizer=optimizer)\n"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"id": "ea5cc8036fa0ab48",
"metadata": {},
"source": [
"Do not optimize the parameters in the fitting. "
]
},
{
"cell_type": "code",
"id": "7e93b99c3b072d99",
"metadata": {},
"source": [
"X_train = bvecs_shell\n",
"# Consider only brain voxels\n",
"dwi_mask = np.repeat(brain_mask[..., np.newaxis], shell_data.shape[-1], axis=-1)\n",
"y = shell_data[dwi_mask].reshape((X_train.shape[0], -1))\n",
"gpr.fit(X_train, y)"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"id": "dfdd82afbdb22790",
"metadata": {},
"source": [
"Predict on a randomly chosen direction."
]
},
{
"cell_type": "code",
"id": "ae3407b31b14928d",
"metadata": {},
"source": [
"# Pick a direction to predict\n",
"idx = rng.integers(0, len(indices))\n",
"X_test = bvecs_shell[idx][np.newaxis, :]\n",
"y_pred = gpr.predict(X_test)\n",
"\n",
"rmse = np.sqrt(np.mean(np.square(y[idx, ...] - y_pred.squeeze())))\n",
"_rmse_element = np.sqrt(np.square(y[idx, ...] - y_pred.squeeze()))\n",
"\n",
"print(f\"RMSE: {rmse}\")\n",
"threshold = 10\n",
"n_error_thr = len(_rmse_element[_rmse_element > threshold])\n",
"ratio = n_error_thr / len(_rmse_element) * 100\n",
"print(f\"Number of RMSE values above {threshold}: {n_error_thr} ({ratio:.2f}%)\")"
],
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"id": "74b040c05621f2d9",
"metadata": {},
"source": [
"Visualize the prediction."
]
},
{
"cell_type": "code",
"id": "a130de2a03dff2b5",
"metadata": {},
"source": [
"# Reconstruct the data array\n",
"brain_mask_idx = np.where(brain_mask)\n",
"_y = np.zeros((shell_data.shape[:-1]), dtype=y.dtype)\n",
"_y[brain_mask_idx] = y_pred.squeeze()\n",
"\n",
"x_slice = _y[slice_idx[0], :, :]\n",
"y_slice = _y[:, slice_idx[1], :]\n",
"z_slice = _y[:, :, slice_idx[2]]\n",
"slices = [x_slice, y_slice, z_slice]\n",
"\n",
"fig, axes = plt.subplots(1, len(slices))\n",
"for i, _slice in enumerate(slices):\n",
" axes[i].imshow(_slice.T, cmap=\"gray\", origin=\"lower\", aspect='equal')\n",
" axes[i].set_axis_off()\n",
"\n",
"plt.show()"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "",
"id": "fae657ba6d3734a4",
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit f549bae

Please sign in to comment.