forked from KirstieJane/DTI_PROCESSING
-
Notifications
You must be signed in to change notification settings - Fork 0
/
calculate_connectivity_matrix.py
executable file
·288 lines (220 loc) · 10.1 KB
/
calculate_connectivity_matrix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
#!/usr/bin/env python
"""
Name: calculate_connectivity_matrix.py
Created by: Kirstie Whitaker
kw401@cam.ac.uk
"""
#=============================================================================
# IMPORTS
#=============================================================================
import os
import sys
from glob import glob
import argparse
import numpy as np
import matplotlib.pylab as plt
import nibabel as nib
import dipy.reconst.dti as dti
from dipy.reconst.dti import quantize_evecs
from dipy.data import get_sphere
from dipy.io import read_bvals_bvecs
from dipy.core.gradients import gradient_table
from dipy.tracking.eudx import EuDX
from dipy.reconst import peaks, shm
from dipy.tracking import utils
try:
from dipy.viz import fvtk
except ImportError:
raise ImportError('Python vtk module is not installed')
from dipy.viz.colormap import line_colors
import networkx as nx
import matplotlib.colors as colors
from condition_seeds import condition_seeds
#=============================================================================
# FUNCTIONS
#=============================================================================
# Set up the argparser so you can read arguments from the command line
def setup_argparser():
'''
# Code to read in arguments from the command line
# Also allows you to change some settings
'''
# Build a basic parser.
help_text = ('Create a connectivity matrix from a diffusion weighted dataset')
sign_off = 'Author: Kirstie Whitaker <kw401@cam.ac.uk>'
parser = argparse.ArgumentParser(description=help_text, epilog=sign_off)
# Now add the arguments
# Required argument: dti_dir
parser.add_argument(dest='dti_dir',
type=str,
metavar='dti_dir',
help='DTI directory')
# Required argument: parcellation_file
parser.add_argument(dest='parcellation_file',
type=str,
metavar='parcellation_file',
help='Parcellation filename')
# Required argument: white_matter_file
parser.add_argument(dest='white_matter_file',
type=str,
metavar='white_matter_file',
help='White matter filename')
arguments = parser.parse_args()
return arguments, parser
#-----------------------------------------------------------------------------
def save_mat(M, M_text_name):
# Save the matrix as a text file
if not os.path.exists(M_text_name):
np.savetxt(M_text_name,
M[1:,1:],
fmt='%.5f',
delimiter='\t',
newline='\n')
#-----------------------------------------------------------------------------
def save_png(M, M_fig_name):
# Make a png image of the matrix
if not os.path.exists(M_fig_name):
fig, ax = plt.subplots(figsize=(4,4))
# Plot the matrix on a log scale
axM = ax.imshow(np.log1p(M[1:,1:]),
interpolation='nearest',
cmap='jet')
# Add a colorbar
cbar = fig.colorbar(axM)
fig.savefig(M_fig_name, bbox_inches=0, dpi=600)
#=============================================================================
# Define some variables
#=============================================================================
# Read in the arguments from argparse
arguments, parser = setup_argparser()
dti_dir = arguments.dti_dir
parcellation_file = arguments.parcellation_file
wm_file = arguments.white_matter_file
if not os.path.exists(parcellation_file):
parcellation_file = os.path.join(dti_dir, parcellation_file)
# Check that the inputs exist:
if not os.path.isdir(dti_dir):
print "DTI directory doesn't exist"
sys.exit()
if not os.path.exists(parcellation_file):
print "Parcellation file doesn't exist"
sys.exit()
if not os.path.exists(wm_file):
print "White matter file doesn't exist"
sys.exit()
# Define the output directory and make it if it doesn't yet exist
connectivity_dir = os.path.join(dti_dir, 'CONNECTIVITY')
if not os.path.isdir(connectivity_dir):
os.makedirs(connectivity_dir)
# Now define a couple of variables
dwi_file = os.path.join(dti_dir, 'dti_ec.nii.gz')
mask_file = os.path.join(dti_dir, 'dti_ec_brain.nii.gz')
bvals_file = os.path.join(dti_dir, 'bvals')
bvecs_file = os.path.join(dti_dir, 'bvecs')
Msym_file = os.path.join(connectivity_dir, 'Msym.txt')
Mdir_file = os.path.join(connectivity_dir, 'Mdir.txt')
#=============================================================================
# Load in the data
#=============================================================================
print 'PARCELLATION FILE: {}'.format(parcellation_file)
dwi_img = nib.load(dwi_file)
dwi_data = dwi_img.get_data()
mask_img = nib.load(mask_file)
mask_data = mask_img.get_data().astype(np.int)
mask_data_bin = np.copy(mask_data)
mask_data_bin[mask_data_bin > 0] = 1
wm_img = nib.load(wm_file)
wm_data = wm_img.get_data()
wm_data_bin = np.copy(wm_data)
wm_data_bin[wm_data_bin > 0] = 1
# Mask the dwi_data so that you're only investigating voxels inside the brain!
dwi_data = dwi_data * wm_data_bin.reshape([wm_data_bin.shape[0],
wm_data_bin.shape[1],
wm_data_bin.shape[2],
1])
parcellation_img = nib.load(parcellation_file)
parcellation_data = parcellation_img.get_data().astype(np.int)
wm_img = nib.load(wm_file)
wm_data = wm_img.get_data()
bvals, bvecs = read_bvals_bvecs(bvals_file, bvecs_file)
gtab = gradient_table(bvals, bvecs)
mask_data_bin[mask_data_bin > 0] = 1
wm_data_bin = np.copy(wm_data)
wm_data_bin[wm_data_bin > 0] = 1
parcellation_data = parcellation_data * mask_data_bin
parcellation_wm_data = parcellation_data * wm_data_bin
parcellation_wm_data = parcellation_wm_data.astype(np.int)
#=============================================================================
# Track all of white matter using EuDX
#=============================================================================
if not os.path.exists(Msym_file) and not os.path.exists(Mdir_file):
print '\tCalculating peaks'
csamodel = shm.CsaOdfModel(gtab, 6)
csapeaks = peaks.peaks_from_model(model=csamodel,
data=dwi_data,
sphere=peaks.default_sphere,
relative_peak_threshold=.8,
min_separation_angle=45,
mask=wm_data_bin)
print '\tTracking'
seeds = utils.seeds_from_mask(parcellation_wm_data, density=2)
condition_seeds = condition_seeds(seeds, np.eye(4), csapeaks.peak_values.shape[:3])
streamline_generator = EuDX(csapeaks.peak_values, csapeaks.peak_indices,
odf_vertices=peaks.default_sphere.vertices,
a_low=.05, step_sz=.5, seeds=condition_seeds)
affine = streamline_generator.affine
streamlines = list(streamline_generator)
else:
print '\tTracking already complete'
#=============================================================================
# Create two connectivity matrices - symmetric and directional
#=============================================================================
if not os.path.exists(Msym_file) and not os.path.exists(Mdir_file):
print '\tCreating Connectivity Matrix'
Msym, grouping = utils.connectivity_matrix(streamlines, parcellation_wm_data,
affine=affine,
return_mapping=True,
symmetric=True,
mapping_as_streamlines=True)
Mdir, grouping = utils.connectivity_matrix(streamlines, parcellation_wm_data,
affine=affine,
return_mapping=True,
symmetric=False,
mapping_as_streamlines=True)
else:
Msym = np.loadtxt(Msym_file)
Mdir = np.loadtxt(Mdir_file)
# Calculate the difference the two directions
Mdiff = Mdir - Mdir.T
Mdiff[Mdiff<0] = 0
#=============================================================================
# Save the connectivity matrices as text files, and as figures
#=============================================================================
print '\tMaking Pictures'
for M, name in zip([Msym, Mdir, Mdiff], ['Msym', 'Mdir', 'Mdiff']):
# Save the matrix as a text file
M_text_name = os.path.join(connectivity_dir, '{}.txt'.format(name))
save_mat(M, M_text_name)
# Make a png image of the matrix
M_fig_name = os.path.join(connectivity_dir, '{}.png'.format(name))
save_png(M, M_fig_name)
# Save an image of all three matrices
fig_name = os.path.join(connectivity_dir, 'AllMatrices.png')
if not os.path.exists(fig_name):
# Now make the plot of all three figures
fig, ax = plt.subplots(1,3, figsize=(12, 4))
M0 = ax[0].imshow(np.log1p(Msym[1:,1:]), interpolation='nearest', cmap='jet',
vmin=0, vmax=np.log1p(1000))
M1 = ax[1].imshow(np.log1p(Mdir[1:,1:]), interpolation='nearest', cmap='jet',
vmin=0, vmax=np.log1p(1000))
M2 = ax[2].imshow(np.log1p(Mdiff[1:,1:]), interpolation='nearest', cmap='jet',
vmin=0, vmax=np.log1p(1000))
ax[0].set_title('Symmetric')
ax[1].set_title('Directed')
ax[2].set_title('Difference\nA --> B and B --> A')
plt.tight_layout()
fig.savefig(fig_name, bbox_inches=0, dpi=600)
#------------------------------------------------
### THE END ###
# Today is April 3rd and the sun in shining in Cambridge
#------------------------------------------------