Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Robinson model #149

Draft
wants to merge 7 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions doc/getting_started/how_to_contribute.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Contributing to WhoBPyT Codebase
Authors: Andrew Clappison, Kevin Kadak, John Griffiths

The below instuctions outline how to properly adhere to version control for contributing to the WhoBPyT repo.

## Setting Up (Done Once):

- **Downloading your Fork**
- Must have already configured an authentication key and have forked the repository on github.com; ensure your fork is up-to-date with the whobpyt/dev branch from which your new branch will be created.
- Open terminal and go to the desired directory.
- `git clone git@github.com:<<your_github_account>>/whobpyt.git`

- **Adding Upstream**
- `cd whobpyt`
- `git remote add upstream https://github.com/GriffithsLab/whobpyt.git`
- `git fetch upstream`

## Coding Cycle (Done for each new feature or group of features):

- **Creating a New Branch**
- `git fetch upstream`
- `git checkout --track upstream/dev`
- `git push origin dev`
- `git checkout -b <<new_branch_name>>`

- **Editing Code**
- Add/Delete/Edit code

- **Testing (WhoBPyT Sphinx Examples should run successfully on Linux, but may fail to run on Windows)**
- Optionally: Rename sphinx examples ending in “r” to “x” if it is not relevant to the code changes done (for quicker debugging). Example: “eg001r...” to “eg001x...”.
- `cd doc`
- `make clean`
- `make html`
- Open and inspect in a web browser: whobpyt/doc/_build/html/html.txt
- Additional other testing may also be advised.

- **Committing Code**
- `git status`
- `git add <<edited_file_names>>`
- `git commit -m “<<commit_message>>”`

- **Pushing Code**
- `git push --set-upstream origin <<new_branch_name>>`

- **Creating a pull request**
- On github.com do a pull request from the new branch on your fork to the main repo’s dev branch. If there is a merging conflict, it will have to be addressed before proceeding.

6 changes: 3 additions & 3 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Welcome to WhoBPyT documentation!
:maxdepth: 1
:caption: Getting Started
:glob:

getting_started/installation
getting_started/running_in_colab
getting_started/running_in_codespaces
Expand All @@ -36,9 +36,9 @@ Welcome to WhoBPyT documentation!
:maxdepth: 1
:caption: Examples
:glob:

auto_examples/eg002r__multimodal_simulation
auto_examples/eg003r__fitting_rww_example
auto_examples/eg004r__fitting_JR_example
auto_examples/eg005r__gpu_support

auto_examples/eg006r__replicate_Momi2023
193 changes: 193 additions & 0 deletions examples/eg006r__replicate_Momi2023.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# -*- coding: utf-8 -*-
r"""
=================================
Replicate Momi et al. (2023): TMS-evoked Responses
===========================================

This script replicates the findings of the paper:

Momi, D., Wang, Z., Griffiths, J.D. (2023).
"TMS-evoked responses are driven by recurrent large-scale network dynamics."
eLife, [doi: 10.7554/eLife.83232](https://elifesciences.org/articles/83232)

The code includes data fetching, model fitting, and result visualization based on the methods presented in the paper.

"""


# sphinx_gallery_thumbnail_number = 1
#
# %%
# Importage
# --------------------------------------------------

# whobpyt stuff
import whobpyt
from whobpyt.datatypes import par, Recording
from whobpyt.models.JansenRit import RNNJANSEN, ParamsJR
from whobpyt.run import Model_fitting
from whobpyt.optimization.custom_cost_JR import CostsJR

# python stuff
import numpy as np
import pandas as pd
import scipy.io
import gdown
import pickle
import warnings
warnings.filterwarnings('ignore')

#neuroimaging packages
import mne

# viz stuff
import matplotlib.pyplot as plt



# %%
# Download and load necessary data for the example
url='https://drive.google.com/drive/folders/1Qu-JyZc3-SL-Evsystg4D-DdpsGU4waB?usp=sharing'
gdown.download_folder(url, quiet=True)


# %%
# Load EEG data from a file
file_name = './data/Subject_1_low_voltage.fif'
epoched = mne.read_epochs(file_name, verbose=False);
evoked = epoched.average()

