Skip to content

Commit

Permalink
added PLV and spike rate code
Browse files Browse the repository at this point in the history
  • Loading branch information
GregGlickert committed Oct 16, 2024
1 parent 259c146 commit b2e690f
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 2 deletions.
141 changes: 140 additions & 1 deletion bmtool/analysis/lfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from fooof.sim.gen import gen_model
import matplotlib.pyplot as plt
from scipy import signal
import pywt
from bmtool.bmplot import is_notebook


def load_ecp_to_xarray(ecp_file: str, demean: bool = False) -> xr.DataArray:
Expand Down Expand Up @@ -209,7 +211,10 @@ def set_range(x, upper=f[-1]):
plt.gcf().set_size_inches(figsize)
if title:
plt.title(title)
plt.show()
if is_notebook():
pass
else:
plt.show()

return results, fm

Expand Down Expand Up @@ -266,3 +271,137 @@ def calculate_SNR(fooof_model: FOOOF, freq_band: tuple) -> float:
ap_power = np.trapz(band_ap, band_freq) # Integrate aperiodic power
normalized_power = periodic_power / ap_power # Compute the SNR
return normalized_power


def wavelet_filter(x: np.ndarray, freq: float, fs: float, bandwidth: float = 1.0, axis: int = -1) -> np.ndarray:
"""
Compute the Continuous Wavelet Transform (CWT) for a specified frequency using a complex Morlet wavelet.
"""
wavelet = 'cmor' + str(2 * bandwidth ** 2) + '-1.0'
scale = pywt.scale2frequency(wavelet, 1) * fs / freq
x_a = pywt.cwt(x, [scale], wavelet=wavelet, axis=axis)[0][0]
return x_a


def butter_bandpass_filter(data: np.ndarray, lowcut: float, highcut: float, fs: float, order: int = 5, axis: int = -1) -> np.ndarray:
"""
Apply a Butterworth bandpass filter to the input data.
"""
sos = signal.butter(order, [lowcut, highcut], fs=fs, btype='band', output='sos')
x_a = signal.sosfiltfilt(sos, data, axis=axis)
return x_a


def calculate_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: float = None,
method: str = 'wavelet', lowcut: float = None, highcut: float = None,
bandwidth: float = 2.0) -> np.ndarray:
"""
Calculate Phase Locking Value (PLV) between two signals using wavelet or Hilbert method.
Parameters:
- x1, x2: Input signals (1D arrays, same length)
- fs: Sampling frequency
- freq_of_interest: Desired frequency for wavelet PLV calculation
- method: 'wavelet' or 'hilbert' to choose the PLV calculation method
- lowcut, highcut: Cutoff frequencies for the Hilbert method
- bandwidth: Bandwidth parameter for the wavelet
Returns:
- plv: Phase Locking Value (1D array)
"""
if len(x1) != len(x2):
raise ValueError("Input signals must have the same length.")

if method == 'wavelet':
if freq_of_interest is None:
raise ValueError("freq_of_interest must be provided for the wavelet method.")

# Apply CWT to both signals
theta1 = wavelet_filter(x=x1, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)
theta2 = wavelet_filter(x=x2, freq=freq_of_interest, fs=fs, bandwidth=bandwidth)

elif method == 'hilbert':
if lowcut is None or highcut is None:
raise ValueError("lowcut and highcut must be provided for the Hilbert method.")

# Bandpass filter and get the analytic signal using the Hilbert transform
filtered_x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
filtered_x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)

# Get phase using the Hilbert transform
theta1 = np.angle(signal.hilbert(filtered_x1))
theta2 = np.angle(signal.hilbert(filtered_x2))

else:
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")

# Calculate phase difference
phase_diff = np.angle(theta1) - np.angle(theta2)

# Calculate PLV from standard equation from Measuring phase synchrony in brain signals(1999)
plv = np.abs(np.mean(np.exp(1j * phase_diff), axis=-1))

return plv


def calculate_plv_over_time(x1: np.ndarray, x2: np.ndarray, fs: float,
window_size: float, step_size: float,
method: str = 'wavelet', freq_of_interest: float = None,
lowcut: float = None, highcut: float = None,
bandwidth: float = 2.0):
"""
Calculate the time-resolved Phase Locking Value (PLV) between two signals using a sliding window approach.
Parameters:
----------
x1, x2 : array-like
Input signals (1D arrays, same length).
fs : float
Sampling frequency of the input signals.
window_size : float
Length of the window in seconds for PLV calculation.
step_size : float
Step size in seconds to slide the window across the signals.
method : str, optional
Method to calculate PLV ('wavelet' or 'hilbert'). Defaults to 'wavelet'.
freq_of_interest : float, optional
Frequency of interest for the wavelet method. Required if method is 'wavelet'.
lowcut, highcut : float, optional
Cutoff frequencies for the Hilbert method. Required if method is 'hilbert'.
bandwidth : float, optional
Bandwidth parameter for the wavelet. Defaults to 2.0.
Returns:
-------
plv_over_time : 1D array
Array of PLV values calculated over each window.
times : 1D array
The center times of each window where the PLV was calculated.
"""
# Convert window and step size from seconds to samples
window_samples = int(window_size * fs)
step_samples = int(step_size * fs)

# Initialize results
plv_over_time = []
times = []

# Iterate over the signal with a sliding window
for start in range(0, len(x1) - window_samples + 1, step_samples):
end = start + window_samples
window_x1 = x1[start:end]
window_x2 = x2[start:end]

