Skip to content

Commit

Permalink
small bug fixes and analysis features
Browse files Browse the repository at this point in the history
  • Loading branch information
GregGlickert committed Nov 12, 2024
1 parent 7448314 commit 74e866c
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 31 deletions.
13 changes: 7 additions & 6 deletions bmtool/analysis/lfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,15 +322,16 @@ def calculate_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: f

elif method == 'hilbert':
if lowcut is None or highcut is None:
raise ValueError("lowcut and highcut must be provided for the Hilbert method.")
print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc")

# Bandpass filter and get the analytic signal using the Hilbert transform
filtered_x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
filtered_x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)
if lowcut and highcut:
# Bandpass filter and get the analytic signal using the Hilbert transform
x1 = butter_bandpass_filter(x1, lowcut, highcut, fs)
x2 = butter_bandpass_filter(x2, lowcut, highcut, fs)

# Get phase using the Hilbert transform
theta1 = np.angle(signal.hilbert(filtered_x1))
theta2 = np.angle(signal.hilbert(filtered_x2))
theta1 = signal.hilbert(x1)
theta2 = signal.hilbert(x2)

else:
raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.")
Expand Down
62 changes: 38 additions & 24 deletions bmtool/analysis/spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,67 +77,80 @@ def pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[fl
return spike_rate



def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None,
save: bool = False, save_path: Optional[str] = None) -> Dict[str, np.ndarray]:
config: Optional[str] = None, network_name: Optional[str] = None,
save: bool = False, save_path: Optional[str] = None,
normalize: bool = False) -> Dict[str, np.ndarray]:
"""
Calculate the population spike rate for each population in the given spike data.
Calculate the population spike rate for each population in the given spike data, with an option to normalize.
Args:
spikes (pd.DataFrame): A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'.
fs (float, optional): Sampling frequency in Hz. Default is 400.
fs (float, optional): Sampling frequency in Hz, which determines the time bin size for calculating the spike rate. Default is 400.
t_start (float, optional): Start time (in milliseconds) for spike rate calculation. Default is 0.
t_stop (Optional[float], optional): Stop time (in milliseconds) for spike rate calculation. If None, it defaults to the maximum timestamp in the data. Default is None.
save (bool, optional): Whether to save the population spike rate to a file. Default is False.
save_path (Optional[str], optional): Directory path where the file should be saved if `save` is True. Default is None.
t_stop (Optional[float], optional): Stop time (in milliseconds) for spike rate calculation. If None, defaults to the maximum timestamp in the data.
config (Optional[str], optional): Path to a configuration file containing node information, used to determine the correct number of nodes per population.
If None, node count is estimated from unique node spikes. Default is None.
network_name (Optional[str], optional): Name of the network used in the configuration file, allowing selection of nodes for that network.
Required if `config` is provided. Default is None.
save (bool, optional): Whether to save the calculated population spike rate to a file. Default is False.
save_path (Optional[str], optional): Directory path where the file should be saved if `save` is True. If `save` is True and `save_path` is None, a ValueError is raised.
normalize (bool, optional): Whether to normalize the spike rates for each population to a range of [0, 1]. Default is False.
Returns:
Dict[str, np.ndarray]: A dictionary where keys are population names, and values are arrays of spike rates.
Dict[str, np.ndarray]: A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population.
If `normalize` is True, each population's spike rate is scaled to [0, 1].
Raises:
ValueError: If `save` is True but `save_path` is not provided.
Notes:
- If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate.
- If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values.
"""
pop_spikes = {} # Dictionary to store filtered spike data by population
node_number = {} # Dictionary to store the number of unique nodes for each population
pop_spikes = {}
node_number = {}

print("Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, and not all cells fired, this count might be incorrect.")
if config is None:
print("Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, and not all cells fired, this count might be incorrect.")
print("You can provide a config to calculate the correct amount of nodes!")

for pop_name in spikes['pop_name'].unique():
# Get the number of cells for each population by counting unique node IDs in the spike data.
# This approach assumes the simulation ran long enough for all cells to fire.
ps = spikes[spikes['pop_name'] == pop_name]
node_number[pop_name] = ps['node_ids'].nunique()

if config:
nodes = load_nodes_from_config(config)
nodes = nodes[network_name]
nodes = nodes[nodes['pop_name'] == pop_name]
node_number[pop_name] = nodes.index.nunique()
else:
node_number[pop_name] = ps['node_ids'].nunique()

# Set `t_stop` to the maximum timestamp if not specified
if t_stop is None:
t_stop = spikes['timestamps'].max()

