-
Notifications
You must be signed in to change notification settings - Fork 1
/
run.py
149 lines (120 loc) · 7.77 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# -*- coding: utf-8 -*-
"""
Created on Thu Jul 28 22:31:20 2022
@author: Amin
"""
from causality import causality_indices as ci
from causality import interventional as intcnn
from causality import helpers as inth
from causality import granger
from delay_embedding import ccm
import visualizations as viz
import data_loader
import numpy as np
import argparse
import yaml
import os
# %%
def get_args():
'''Parsing the arguments when this file is called from console
'''
parser = argparse.ArgumentParser(description='Runner for CCM',formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--config', '-c', metavar='Configuration',help='Configuration file address',default='/')
parser.add_argument('--output', '-o', metavar='Output',help='Folder to save the output and results',default='/')
return parser.parse_args()
# %%
if __name__ == '__main__':
'''Read the arguments and run given the configuration in args
'''
args = get_args()
if not os.path.exists(args.output): os.makedirs(args.output)
with open(args.config, 'r') as stream: pm = yaml.safe_load(stream)
# %%
dl = eval('data_loader.'+pm['dataset'])(pm,save=~pm['load'],load=pm['load'],file=args.output)
if 'visualize_adjacency' in pm['visualizations']: viz.visualize_adjacency(dl.network.pm['J'],fontsize=pm['fontsize'],save=True,file=args.output+'cnn')
# %%
r,t,out = dl.load_rest(pm)
mask = dl.mask
if 'visualize_rates' in pm['visualizations']: viz.visualize_signals(t,[r.T],['Rest Rates'],t_range=(min(t),max(t)),fontsize=pm['fontsize'],save=True,file=args.output+'rest_rates')
if 'visualize_voltages' in pm['visualizations']: viz.visualize_signals(out['t'],[out['x'].T],['Rest Voltages'],t_range=(min(t),max(t)),fontsize=pm['fontsize'],save=True,file=args.output+'rest_voltages')
if 'visualize_spikes' in pm['visualizations']: viz.visualize_spikes([out['spikes_flat']],['Rest Spikes'],t_range=(min(t),max(t)),fontsize=pm['fontsize'],distinct_colors=True,distinction_point=pm['distinction_point'],save=True,file=args.output+'rest_spikes')
indices,indices_pval = {},{}
if 'gc' in pm['indices']: indices['gc'],indices_pval['gc'] = granger.univariate_gc(r.T,maxlag=pm['max_lag'],mask=mask,load=pm['load'],save=True,file=args.output+'gc.npy')
if 'te' in pm['indices']: indices['te'],indices_pval['te'] = ci.transfer_entropy_ksg(r.T,mask=mask,load=pm['load'],save=True,file=args.output+'te.npy')
if 'egc' in pm['indices']: indices['egc'],indices_pval['egc'] = ci.extended_granger_causality(r.T,mask=mask,mx=pm['mx'],my=pm['my'],L=pm['L'],delta=pm['delta'],load=pm['load'],save=True,file=args.output+'egc.npy')
if 'ngc' in pm['indices']: indices['ngc'],indices_pval['ngc'] = ci.nonlinear_granger_causality(r.T,mask=mask,mx=pm['mx'],my=pm['my'],load=pm['load'],save=True,file=args.output+'ngc.npy')
if 'mgc' in pm['indices']: indices['mgc'],indices_pval['mgc'] = granger.multivariate_gc(r.T,maxlag=pm['max_lag'],mask=mask,load=pm['load'],save=True,file=args.output+'mgc.npy')
if 'fcf' in pm['indices']: indices['fcf'],indices_pval['fcf'],_ = ccm.connectivity(r,mask=mask,test_ratio=pm['test_ratio'],delay=pm['tau'],dim=pm['D'],n_neighbors=pm['n_neighbors'],return_pval=True,n_surrogates=pm['n_surrogates'],load=pm['load'],save=True,file=args.output+'fcf.npy')
# %%
if 'ic' in pm['indices']:
r,t,out = dl.load_stim(pm)
mask = dl.mask
if 'I' in out.keys():
if 'visualize_rates' in pm['visualizations']: viz.visualize_signals(t,[r.T],['Stim Rates'],t_range=(min(t),max(t)),stim=out['I'][:,dl.recorded],stim_t=out['t_stim'],fontsize=pm['fontsize'],save=True,file=args.output+'stim_rates')
if 'visualize_voltages' in pm['visualizations']: viz.visualize_signals(out['t'],[out['x'].T],['Stim Voltages'],t_range=(min(t),max(t)),stim=out['I'][:,dl.recorded],stim_t=out['t_stim'],fontsize=pm['fontsize'],save=True,file=args.output+'stim_voltages')
if 'visualize_spikes' in pm['visualizations']: viz.visualize_spikes([out['spikes_flat']],['Stim Spikes'],t_range=(min(t),max(t)),stim=out['I'],stim_t=out['t_stim'],distinct_colors=True,distinction_point=pm['distinction_point'],fontsize=pm['fontsize'],save=True,file=args.output+'stim_spikes')
if 'visualize_stim_protocol' in pm['visualizations']: viz.visualize_stim_protocol(out['I'],min(t),max(t),pm['N'],fontsize=pm['fontsize'],save=True,file=args.output+'stim_protocol')
stim_s = np.where(np.diff(out['I'][:,dl.recorded].T,axis=1) > 0)
stim_e = np.where(np.diff(out['I'][:,dl.recorded].T,axis=1) < 0)
# Stimulation duration
stim_d = [out['t_stim'][stim_e[1][i]] - out['t_stim'][stim_e[1][i]] for i in range(len(stim_e[1]))]
# Stimulation array [(chn,start,end),...]
stim_info = [(stim_s[0][i], out['t_stim'][stim_s[1][i]], out['t_stim'][stim_e[1][i]]) for i in range(len(stim_e[1]))]
if 'stim_info' in out.keys():
stim_info = out['stim_info']
indices['ic'],indices_pval['ic'] = intcnn.interventional_connectivity(
r.T,
stim_info,
t=t,
mask=mask,
bin_size=pm['bin_size'],
skip_pre=pm['skip_pre'],
skip_pst=pm['skip_pst'],
method=pm['intcnn_method'],
load=pm['load'],
save=True,file=args.output+'ic.npy'
)
if 'visualize_scatters' in pm['visualizations']:
viz.visualize_scatters(
[index[~mask] for index in indices.values()],
[indices['ic'][~mask] for index in indices.values()],
[indices_pval['ic'][~mask]<pm['pval_thresh'] for index in indices.values()],
xlabel=list(indices.keys()),
ylabel=[pm['intcnn_method'] for i in range(len(indices.keys()))],
titlestr='Functional vs. Interventional Correlation',
fontsize=pm['fontsize'],
save=True,file=args.output+'stim_rest_corr'
)
if 'visualize_bars' in pm['visualizations']:
viz.visualize_bars(
[index[~mask] for index in indices.values()],
[indices_pval['ic'][~mask]<pm['pval_thresh'] for index in indices.values()],
titlestr=list(indices.keys()),
fontsize=pm['fontsize'],
save=True,file=args.output+'stim_rest_bar'
)
if 'visualize_cnn_physical_layout' in pm['visualizations']:
for key in indices.keys():
for ch in dl.stimulated_recorded:
viz.visualize_cnn_physical_layout(
dl.layout,
indices[key][:,ch],
indices_pval[key][:,ch]<=pm['pval_thresh'],
cmap=pm['cmap_'+key],
titlestr=key,fontsize=pm['fontsize'],
save=True,file=args.output+'layout_'+key+'_'+str(ch)
)
if 'plot_index_vs_distance' in pm['visualizations']:
for key in indices.keys():
for ch in dl.stimulated_recorded:
viz.plot_index_vs_distance(
dl.layout,
indices[key][:,ch],
titlestr=key,fontsize=pm['fontsize'],
save=True,file=args.output+'index_v_dist_'+key+'_'+str(ch)
)
for key in indices.keys():
viz.visualize_cnn(
indices[key],indices_pval[key]<=pm['pval_thresh'],titlestr=key,
cmap=pm['cmap_'+key],fontsize=pm['fontsize'],save=True,file=args.output+key
)