# Use the updated calculate_plv function within each window
plv = calculate_plv(x1=window_x1, x2=window_x2, fs=fs,
method=method, freq_of_interest=freq_of_interest,
lowcut=lowcut, highcut=highcut, bandwidth=bandwidth)
plv_over_time.append(plv)

# Store the time at the center of the window
center_time = (start + end) / 2 / fs
times.append(center_time)

return np.array(plv_over_time), np.array(times)


112 changes: 112 additions & 0 deletions bmtool/analysis/spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import h5py
import pandas as pd
from bmtool.util.util import load_nodes_from_config
from typing import Dict, Optional,Tuple, Union
import numpy as np
import os

def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True,config: str = None) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -35,3 +38,112 @@ def load_spikes_to_df(spike_file: str, network_name: str, sort: bool = True,conf
spikes_df = spikes_df.merge(nodes['pop_name'], left_on='node_ids', right_index=True, how='left')

return spikes_df


def pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[float, float, float]] = None,
time_points: Optional[Union[np.ndarray, list]] = None, frequeny: bool = False) -> np.ndarray:
"""
Calculate the spike count or frequency histogram over specified time intervals.
Args:
spike_times (Union[np.ndarray, list]): Array or list of spike times in milliseconds.
time (Optional[Tuple[float, float, float]], optional): Tuple specifying (start, stop, step) in milliseconds.
Used to create evenly spaced time points if `time_points` is not provided. Default is None.
time_points (Optional[Union[np.ndarray, list]], optional): Array or list of specific time points for binning.
If provided, `time` is ignored. Default is None.
frequeny (bool, optional): If True, returns spike frequency in Hz; otherwise, returns spike count. Default is False.
Returns:
np.ndarray: Array of spike counts or frequencies, depending on the `frequeny` flag.
Raises:
ValueError: If both `time` and `time_points` are None.
"""
if time_points is None:
if time is None:
raise ValueError("Either `time` or `time_points` must be provided.")
time_points = np.arange(*time)
dt = time[2]
else:
time_points = np.asarray(time_points).ravel()
dt = (time_points[-1] - time_points[0]) / (time_points.size - 1)

bins = np.append(time_points, time_points[-1] + dt)
spike_rate, _ = np.histogram(np.asarray(spike_times), bins)

if frequeny:
spike_rate = 1000 / dt * spike_rate

return spike_rate



def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None,
save: bool = False, save_path: Optional[str] = None) -> Dict[str, np.ndarray]:
"""
Calculate the population spike rate for each population in the given spike data.
Args:
spikes (pd.DataFrame): A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'.
fs (float, optional): Sampling frequency in Hz. Default is 400.
t_start (float, optional): Start time (in milliseconds) for spike rate calculation. Default is 0.
t_stop (Optional[float], optional): Stop time (in milliseconds) for spike rate calculation. If None, it defaults to the maximum timestamp in the data. Default is None.
save (bool, optional): Whether to save the population spike rate to a file. Default is False.
save_path (Optional[str], optional): Directory path where the file should be saved if `save` is True. Default is None.
Returns:
Dict[str, np.ndarray]: A dictionary where keys are population names, and values are arrays of spike rates.
Raises:
ValueError: If `save` is True but `save_path` is not provided.
"""
pop_spikes = {} # Dictionary to store filtered spike data by population
node_number = {} # Dictionary to store the number of unique nodes for each population

print("Note: Node number is obtained by counting unique node spikes in the network. If the network did not run for a sufficient duration, and not all cells fired, this count might be incorrect.")

for pop_name in spikes['pop_name'].unique():
# Get the number of cells for each population by counting unique node IDs in the spike data.
# This approach assumes the simulation ran long enough for all cells to fire.
ps = spikes[spikes['pop_name'] == pop_name]
node_number[pop_name] = ps['node_ids'].nunique()

# Set `t_stop` to the maximum timestamp if not specified
if t_stop is None:
t_stop = spikes['timestamps'].max()

# Filter spikes by population name and timestamp range
filtered_spikes = spikes[
(spikes['pop_name'] == pop_name) &
(spikes['timestamps'] > t_start) &
(spikes['timestamps'] < t_stop)
]
pop_spikes[pop_name] = filtered_spikes

# Generate time array for calculating spike rates
time = np.array([t_start, t_stop, 1000 / fs])

# Calculate the population spike rate for each population
pop_rspk = {p: pop_spike_rate(spk['timestamps'], time) for p, spk in pop_spikes.items()}

# Adjust spike rate by the number of cells in each population
spike_rate = {p: fs / node_number[p] * pop_rspk[p] for p in pop_rspk}

# Save results to file if required
if save:
if save_path is None:
raise ValueError("save_path must be provided if save is True.")

# Create directory if it does not exist
os.makedirs(save_path, exist_ok=True)

# Define the save file path and write data to an HDF5 file
save_file = os.path.join(save_path, 'spike_rate.h5')
with h5py.File(save_file, 'w') as f:
f.create_dataset('time', data=time)
grp = f.create_group('populations')
for p, rspk in spike_rate.items():
pop_grp = grp.create_group(p)
pop_grp.create_dataset('data', data=rspk)

return spike_rate
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="bmtool",
version='0.5.7.2',
version='0.5.8',
author="Neural Engineering Laboratory at the University of Missouri",
author_email="gregglickert@mail.missouri.edu",
description="BMTool",
Expand Down

0 comments on commit b2e690f

Please sign in to comment.