# Filter spikes by population name and timestamp range
filtered_spikes = spikes[
(spikes['pop_name'] == pop_name) &
(spikes['timestamps'] > t_start) &
(spikes['timestamps'] < t_stop)
]
pop_spikes[pop_name] = filtered_spikes

# Generate time array for calculating spike rates
time = np.array([t_start, t_stop, 1000 / fs])

# Calculate the population spike rate for each population
pop_rspk = {p: pop_spike_rate(spk['timestamps'], time) for p, spk in pop_spikes.items()}

# Adjust spike rate by the number of cells in each population
spike_rate = {p: fs / node_number[p] * pop_rspk[p] for p in pop_rspk}

# Save results to file if required
# Normalize each spike rate series if normalize=True
if normalize:
spike_rate = {p: (sr - sr.min()) / (sr.max() - sr.min()) for p, sr in spike_rate.items()}

if save:
if save_path is None:
raise ValueError("save_path must be provided if save is True.")

# Create directory if it does not exist
os.makedirs(save_path, exist_ok=True)

# Define the save file path and write data to an HDF5 file
save_file = os.path.join(save_path, 'spike_rate.h5')
with h5py.File(save_file, 'w') as f:
f.create_dataset('time', data=time)
Expand All @@ -147,3 +160,4 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start:
pop_grp.create_dataset('data', data=rspk)

return spike_rate

70 changes: 70 additions & 0 deletions bmtool/bmplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,76 @@ def connection_pair_histogram(**kwargs):
tids = []
util.relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=connection_pair_histogram,synaptic_info=synaptic_info)

def connection_distance(config: str,source: str,target: str,
source_cell_id: int,target_id_type: str) -> None:
"""
Plots the 3D spatial distribution of target nodes relative to a source node
and a histogram of distances from the source node to each target node.
Parameters:
----------
config: (str) A BMTK simulation config
sources: (str) network name(s) to plot
targets: (str) network name(s) to plot
source_cell_id : (int) ID of the source cell for calculating distances to target nodes.
target_id_type : (str) A string to filter target nodes based off the target_query.
"""
if not config:
raise Exception("config not defined")
if not source or not target:
raise Exception("Sources or targets not defined")
#if source != target:
#raise Exception("Code is setup for source and target to be the same! Look at source code for function to add feature")

# Load nodes and edges based on config file
nodes, edges = util.load_nodes_edges_from_config(config)

edge_network = source + "_to_" + target
node_network = source

# Filter edges to obtain connections originating from the source node
edge = edges[edge_network]
edge = edge[edge['source_node_id'] == source_cell_id]
if target_id_type:
edge = edge[edge['target_query'].str.contains(target_id_type, na=False)]

target_node_ids = edge['target_node_id']

# Filter nodes to obtain only the target and source nodes
node = nodes[node_network]
target_nodes = node.loc[node.index.isin(target_node_ids)]
source_node = node.loc[node.index == source_cell_id]

# Calculate distances between source node and each target node
target_positions = target_nodes[['pos_x', 'pos_y', 'pos_z']].values
source_position = np.array([source_node['pos_x'], source_node['pos_y'], source_node['pos_z']]).ravel() # Ensure 1D shape
distances = np.linalg.norm(target_positions - source_position, axis=1)

# Plot positions of source and target nodes in 3D space
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')

ax.scatter(target_nodes['pos_x'], target_nodes['pos_y'], target_nodes['pos_z'], c='blue', label="target cells")
ax.scatter(source_node['pos_x'], source_node['pos_y'], source_node['pos_z'], c='red', label="source cell")

# Optional: Add text annotations for distances
# for i, distance in enumerate(distances):
# ax.text(target_nodes['pos_x'].iloc[i], target_nodes['pos_y'].iloc[i], target_nodes['pos_z'].iloc[i],
# f'{distance:.2f}', color='black', fontsize=8, ha='center')

plt.legend()
plt.show()

# Plot distances in a separate 2D plot
plt.figure(figsize=(8, 6))
plt.hist(distances, bins=20, color='blue', edgecolor='black')
plt.xlabel("Distance")
plt.ylabel("Count")
plt.title("Distance from Source Node to Each Target Node")
plt.grid(True)
plt.show()

def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids=None,no_prepend_pop=None,edge_property = None,time = None,time_compare = None,report=None,title=None,save_file=None):
"""
write about function here
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="bmtool",
version='0.5.9',
version='0.5.9.5',
author="Neural Engineering Laboratory at the University of Missouri",
author_email="gregglickert@mail.missouri.edu",
description="BMTool",
Expand Down

0 comments on commit 74e866c

Please sign in to comment.