# %%
# Load Atlas
url = 'https://raw.githubusercontent.com/ThomasYeoLab/CBIG/master/stable_projects/brain_parcellation/Schaefer2018_LocalGlobal/Parcellations/MNI/Centroid_coordinates/Schaefer2018_200Parcels_7Networks_order_FSLMNI152_2mm.Centroid_RAS.csv'
atlas = pd.read_csv(url)
labels = atlas['ROI Name']
coords = np.array([atlas['R'], atlas['A'], atlas['S']]).T
conduction_velocity = 5 #in ms

# %%
# Compute the distance matrix
dist = np.zeros((coords.shape[0], coords.shape[0]))

for roi1 in range(coords.shape[0]):
for roi2 in range(coords.shape[0]):
dist[roi1, roi2] = np.sqrt(np.sum((coords[roi1,:] - coords[roi2,:])**2, axis=0))
dist[roi1, roi2] = np.sqrt(np.sum((coords[roi1,:] - coords[roi2,:])**2, axis=0))


# %%
# Load the stim weights matrix which encode where to inject the external input
stim_weights = np.load('./data/stim_weights.npy')
stim_weights_thr = stim_weights.copy()
labels[np.where(stim_weights_thr>0)[0]]

# %%
# Load the structural connectivity matrix
sc_file = './data/Schaefer2018_200Parcels_7Networks_count.csv'
sc_df = pd.read_csv(sc_file, header=None, sep=' ')
sc = sc_df.values
sc = np.log1p(sc) / np.linalg.norm(np.log1p(sc))

# %%
# Load the leadfield matrix
lm = np.load('./data/Subject_1_low_voltage_lf.npy')
ki0 =stim_weights_thr[:,np.newaxis]
delays = dist/conduction_velocity

# %%
# define options for JR model
eeg_data = evoked.data.copy()
time_start = np.where(evoked.times==-0.1)[0][0]
time_end = np.where(evoked.times==0.3)[0][0]
eeg_data = eeg_data[:,time_start:time_end]/np.abs(eeg_data).max()*4
node_size = sc.shape[0]
output_size = eeg_data.shape[0]
batch_size = 20
step_size = 0.0001
num_epoches = 120
tr = 0.001
state_size = 6
base_batch_num = 20
time_dim = 400
state_size = 6
base_batch_num = 20
hidden_size = int(tr/step_size)


# %%
# prepare data structure of the model
data_mean = Recording(eeg_data, num_epoches, batch_size)

# %%
# get model parameters structure and define the fitted parameters by setting non-zero variance for the model
lm = np.zeros((output_size,200))
lm_v = np.zeros((output_size,200))
params = ParamsJR(A = par(3.25), a= par(100,100, 2, True, True), B = par(22), b = par(50, 50, 1, True, True),
g=par(500,500,2, True, True), g_f=par(10,10,1, True, True), g_b=par(10,10,1, True, True),
c1 = par(135, 135, 1, True, True), c2 = par(135*0.8, 135*0.8, 1, True, True), c3 = par(135*0.25, 135*0.25, 1, True, True),
c4 = par(135*0.25, 135*0.25, 1, True, True), std_in= par(0,0, 1, True, True), vmax= par(5), v0=par(6), r=par(0.56),
y0=par(-2, -2, 1/4, True, True),mu = par(1., 1., 0.4, True, True), k =par(5., 5., 0.2, True, True), k0=par(0),
cy0 = par(50, 50, 1, True, True), ki=par(ki0), lm=par(lm, lm, 1 * np.ones((output_size, node_size))+lm_v, True, True))


# %%
# call model want to fit
model = RNNJANSEN(node_size, batch_size, step_size, output_size, tr, sc, lm, dist, True, False, params)


# create objective function
ObjFun = CostsJR(model)


# %%
# call model fit
F = Model_fitting(model, ObjFun)

# %%
# model training
u = np.zeros((node_size,hidden_size,time_dim))
u[:,:,80:120]= 1000
F.train(u=u, empRecs = [data_mean], num_epochs = num_epoches, TPperWindow = batch_size)

# %%
# model test with 20 window for warmup
F.evaluate(u = u, empRec = data_mean, TPperWindow = batch_size, base_window_num = 20)

# filename = 'Subject_1_low_voltage_fittingresults_stim_exp.pkl'
# with open(filename, 'wb') as f:
# pickle.dump(F, f)

