diff --git a/doc/getting_started/how_to_contribute.md b/doc/getting_started/how_to_contribute.md new file mode 100644 index 00000000..41bab50e --- /dev/null +++ b/doc/getting_started/how_to_contribute.md @@ -0,0 +1,47 @@ +# Contributing to WhoBPyT Codebase +Authors: Andrew Clappison, Kevin Kadak, John Griffiths + +The below instuctions outline how to properly adhere to version control for contributing to the WhoBPyT repo. + +## Setting Up (Done Once): + +- **Downloading your Fork** + - Must have already configured an authentication key and have forked the repository on github.com; ensure your fork is up-to-date with the whobpyt/dev branch from which your new branch will be created. + - Open terminal and go to the desired directory. + - `git clone git@github.com:<>/whobpyt.git` + +- **Adding Upstream** + - `cd whobpyt` + - `git remote add upstream https://github.com/GriffithsLab/whobpyt.git` + - `git fetch upstream` + +## Coding Cycle (Done for each new feature or group of features): + +- **Creating a New Branch** + - `git fetch upstream` + - `git checkout --track upstream/dev` + - `git push origin dev` + - `git checkout -b <>` + +- **Editing Code** + - Add/Delete/Edit code + +- **Testing (WhoBPyT Sphinx Examples should run successfully on Linux, but may fail to run on Windows)** + - Optionally: Rename sphinx examples ending in “r” to “x” if it is not relevant to the code changes done (for quicker debugging). Example: “eg001r...” to “eg001x...”. + - `cd doc` + - `make clean` + - `make html` + - Open and inspect in a web browser: whobpyt/doc/_build/html/html.txt + - Additional other testing may also be advised. + +- **Committing Code** + - `git status` + - `git add <>` + - `git commit -m “<>”` + +- **Pushing Code** + - `git push --set-upstream origin <>` + +- **Creating a pull request** + - On github.com do a pull request from the new branch on your fork to the main repo’s dev branch. If there is a merging conflict, it will have to be addressed before proceeding. + diff --git a/doc/index.rst b/doc/index.rst index 4fd6d650..3166fc0c 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -16,7 +16,7 @@ Welcome to WhoBPyT documentation! :maxdepth: 1 :caption: Getting Started :glob: - + getting_started/installation getting_started/running_in_colab getting_started/running_in_codespaces @@ -36,9 +36,9 @@ Welcome to WhoBPyT documentation! :maxdepth: 1 :caption: Examples :glob: - + auto_examples/eg002r__multimodal_simulation auto_examples/eg003r__fitting_rww_example auto_examples/eg004r__fitting_JR_example auto_examples/eg005r__gpu_support - + auto_examples/eg006r__replicate_Momi2023 diff --git a/examples/eg006r__replicate_Momi2023.py b/examples/eg006r__replicate_Momi2023.py new file mode 100644 index 00000000..462431f8 --- /dev/null +++ b/examples/eg006r__replicate_Momi2023.py @@ -0,0 +1,193 @@ +# -*- coding: utf-8 -*- +r""" +================================= +Replicate Momi et al. (2023): TMS-evoked Responses +=========================================== + +This script replicates the findings of the paper: + +Momi, D., Wang, Z., Griffiths, J.D. (2023). +"TMS-evoked responses are driven by recurrent large-scale network dynamics." +eLife, [doi: 10.7554/eLife.83232](https://elifesciences.org/articles/83232) + +The code includes data fetching, model fitting, and result visualization based on the methods presented in the paper. + +""" + + +# sphinx_gallery_thumbnail_number = 1 +# +# %% +# Importage +# -------------------------------------------------- + +# whobpyt stuff +import whobpyt +from whobpyt.datatypes import par, Recording +from whobpyt.models.JansenRit import RNNJANSEN, ParamsJR +from whobpyt.run import Model_fitting +from whobpyt.optimization.custom_cost_JR import CostsJR + +# python stuff +import numpy as np +import pandas as pd +import scipy.io +import gdown +import pickle +import warnings +warnings.filterwarnings('ignore') + +#neuroimaging packages +import mne + +# viz stuff +import matplotlib.pyplot as plt + + + +# %% +# Download and load necessary data for the example +url='https://drive.google.com/drive/folders/1Qu-JyZc3-SL-Evsystg4D-DdpsGU4waB?usp=sharing' +gdown.download_folder(url, quiet=True) + + +# %% +# Load EEG data from a file +file_name = './data/Subject_1_low_voltage.fif' +epoched = mne.read_epochs(file_name, verbose=False); +evoked = epoched.average() + +# %% +# Load Atlas +url = 'https://raw.githubusercontent.com/ThomasYeoLab/CBIG/master/stable_projects/brain_parcellation/Schaefer2018_LocalGlobal/Parcellations/MNI/Centroid_coordinates/Schaefer2018_200Parcels_7Networks_order_FSLMNI152_2mm.Centroid_RAS.csv' +atlas = pd.read_csv(url) +labels = atlas['ROI Name'] +coords = np.array([atlas['R'], atlas['A'], atlas['S']]).T +conduction_velocity = 5 #in ms + +# %% +# Compute the distance matrix +dist = np.zeros((coords.shape[0], coords.shape[0])) + +for roi1 in range(coords.shape[0]): + for roi2 in range(coords.shape[0]): + dist[roi1, roi2] = np.sqrt(np.sum((coords[roi1,:] - coords[roi2,:])**2, axis=0)) + dist[roi1, roi2] = np.sqrt(np.sum((coords[roi1,:] - coords[roi2,:])**2, axis=0)) + + +# %% +# Load the stim weights matrix which encode where to inject the external input +stim_weights = np.load('./data/stim_weights.npy') +stim_weights_thr = stim_weights.copy() +labels[np.where(stim_weights_thr>0)[0]] + +# %% +# Load the structural connectivity matrix +sc_file = './data/Schaefer2018_200Parcels_7Networks_count.csv' +sc_df = pd.read_csv(sc_file, header=None, sep=' ') +sc = sc_df.values +sc = np.log1p(sc) / np.linalg.norm(np.log1p(sc)) + +# %% +# Load the leadfield matrix +lm = np.load('./data/Subject_1_low_voltage_lf.npy') +ki0 =stim_weights_thr[:,np.newaxis] +delays = dist/conduction_velocity + +# %% +# define options for JR model +eeg_data = evoked.data.copy() +time_start = np.where(evoked.times==-0.1)[0][0] +time_end = np.where(evoked.times==0.3)[0][0] +eeg_data = eeg_data[:,time_start:time_end]/np.abs(eeg_data).max()*4 +node_size = sc.shape[0] +output_size = eeg_data.shape[0] +batch_size = 20 +step_size = 0.0001 +num_epoches = 120 +tr = 0.001 +state_size = 6 +base_batch_num = 20 +time_dim = 400 +state_size = 6 +base_batch_num = 20 +hidden_size = int(tr/step_size) + + +# %% +# prepare data structure of the model +data_mean = Recording(eeg_data, num_epoches, batch_size) + +# %% +# get model parameters structure and define the fitted parameters by setting non-zero variance for the model +lm = np.zeros((output_size,200)) +lm_v = np.zeros((output_size,200)) +params = ParamsJR(A = par(3.25), a= par(100,100, 2, True, True), B = par(22), b = par(50, 50, 1, True, True), + g=par(500,500,2, True, True), g_f=par(10,10,1, True, True), g_b=par(10,10,1, True, True), + c1 = par(135, 135, 1, True, True), c2 = par(135*0.8, 135*0.8, 1, True, True), c3 = par(135*0.25, 135*0.25, 1, True, True), + c4 = par(135*0.25, 135*0.25, 1, True, True), std_in= par(0,0, 1, True, True), vmax= par(5), v0=par(6), r=par(0.56), + y0=par(-2, -2, 1/4, True, True),mu = par(1., 1., 0.4, True, True), k =par(5., 5., 0.2, True, True), k0=par(0), + cy0 = par(50, 50, 1, True, True), ki=par(ki0), lm=par(lm, lm, 1 * np.ones((output_size, node_size))+lm_v, True, True)) + + +# %% +# call model want to fit +model = RNNJANSEN(node_size, batch_size, step_size, output_size, tr, sc, lm, dist, True, False, params) + + +# create objective function +ObjFun = CostsJR(model) + + +# %% +# call model fit +F = Model_fitting(model, ObjFun) + +# %% +# model training +u = np.zeros((node_size,hidden_size,time_dim)) +u[:,:,80:120]= 1000 +F.train(u=u, empRecs = [data_mean], num_epochs = num_epoches, TPperWindow = batch_size) + +# %% +# model test with 20 window for warmup +F.evaluate(u = u, empRec = data_mean, TPperWindow = batch_size, base_window_num = 20) + +# filename = 'Subject_1_low_voltage_fittingresults_stim_exp.pkl' +# with open(filename, 'wb') as f: +# pickle.dump(F, f) + +# %% +# Plot the original and simulated EEG data +epoched = mne.read_epochs(file_name, verbose=False); +evoked = epoched.average() +ts_args = dict(xlim=[-0.1,0.3]) +ch, peak_locs1 = evoked.get_peak(ch_type='eeg', tmin=-0.05, tmax=0.01) +ch, peak_locs2 = evoked.get_peak(ch_type='eeg', tmin=0.01, tmax=0.02) +ch, peak_locs3 = evoked.get_peak(ch_type='eeg', tmin=0.03, tmax=0.05) +ch, peak_locs4 = evoked.get_peak(ch_type='eeg', tmin=0.07, tmax=0.15) +ch, peak_locs5 = evoked.get_peak(ch_type='eeg', tmin=0.15, tmax=0.20) +times = [peak_locs1, peak_locs2, peak_locs3, peak_locs4, peak_locs5] +plot = evoked.plot_joint(ts_args=ts_args, times=times); + + +simulated_EEG_st = evoked.copy() +simulated_EEG_st.data[:,time_start:time_end] = F.lastRec['eeg'].npTS() +times = [peak_locs1, peak_locs2, peak_locs3, peak_locs4, peak_locs5] +simulated_joint_st = simulated_EEG_st.plot_joint(ts_args=ts_args, times=times) + + +# %% +# Results Description +# --------------------------------------------------- +# + +# The plot above shows the original EEG data and the simulated EEG data using the fitted Jansen-Rit model. +# The simulated data closely resembles the original EEG data, indicating that the model fitting was successful. +# Peak locations extracted from different time intervals are marked on the plots, demonstrating the model's ability +# to capture key features of the EEG signal. + +# %% +# Reference: +# Momi, D., Wang, Z., Griffiths, J.D. (2023). "TMS-evoked responses are driven by recurrent large-scale network dynamics." +# eLife, 10.7554/eLife.83232. https://doi.org/10.7554/eLife.83232 diff --git a/requirements.txt b/requirements.txt index 54a3b260..93c92828 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,13 +12,12 @@ mne # Docs requirements -sphinx -#sphinx==3.1.1 -sphinx-gallery==0.8.1 -sphinx_rtd_theme==0.5.0 -sphinx-tabs==1.3.0 -sphinx-copybutton==0.3.1 -sphinxcontrib-httpdomain==1.7.0 +sphinx==5.0.0 +sphinx-gallery==0.15.0 +sphinx_rtd_theme==2.0.0 +sphinx-tabs==3.4.4 +sphinx-copybutton==0.5.2 +sphinxcontrib-httpdomain==1.8.1 numpydoc==1.1.0 recommonmark==0.6.0 versioneer==0.19 diff --git a/whobpyt/models/robinson/params_robinson.py b/whobpyt/models/robinson/params_robinson.py new file mode 100644 index 00000000..788658fc --- /dev/null +++ b/whobpyt/models/robinson/params_robinson.py @@ -0,0 +1,60 @@ +""" +Authors: Zheng Wang, John Griffiths, Davide Momi, Kevin Kadak, Parsa Oveisi, Taha Morshedzadeh, Sorenza Bastiaens +Neural Mass Model fitting +module for Robinson with forward backward and lateral connection for EEG +""" + +# @title new function PyTepFit + +# Pytorch stuff + + +""" +Importage +""" +import torch +from torch.nn.parameter import Parameter +from whobpyt.datatypes.parameter import par +from whobpyt.datatypes.AbstractParams import AbstractParams +from whobpyt.datatypes.AbstractNMM import AbstractNMM +import numpy as np # for numerical operations + +class ParamsRobinsonTime(AbstractParams): + + def __init__(self, **kwargs): + + param = { + "Q_max": par(250), + "sig_theta": par(15/1000), + "sigma": par(3.3/1000), + "gamma": par(100), + "beta": par(200), + "alpha": par(200/4), + "t0": par(0.08), + "g": par(100), + "nu_ee": par(0.0528/1000), + "nu_ii": par(0.0528/1000), + "nu_ie": par(0.02/1000), + "nu_es": par(1.2/1000), + "nu_is": par(1.2/1000), + "nu_se": par(1.2/1000), + "nu_si": par(0.0), + "nu_ei": par(0.4/1000), + "nu_sr": par(0.01/1000), + "nu_sn": par(0.0), + "nu_re": par(0.1/1000), + "nu_ri": par(0.0), + "nu_rs": par(0.1/1000), + "nu_ss": par(0.0), + "nu_rr": par(0.0), + "nu_rn": par(0.0), + "mu": par(5), + "cy0": par(5), + "y0": par(2) + } + + for var in param: + setattr(self, var, param[var]) + + for var in kwargs: + setattr(self, var, kwargs[var]) \ No newline at end of file diff --git a/whobpyt/models/robinson/robinson.py b/whobpyt/models/robinson/robinson.py new file mode 100644 index 00000000..dcdd98dd --- /dev/null +++ b/whobpyt/models/robinson/robinson.py @@ -0,0 +1,422 @@ +""" +Authors: Zheng Wang, John Griffiths, Davide Momi, Kevin Kadak, Parsa Oveisi, Taha Morshedzadeh, Sorenza Bastiaens +Neural Mass Model fitting +module for Robinson with forward backward and lateral connection for EEG +""" + +# @title new function PyTepFit + +# Pytorch stuff + + +""" +Importage +""" +import torch +from torch.nn.parameter import Parameter +from whobpyt.datatypes.AbstractParams import AbstractParams +from whobpyt.datatypes.AbstractNMM import AbstractNMM +from whobpyt.models.robinson.params_robinson import ParamsRobinsonTime +from whobpyt.datatypes.parameter import par +import numpy as np # for numerical operations + + +class RNNROBINSON(AbstractNMM): + """ + A module for forward model (Robinson) to simulate a batch of EEG signals + Attibutes + --------- + state_size : int + the number of states in the Robinson model + input_size : int + the number of states with noise as input + tr : float + tr of image + step_size: float + Integration step for forward model + hidden_size: int + the number of step_size in a tr + TRs_per_window: int + the number of EEG signals to simulate + node_size: int + the number of ROIs + sc: float node_size x node_size array + structural connectivity + fit_gains: bool + flag for fitting gains 1: fit 0: not fit + g: float + global gain parameter + w_bb: tensor with node_size x node_size (grad on depends on fit_gains) + connection gains + std_in std_out: tensor with gradient on + std for state noise and output noise + hyper parameters for prior distribution of model parameters + Methods + ------- + forward(input, noise_out, hx) + forward model (Robinson) for generating a number of EEG signals with current model parameters + """ + + def __init__(self, node_size: int, + TRs_per_window: int, step_size: float, output_size: int, tr: float, sc: float, lm: float, dist: float, + use_fit_gains: bool, use_fit_lfm: bool, param: ParamsRobinson) -> None: + """ + Parameters + ---------- + tr : float + tr of image + step_size: float + Integration step for forward model + + TRs_per_window: int + the number of EEG signals to simulate + node_size: int + the number of ROIs + output_size: int + the number of channels EEG + sc: float node_size x node_size array + structural connectivity + use_fit_gains: bool + flag for fitting gains 1: fit 0: not fit + use_fit_lfm: bool + flag for fitting gains 1: fit 0: not fit + param from ParamCT + """ + super(RNNROBINSON, self).__init__() + + self.state_names = ['V_e', 'V_e_dot', 'phi_e', 'phi_e_dot', 'V_i', 'V_i_dot', 'phi_i', 'phi_i_dot'] + self.output_names = ["eeg"] + self.model_name = "CT" + + self.state_size = 8 # 8 states CT model + self.tr = tr # tr ms (integration step 0.1 ms) + self.step_size = torch.tensor(step_size, dtype=torch.float32) # integration step 0.1 ms + self.steps_per_TR = int(tr / step_size) + self.TRs_per_window = TRs_per_window # size of the batch used at each step + self.node_size = node_size # num of ROI + self.output_size = output_size # num of EEG channels + self.sc = sc # matrix node_size x node_size structure connectivity + self.dist = torch.tensor(dist, dtype=torch.float32) + self.lm = lm + self.use_fit_gains = use_fit_gains # flag for fitting gains + self.use_fit_lfm = use_fit_lfm + self.param = param + + self.output_size = lm.shape[0] # number of EEG channels + + def info(self): + return {"state_names": ['V_e', 'V_e_dot', 'phi_e', 'phi_e_dot', 'V_i', 'V_i_dot', 'phi_i', 'phi_i_dot'], "output_names": ["eeg"]} + + def createIC(self, ver): + # initial state + if (ver == 0): + state_lb = 0.5 + state_ub = 2 + return torch.tensor(np.random.uniform(state_lb, state_ub, (self.node_size+1, self.state_size)), + dtype=torch.float32) + if (ver == 1): + state_lb = 0 + state_ub = 5 + return torch.tensor(np.random.uniform(state_lb, state_ub, (self.node_size+1, self.state_size)), + dtype=torch.float32) + + # TODO: Note Version 0 is training, Version 1 is testing, so creating version 2 is probably not what was intended. Likely createDelayIC should be updated as well. + + if (ver == 2): # for testing the robinson corticothalamic model + state_lb = -1.5*1e-4 + state_ub = 1.5*1e-4 + return torch.tensor(np.random.uniform(state_lb, state_ub, (self.node_size+1, self.state_size)), + dtype=torch.float32) + + def createDelayIC(self, ver): + state_lb = 0 + state_ub = 5 + delays_max = 500 + return torch.tensor(np.random.uniform(state_lb, state_ub, (self.node_size, delays_max)), dtype=torch.float32) + + + def setModelParameters(self): + # set states E I f v mean and 1/sqrt(variance) + return setModelParameters(self) + + def forward(self, external, hx, hE): + return integration_forward(self, external, hx, hE) + +def sigmoid(x, Q_max, sig_theta, sigma): + return Q_max / (1 + torch.exp(-(x-sig_theta) / sigma)) + + +def setModelParameters(model): + param_reg = [] + param_hyper = [] + # set model parameters (variables: need to calculate gradient) as Parameter others : tensor + # set w_bb as Parameter if fit_gain is True + if model.use_fit_gains: + model.w_bb = Parameter(torch.tensor(np.zeros((model.node_size, model.node_size)) + 0.05, + dtype=torch.float32)) # connenction gain to modify empirical sc + model.w_ff = Parameter(torch.tensor(np.zeros((model.node_size, model.node_size)) + 0.05, + dtype=torch.float32)) + model.w_ll = Parameter(torch.tensor(np.zeros((model.node_size, model.node_size)) + 0.05, + dtype=torch.float32)) + param_reg.append(model.w_ll) + param_reg.append(model.w_ff) + param_reg.append(model.w_bb) + else: + model.w_bb = torch.tensor(np.zeros((model.node_size, model.node_size)), dtype=torch.float32) + model.w_ff = torch.tensor(np.zeros((model.node_size, model.node_size)), dtype=torch.float32) + model.w_ll = torch.tensor(np.zeros((model.node_size, model.node_size)), dtype=torch.float32) + + if model.use_fit_lfm: + model.lm = Parameter(torch.tensor(model.lm, dtype=torch.float32)) # leadfield matrix from sourced data to eeg + param_reg.append(model.lm) + else: + model.lm = torch.tensor(model.lm, dtype=torch.float32) # leadfield matrix from sourced data to eeg + + var_names = [a for a in dir(model.param) if (type(getattr(model.param, a)) == par)] + for var_name in var_names: + var = getattr(model.param, var_name) + if (var.fit_hyper == True): + if var_name == 'lm': + size = var.prior_var.shape + var.val = Parameter(var.val.detach() - 1 * torch.ones((size[0], size[1]))) # TODO: This is not consistent with what user would expect giving a variance + param_hyper.append(var.prior_mean) + param_hyper.append(var.prior_var) + elif (var != 'std_in'): + var.randSet() #TODO: This should be done before giving params to model class + param_hyper.append(var.prior_mean) + param_hyper.append(var.prior_var) + + if (var.fit_par): + param_reg.append(var.val) #TODO: This should got before fit_hyper, but need to change where randomness gets added in the code first + setattr(model, var_name, var.val) + + model.params_fitted = {'modelparameter': param_reg,'hyperparameter': param_hyper} + + +def integration_forward(model, external, hx, hE): + + # define some constants + conduct_lb = 1.5 # lower bound for conduct velocity + u_2ndsys_ub = 500 # the bound of the input for second order system + noise_std_lb = 150 # lower bound of std of noise + lb = 0.01 # lower bound of local gains + s2o_coef = 0.0001 # coefficient from states (source EEG) to EEG + k_lb = 0.5 # lower bound of coefficient of external inputs + + next_state = {} + + V_e = hx[:model.node_size, 0:1] # voltage of cortical excitatory population + V_e_dot = hx[:model.node_size, 1:2] # current of cortical excitatory population + phi_e = hx[:model.node_size, 2:3] # firing rate of excitory population + phi_e_dot = hx[:model.node_size, 3:4] # change in firing rate of excitory population + V_i = hx[:model.node_size, 4:5] # voltage of cortical inhibitory population + V_i_dot = hx[:model.node_size, 5:6] # current of cortical inhibitory population + phi_i = hx[:model.node_size, 6:7] # firing rate of inhibitory population + phi_i_dot = hx[:model.node_size, 7:8] # change in firing rate of inhibitory population + + V_s = hx[model.node_size:model.node_size+1, 0:1] # voltage of cortical excitatory population + V_s_dot = hx[model.node_size:model.node_size+1, 1:2] # current of cortical excitatory population + phi_s = hx[model.node_size:model.node_size+1, 2:3] # firing rate of excitory population + phi_s_dot = hx[model.node_size:model.node_size+1, 3:4] # change in firing rate of excitory population + V_r = hx[model.node_size:model.node_size+1, 4:5] # voltage of cortical inhibitory population + V_r_dot = hx[model.node_size:model.node_size+1, 5:6] # current of cortical inhibitory population + phi_r = hx[model.node_size:model.node_size+1, 6:7] # firing rate of inhibitory population + phi_r_dot = hx[model.node_size:model.node_size+1, 7:8] # change in firing rate of inhibitory population + + dt = model.step_size + # Generate the ReLU module for model parameters gEE gEI and gIE + + m = torch.nn.ReLU() + + # define constant 1 tensor + con_1 = torch.tensor(1.0, dtype=torch.float32) + if model.sc.shape[0] > 1: + # Update the Laplacian based on the updated connection gains w_bb. + w = torch.exp(model.w_ll) * torch.tensor(model.sc, dtype=torch.float32) + w_n_l = (0.5 * (w + torch.transpose(w, 0, 1))) / torch.linalg.norm( + 0.5 * (w + torch.transpose(w, 0, 1))) + + model.sc_fitted = w_n_l + dg_l = -torch.diag(torch.sum(w_n_l, dim=1)) + else: + l_s = torch.tensor(np.zeros((1, 1)), dtype=torch.float32) + dg_l = 0 + + w_n_l = 0 + + + model.delays = (model.dist / (conduct_lb * con_1 + m(model.mu))).type(torch.int64) + # print(torch.max(model.delays), model.delays.shape) + + # placeholder for the updated current state + current_state = torch.zeros_like(hx) + + # placeholders for output BOLD, history of E I x f v and q + eeg_window = [] + V_e_window = [] + V_e_dot_window = [] + phi_e_window = [] + phi_e_dot_window = [] + V_i_window = [] + V_i_dot_window = [] + phi_i_window = [] + phi_i_dot_window = [] + + + alpha = model.alpha + beta = model.beta + alphaxbeta = model.alpha*model.beta + gamma = model.gamma + gamma_rs = model.gamma*1 + nu_ee = model.nu_ee #torch.exp(model.nu_ee) + nu_ei = model.nu_ei #torch.exp(model.nu_ei) + nu_es = model.nu_es #torch.exp(model.nu_es) + nu_ie = model.nu_ie #torch.exp(model.nu_ie) + nu_ii = model.nu_ii #torch.exp(model.nu_ii) + nu_is = model.nu_is #torch.exp(model.nu_is) + nu_se = model.nu_se #torch.exp(model.nu_se) + nu_si = model.nu_si + nu_ss = model.nu_ss + nu_sr = model.nu_sr #was set to log in params + nu_sn = model.nu_sn + nu_re = model.nu_re #torch.exp(model.nu_re) + nu_ri = model.nu_ri + nu_rs = model.nu_rs #torch.exp(model.nu_rs) + Q = model.Q_max + sig_theta = model.sig_theta + sigma = model.sigma + # Use the forward model to get EEG signal at ith element in the window. + for i_window in range(model.TRs_per_window): + + for step_i in range(model.steps_per_TR): + Ed = torch.tensor(np.zeros((model.node_size, model.node_size)), dtype=torch.float32) # delayed E + + """for ind in range(model.node_size): + #print(ind, hE[ind,:].shape, model.delays[ind,:].shape) + Ed[ind] = torch.index_select(hE[ind,:], 0, model.delays[ind,:])""" + hE_new = hE.clone() + Ed = hE_new.gather(1, model.delays) # delayed E + + + LEd_l = torch.reshape(torch.sum(w_n_l * torch.transpose(Ed, 0, 1), 1), + (model.node_size, 1)) # weights on delayed E + # Input noise for M. + + u_tms = external[:, step_i:step_i + 1, i_window] + #u_aud = external[:, i_hidden:i_hidden + 1, i_window, 1] + #u_0 = external[:, i_hidden:i_hidden + 1, i_window, 2] + + # LEd+torch.matmul(dg,E): Laplacian on delayed E + + + + # Update the states by step-size. + ddVe = V_e + dt * V_e_dot + ddphie = phi_e + dt * phi_e_dot + ddVi = V_i + dt * V_i_dot + ddphii = phi_i + dt * phi_i_dot + ddVs = V_s + dt * V_s_dot + ddphis = phi_s + dt * phi_s_dot + ddVr = V_r + dt * V_r_dot + ddphir = phi_r + dt * phi_r_dot + + ones_mx = torch.ones((1,1)) # 1x1 ones matrix + noise_phi_e = torch.randn_like(phi_e) * 10 # noise for phi_e, stdev = 1, to be added to ddphiedot + noise_phi_i = torch.randn_like(phi_i) * 10 + rVe = m(nu_ee) * phi_e - m(nu_ei) * phi_i + m(nu_es) * phi_s + rVi = m(nu_ie) * phi_e + m(nu_ii) * phi_i + m(nu_is) * phi_s + rVs =torch.mean( m(nu_se) * phi_e, axis=0)*ones_mx + torch.mean(m(nu_si) *phi_i, axis=0)*ones_mx + \ + m(nu_ss) * phi_s - m(nu_sr) * phi_r + m(nu_sn) * (0.025)*ones_mx + \ + 0.001 * torch.randn(1, 1) + rVr = m(nu_re) * torch.mean(phi_e, axis=0) * ones_mx + \ + m(nu_ri) * torch.mean(phi_i, axis=0) * ones_mx + \ + m(nu_rs) * phi_s + network_interactions = (lb * con_1 + m(model.g)) * (LEd_l + \ + 1 * torch.matmul(dg_l, phi_e) )# is implementation of global gain & SC + rphi_e = sigmoid(V_e, Q, sig_theta, sigma) + network_interactions + noise_phi_e +u_tms + rphi_i = sigmoid(V_i, Q, sig_theta, sigma) + noise_phi_i + rphi_s = sigmoid(V_s, Q, sig_theta, sigma) + rphi_r = sigmoid(V_r, Q, sig_theta, sigma) + + ddVedot = V_e_dot + dt * (-(1/alpha + 1/beta) * alphaxbeta * V_e_dot -alphaxbeta*V_e + \ + alphaxbeta*(rVe)) + ddVidot = V_i_dot + dt * (-(1/alpha + 1/beta) * alphaxbeta * V_i_dot -alphaxbeta*V_i + \ + alphaxbeta*(rVi)) + ddVsdot = V_s_dot + dt * (-(1/alpha + 1/beta) * alphaxbeta * V_s_dot - alphaxbeta * V_s + \ + alphaxbeta*(rVs)) + ddVrdot = V_r_dot + dt * (-(1/alpha + 1/beta) * alphaxbeta * V_r_dot -alphaxbeta*V_r + \ + alphaxbeta*(rVr)) + + ddphiedot = phi_e_dot + dt * (-2 * gamma * phi_e_dot - gamma**2 * phi_e + \ + gamma**2 * (rphi_e)) + ddphiidot = phi_i_dot + dt * (-2 * gamma_rs * phi_i_dot - gamma_rs**2 * phi_i + \ + gamma_rs**2 * (rphi_i)) + ddphisdot = phi_s_dot + dt * (-2 * gamma_rs * phi_s_dot - gamma_rs**2 * phi_s + \ + gamma_rs**2 * (rphi_s)) + ddphirdot = phi_r_dot + dt * (-2 * gamma_rs * phi_r_dot - gamma_rs**2 * phi_r + \ + gamma_rs**2 * (rphi_r)) + + + # Calculate the saturation for model states (for stability and gradient calculation). + V_e = ddVe # 1000*torch.tanh(ddE/1000)#torch.tanh(0.00001+torch.nn.functional.relu(ddE)) + V_i = ddVi # 1000*torch.tanh(ddI/1000)#torch.tanh(0.00001+torch.nn.functional.relu(ddI)) + V_s = ddVs # 1000*torch.tanh(ddM/1000) + V_r = ddVr + phi_e = ddphie + phi_i = ddphii + phi_s = ddphis + phi_r = ddphir + + V_e_dot = 1000*torch.tanh(ddVedot/1000) # 1000*torch.tanh(ddEv/1000)#(con_1 + torch.tanh(df - con_1)) + V_i_dot = 1000*torch.tanh(ddVidot/1000) # 1000*torch.tanh(ddIv/1000)#(con_1 + torch.tanh(dv - con_1)) + V_s_dot = 1000*torch.tanh(ddVsdot/1000) # 1000*torch.tanh(ddMv/1000)#(con_1 + torch.tanh(dq - con_1)) + V_r_dot = 1000*torch.tanh(ddVrdot/1000) + phi_e_dot = 1000*torch.tanh(ddphiedot/1000) + phi_i_dot = 1000*torch.tanh(ddphiidot/1000) + phi_s_dot = 1000*torch.tanh(ddphisdot/1000) + phi_r_dot = 1000*torch.tanh(ddphirdot/1000) + + # update placeholders for E buffer + hE[:, 0] = phi_e[:, 0] + # hE = torch.cat([M, hE[:, :-1]], axis=1) + + # Put M E I Mv Ev and Iv at every tr to the placeholders for checking them visually. + V_e_window.append(torch.cat([V_e, V_s], axis = 0)) + V_i_window.append(torch.cat([V_i, V_r], axis = 0)) + phi_e_window.append(torch.cat([phi_e, phi_s], axis = 0)) + phi_i_window.append(torch.cat([phi_i, phi_r], axis = 0)) + V_e_dot_window.append(torch.cat([V_e_dot, V_s_dot], axis = 0)) + V_i_dot_window.append(torch.cat([V_i_dot, V_r_dot], axis = 0)) + phi_e_dot_window.append(torch.cat([phi_e_dot, phi_s_dot], axis = 0)) + phi_i_dot_window.append(torch.cat([phi_i_dot, phi_r_dot], axis = 0)) + hE = torch.cat([phi_e, hE[:, :-1]], dim=1) # update placeholders for E buffer + + # Put the EEG signal each tr to the placeholder being used in the cost calculation. + lm_t = (model.lm.T / torch.sqrt(model.lm ** 2).sum(1)).T + + model.lm_t = (lm_t - 1 / model.output_size * torch.matmul(torch.ones((1, model.output_size)), + lm_t)) # s2o_coef * + temp = model.cy0 * torch.matmul(model.lm_t, phi_e[:model.node_size, :]) - 1 * model.y0 + eeg_window.append(temp) # torch.abs(E) - torch.abs(I) + 0.0*noiseEEG) + + # Update the current state. + current_state = torch.cat([torch.cat([V_e, V_s], axis = 0), \ + torch.cat([V_e_dot, V_s_dot], axis = 0), \ + torch.cat([phi_e, phi_s], axis = 0), \ + torch.cat([phi_e_dot, phi_s_dot], axis = 0), \ + torch.cat([V_i, V_r], axis = 0), \ + torch.cat([V_i_dot, V_r_dot], axis = 0), \ + torch.cat([phi_i, phi_r], axis = 0), \ + torch.cat([phi_i_dot, phi_r_dot], axis = 0)], dim=1) + next_state['current_state'] = current_state + next_state['eeg_window'] = torch.cat(eeg_window, dim=1) + next_state['V_e_window'] = torch.cat(V_e_window, dim=1) + next_state['V_i_window'] = torch.cat(V_i_window, dim=1) + next_state['phi_e_window'] = torch.cat(phi_e_window, dim=1) + next_state['phi_i_window'] = torch.cat(phi_i_window, dim=1) + next_state['V_e_dot_window'] = torch.cat(V_e_dot_window, dim=1) + next_state['V_i_dot_window'] = torch.cat(V_i_dot_window, dim=1) + next_state['phi_e_dot_window'] = torch.cat(phi_e_dot_window, dim=1) + next_state['phi_i_dot_window'] = torch.cat(phi_i_dot_window, dim=1) + + return next_state, hE \ No newline at end of file diff --git a/whobpyt/models/robinson_freq/params_robinson_freq.py b/whobpyt/models/robinson_freq/params_robinson_freq.py new file mode 100644 index 00000000..94a727c5 --- /dev/null +++ b/whobpyt/models/robinson_freq/params_robinson_freq.py @@ -0,0 +1,40 @@ +""" +Authors: Zheng Wang, John Griffiths, Davide Momi, Kevin Kadak, Parsa Oveisi, Taha Morshedzadeh, Sorenza Bastiaens +Neural Mass Model fitting +module for Robinson with forward backward and lateral connection for EEG +""" + +import torch +from torch.nn.parameter import Parameter +from whobpyt.datatypes.parameter import par +from whobpyt.datatypes.AbstractParams import AbstractParams +from whobpyt.datatypes.AbstractNMM import AbstractNMM +import numpy as np # for numerical operations + +class ParamsRobinsonFreq(AbstractParams): + + def __init__(self, **kwargs): + param = { + 'gamma': par(100), + 'beta': par(200), + 'alpha': par(50), + 't0_2': par(0.08), + 'ii': par(0.0528), + 'ee': par(0.0528), + + 'es': par(1.2), + 'sr': par(-0.01), + 'sn': par(10.0), + + 'eis': par(-0.48), + 'eie': par(-0.008000000000000002), + 'srs': par(-0.0010000000000000002), + 'g_ese': par(-0.576), + 'g_esre':par(0.00047999999999999996) + } + + for var in param: + setattr(self, var, param[var]) + + for var in kwargs: + setattr(self, var, kwargs[var]) \ No newline at end of file diff --git a/whobpyt/models/robinson_freq/robinson_fq.py b/whobpyt/models/robinson_freq/robinson_fq.py new file mode 100644 index 00000000..73d9a79b --- /dev/null +++ b/whobpyt/models/robinson_freq/robinson_fq.py @@ -0,0 +1,111 @@ +""" +Authors: Zheng Wang, John Griffiths, Davide Momi, Kevin Kadak, Parsa Oveisi, Taha Morshedzadeh, Sorenza Bastiaens +Neural Mass Model fitting +module for Robinson with forward backward and lateral connection for EEG +""" + +import torch +from torch.nn.parameter import Parameter +from whobpyt.datatypes.AbstractParams import AbstractParams +from whobpyt.datatypes.AbstractNMM import AbstractNMM +from whobpyt.models.robinson_freq.params_robinson_freq import ParamsRobinsonFreq +from whobpyt.datatypes.parameter import par +import numpy as np + + +class RNNROBINSON_FQ(AbstractNMM): + """ + A module for forward model (Robinson) to simulate PSD of EEG signals + Attibutes + --------- + + node_size: int + the number of ROIs + output_size: int + the number of eeg channels + Methods + ------- + forward(input, noise_out, hx) + forward model (Robinson) for generating a number of EEG signals with current model parameters + """ + + def __init__(self, node_size: int, output_size: int, param: ParamsRobinsonFreq, use_fit_gains=False, use_fit_lfm=False) -> None: + """ + Parameters + ---------- + + param from ParamJR + """ + super(RNNROBINSON_FQ, self).__init__() + + self.param = param + self.node_size = node_size + self.output_size = node_size + self.use_fit_gains = use_fit_gains + self.use_fit_lfm = use_fit_lfm + + + def setModelParameters(self): + vars_name = [a for a in dir(self.param) if not a.startswith('__') and not callable(getattr(self.param, a))] + for var in vars_name: + if np.any(getattr(self.param, var)[1] > 0): + if var == 'lm': + size = getattr(self.param, var)[1].shape + setattr(self, var, Parameter( + torch.tensor(getattr(self.param, var)[0] -np.ones((size[0], size[1])), + dtype=torch.float32))) + print(getattr(self, var)) + else: + setattr(self, var, Parameter(torch.tensor(getattr(self.param, var)[0] + getattr(self.param, var)[1] * np.random.randn(1, )[0], + dtype=torch.float32))) + + if var not in ['std_in']: + dict_nv = {'m': getattr(self.param, var)[0], 'v': 1 / (getattr(self.param, var)[1]) ** 2} + + dict_np = {'m': var + '_m', 'v': var + '_v_inv'} + + for key in dict_nv: + setattr(self, dict_np[key], Parameter(torch.tensor(dict_nv[key], dtype=torch.float32))) + else: + setattr(self, var, torch.tensor(getattr(self.param, var)[0], dtype=torch.float32)) + + + def forward(self, input): + """ + Forward step in simulating the EEG signal. + Parameters + ---------- + input: list of frequencey + + Outputs + ------- + next_state: pws with given frequence same size as input + + """ + + # define some constants + next_state = [] + + + for i_fq in range(input.shape[0]): + #print(i_fq) + omega = i_fq * 2*np.pi*torch.ones((self.node_size,1)) + j = complex(0, 1) # imaginary number + s = omega * j + tf = (self.alpha*self.beta) / ((s + self.alpha)*(s + self.beta)) + + closed_loop_ei = ((self.eis*tf) / (1 - self.ii*tf) + self.es) + closed_loop_rs = (self.sn * tf**3 *torch.exp(-s * torch.exp(self.t0_2))) / (1 - (self.sr* tf**2)) + + q2r2 = (1 + s/ self.gamma)**2+self.K -self.ee*tf-tf**2/(1 - self.ii * tf) \ + *(self.eie + ((self.g_ese + self.g_esre * tf) * tf) \ + / (1 - self.srs * tf**2)*torch.exp(-s * torch.exp(self.t0_2)*2)) \ + -(self.g_ese + self.g_esre * tf)/(1 - self.srs * tf**2)*torch.exp(-s * torch.exp(self.t0_2)*2)*tf**2 + + closed_loop_g = closed_loop_ei * closed_loop_rs * (1 / q2r2) + #print(torch.abs(closed_loop_g)) + lm_n = self.lm/torch.sqrt((self.lm**2).sum()) + next_state.append(torch.exp(self.gain_tune)*torch.abs(torch.matmul(lm_n+0*j, closed_loop_g))) + + + return torch.cat(next_state, dim=1) \ No newline at end of file