# %%
# Plot the original and simulated EEG data
epoched = mne.read_epochs(file_name, verbose=False);
evoked = epoched.average()
ts_args = dict(xlim=[-0.1,0.3])
ch, peak_locs1 = evoked.get_peak(ch_type='eeg', tmin=-0.05, tmax=0.01)
ch, peak_locs2 = evoked.get_peak(ch_type='eeg', tmin=0.01, tmax=0.02)
ch, peak_locs3 = evoked.get_peak(ch_type='eeg', tmin=0.03, tmax=0.05)
ch, peak_locs4 = evoked.get_peak(ch_type='eeg', tmin=0.07, tmax=0.15)
ch, peak_locs5 = evoked.get_peak(ch_type='eeg', tmin=0.15, tmax=0.20)
times = [peak_locs1, peak_locs2, peak_locs3, peak_locs4, peak_locs5]
plot = evoked.plot_joint(ts_args=ts_args, times=times);


simulated_EEG_st = evoked.copy()
simulated_EEG_st.data[:,time_start:time_end] = F.lastRec['eeg'].npTS()
times = [peak_locs1, peak_locs2, peak_locs3, peak_locs4, peak_locs5]
simulated_joint_st = simulated_EEG_st.plot_joint(ts_args=ts_args, times=times)


# %%
# Results Description
# ---------------------------------------------------
#

# The plot above shows the original EEG data and the simulated EEG data using the fitted Jansen-Rit model.
# The simulated data closely resembles the original EEG data, indicating that the model fitting was successful.
# Peak locations extracted from different time intervals are marked on the plots, demonstrating the model's ability
# to capture key features of the EEG signal.

# %%
# Reference:
# Momi, D., Wang, Z., Griffiths, J.D. (2023). "TMS-evoked responses are driven by recurrent large-scale network dynamics."
# eLife, 10.7554/eLife.83232. https://doi.org/10.7554/eLife.83232
13 changes: 6 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@ mne


# Docs requirements
sphinx
#sphinx==3.1.1
sphinx-gallery==0.8.1
sphinx_rtd_theme==0.5.0
sphinx-tabs==1.3.0
sphinx-copybutton==0.3.1
sphinxcontrib-httpdomain==1.7.0
sphinx==5.0.0
sphinx-gallery==0.15.0
sphinx_rtd_theme==2.0.0
sphinx-tabs==3.4.4
sphinx-copybutton==0.5.2
sphinxcontrib-httpdomain==1.8.1
numpydoc==1.1.0
recommonmark==0.6.0
versioneer==0.19
Expand Down
60 changes: 60 additions & 0 deletions whobpyt/models/robinson/params_robinson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
Authors: Zheng Wang, John Griffiths, Davide Momi, Kevin Kadak, Parsa Oveisi, Taha Morshedzadeh, Sorenza Bastiaens
Neural Mass Model fitting
module for Robinson with forward backward and lateral connection for EEG
"""

# @title new function PyTepFit

# Pytorch stuff


"""
Importage
"""
import torch
from torch.nn.parameter import Parameter
from whobpyt.datatypes.parameter import par
from whobpyt.datatypes.AbstractParams import AbstractParams
from whobpyt.datatypes.AbstractNMM import AbstractNMM
import numpy as np # for numerical operations

class ParamsRobinsonTime(AbstractParams):

def __init__(self, **kwargs):

param = {
"Q_max": par(250),
"sig_theta": par(15/1000),
"sigma": par(3.3/1000),
"gamma": par(100),
"beta": par(200),
"alpha": par(200/4),
"t0": par(0.08),
"g": par(100),
"nu_ee": par(0.0528/1000),
"nu_ii": par(0.0528/1000),
"nu_ie": par(0.02/1000),
"nu_es": par(1.2/1000),
"nu_is": par(1.2/1000),
"nu_se": par(1.2/1000),
"nu_si": par(0.0),
"nu_ei": par(0.4/1000),
"nu_sr": par(0.01/1000),
"nu_sn": par(0.0),
"nu_re": par(0.1/1000),
"nu_ri": par(0.0),
"nu_rs": par(0.1/1000),
"nu_ss": par(0.0),
"nu_rr": par(0.0),
"nu_rn": par(0.0),
"mu": par(5),
"cy0": par(5),
"y0": par(2)
}

for var in param:
setattr(self, var, param[var])

for var in kwargs:
setattr(self, var, kwargs[var])
Loading
Loading