diff --git a/bmtool/analysis/lfp.py b/bmtool/analysis/lfp.py index af35772..a5f5276 100644 --- a/bmtool/analysis/lfp.py +++ b/bmtool/analysis/lfp.py @@ -322,15 +322,16 @@ def calculate_plv(x1: np.ndarray, x2: np.ndarray, fs: float, freq_of_interest: f elif method == 'hilbert': if lowcut is None or highcut is None: - raise ValueError("lowcut and highcut must be provided for the Hilbert method.") + print("Lowcut and or highcut were not definded, signal will not be filter and just take hilbert transform for plv calc") - # Bandpass filter and get the analytic signal using the Hilbert transform - filtered_x1 = butter_bandpass_filter(x1, lowcut, highcut, fs) - filtered_x2 = butter_bandpass_filter(x2, lowcut, highcut, fs) + if lowcut and highcut: + # Bandpass filter and get the analytic signal using the Hilbert transform + x1 = butter_bandpass_filter(x1, lowcut, highcut, fs) + x2 = butter_bandpass_filter(x2, lowcut, highcut, fs) # Get phase using the Hilbert transform - theta1 = np.angle(signal.hilbert(filtered_x1)) - theta2 = np.angle(signal.hilbert(filtered_x2)) + theta1 = signal.hilbert(x1) + theta2 = signal.hilbert(x2) else: raise ValueError("Invalid method. Choose 'wavelet' or 'hilbert'.") diff --git a/bmtool/analysis/spikes.py b/bmtool/analysis/spikes.py index 0e04392..d6095dd 100644 --- a/bmtool/analysis/spikes.py +++ b/bmtool/analysis/spikes.py @@ -77,42 +77,59 @@ def pop_spike_rate(spike_times: Union[np.ndarray, list], time: Optional[Tuple[fl return spike_rate - def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: float = 0, t_stop: Optional[float] = None, - save: bool = False, save_path: Optional[str] = None) -> Dict[str, np.ndarray]: + config: Optional[str] = None, network_name: Optional[str] = None, + save: bool = False, save_path: Optional[str] = None, + normalize: bool = False) -> Dict[str, np.ndarray]: """ - Calculate the population spike rate for each population in the given spike data. + Calculate the population spike rate for each population in the given spike data, with an option to normalize. Args: spikes (pd.DataFrame): A DataFrame containing spike data with columns 'pop_name', 'timestamps', and 'node_ids'. - fs (float, optional): Sampling frequency in Hz. Default is 400. + fs (float, optional): Sampling frequency in Hz, which determines the time bin size for calculating the spike rate. Default is 400. t_start (float, optional): Start time (in milliseconds) for spike rate calculation. Default is 0. - t_stop (Optional[float], optional): Stop time (in milliseconds) for spike rate calculation. If None, it defaults to the maximum timestamp in the data. Default is None. - save (bool, optional): Whether to save the population spike rate to a file. Default is False. - save_path (Optional[str], optional): Directory path where the file should be saved if `save` is True. Default is None. + t_stop (Optional[float], optional): Stop time (in milliseconds) for spike rate calculation. If None, defaults to the maximum timestamp in the data. + config (Optional[str], optional): Path to a configuration file containing node information, used to determine the correct number of nodes per population. + If None, node count is estimated from unique node spikes. Default is None. + network_name (Optional[str], optional): Name of the network used in the configuration file, allowing selection of nodes for that network. + Required if `config` is provided. Default is None. + save (bool, optional): Whether to save the calculated population spike rate to a file. Default is False. + save_path (Optional[str], optional): Directory path where the file should be saved if `save` is True. If `save` is True and `save_path` is None, a ValueError is raised. + normalize (bool, optional): Whether to normalize the spike rates for each population to a range of [0, 1]. Default is False. Returns: - Dict[str, np.ndarray]: A dictionary where keys are population names, and values are arrays of spike rates. + Dict[str, np.ndarray]: A dictionary where keys are population names, and values are arrays representing the spike rate over time for each population. + If `normalize` is True, each population's spike rate is scaled to [0, 1]. Raises: ValueError: If `save` is True but `save_path` is not provided. + + Notes: + - If `config` is None, the function assumes all cells in each population have fired at least once; otherwise, the node count may be inaccurate. + - If normalization is enabled, each population's spike rate is scaled using Min-Max normalization based on its own minimum and maximum values. + """ - pop_spikes = {} # Dictionary to store filtered spike data by population - node_number = {} # Dictionary to store the number of unique nodes for each population + pop_spikes = {} + node_number = {} - print("Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, and not all cells fired, this count might be incorrect.") + if config is None: + print("Note: Node number is obtained by counting unique node spikes in the network.\nIf the network did not run for a sufficient duration, and not all cells fired, this count might be incorrect.") + print("You can provide a config to calculate the correct amount of nodes!") for pop_name in spikes['pop_name'].unique(): - # Get the number of cells for each population by counting unique node IDs in the spike data. - # This approach assumes the simulation ran long enough for all cells to fire. ps = spikes[spikes['pop_name'] == pop_name] - node_number[pop_name] = ps['node_ids'].nunique() + + if config: + nodes = load_nodes_from_config(config) + nodes = nodes[network_name] + nodes = nodes[nodes['pop_name'] == pop_name] + node_number[pop_name] = nodes.index.nunique() + else: + node_number[pop_name] = ps['node_ids'].nunique() - # Set `t_stop` to the maximum timestamp if not specified if t_stop is None: t_stop = spikes['timestamps'].max() - # Filter spikes by population name and timestamp range filtered_spikes = spikes[ (spikes['pop_name'] == pop_name) & (spikes['timestamps'] > t_start) & @@ -120,24 +137,20 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: ] pop_spikes[pop_name] = filtered_spikes - # Generate time array for calculating spike rates time = np.array([t_start, t_stop, 1000 / fs]) - - # Calculate the population spike rate for each population pop_rspk = {p: pop_spike_rate(spk['timestamps'], time) for p, spk in pop_spikes.items()} - - # Adjust spike rate by the number of cells in each population spike_rate = {p: fs / node_number[p] * pop_rspk[p] for p in pop_rspk} - # Save results to file if required + # Normalize each spike rate series if normalize=True + if normalize: + spike_rate = {p: (sr - sr.min()) / (sr.max() - sr.min()) for p, sr in spike_rate.items()} + if save: if save_path is None: raise ValueError("save_path must be provided if save is True.") - # Create directory if it does not exist os.makedirs(save_path, exist_ok=True) - # Define the save file path and write data to an HDF5 file save_file = os.path.join(save_path, 'spike_rate.h5') with h5py.File(save_file, 'w') as f: f.create_dataset('time', data=time) @@ -147,3 +160,4 @@ def get_population_spike_rate(spikes: pd.DataFrame, fs: float = 400.0, t_start: pop_grp.create_dataset('data', data=rspk) return spike_rate + diff --git a/bmtool/bmplot.py b/bmtool/bmplot.py index 5fcd08d..b3ade72 100644 --- a/bmtool/bmplot.py +++ b/bmtool/bmplot.py @@ -405,6 +405,76 @@ def connection_pair_histogram(**kwargs): tids = [] util.relation_matrix(config,nodes,edges,sources,targets,sids,tids,prepend_pop,relation_func=connection_pair_histogram,synaptic_info=synaptic_info) +def connection_distance(config: str,source: str,target: str, + source_cell_id: int,target_id_type: str) -> None: + """ + Plots the 3D spatial distribution of target nodes relative to a source node + and a histogram of distances from the source node to each target node. + + Parameters: + ---------- + config: (str) A BMTK simulation config + sources: (str) network name(s) to plot + targets: (str) network name(s) to plot + source_cell_id : (int) ID of the source cell for calculating distances to target nodes. + target_id_type : (str) A string to filter target nodes based off the target_query. + + """ + if not config: + raise Exception("config not defined") + if not source or not target: + raise Exception("Sources or targets not defined") + #if source != target: + #raise Exception("Code is setup for source and target to be the same! Look at source code for function to add feature") + + # Load nodes and edges based on config file + nodes, edges = util.load_nodes_edges_from_config(config) + + edge_network = source + "_to_" + target + node_network = source + + # Filter edges to obtain connections originating from the source node + edge = edges[edge_network] + edge = edge[edge['source_node_id'] == source_cell_id] + if target_id_type: + edge = edge[edge['target_query'].str.contains(target_id_type, na=False)] + + target_node_ids = edge['target_node_id'] + + # Filter nodes to obtain only the target and source nodes + node = nodes[node_network] + target_nodes = node.loc[node.index.isin(target_node_ids)] + source_node = node.loc[node.index == source_cell_id] + + # Calculate distances between source node and each target node + target_positions = target_nodes[['pos_x', 'pos_y', 'pos_z']].values + source_position = np.array([source_node['pos_x'], source_node['pos_y'], source_node['pos_z']]).ravel() # Ensure 1D shape + distances = np.linalg.norm(target_positions - source_position, axis=1) + + # Plot positions of source and target nodes in 3D space + fig = plt.figure(figsize=(8, 6)) + ax = fig.add_subplot(111, projection='3d') + + ax.scatter(target_nodes['pos_x'], target_nodes['pos_y'], target_nodes['pos_z'], c='blue', label="target cells") + ax.scatter(source_node['pos_x'], source_node['pos_y'], source_node['pos_z'], c='red', label="source cell") + + # Optional: Add text annotations for distances + # for i, distance in enumerate(distances): + # ax.text(target_nodes['pos_x'].iloc[i], target_nodes['pos_y'].iloc[i], target_nodes['pos_z'].iloc[i], + # f'{distance:.2f}', color='black', fontsize=8, ha='center') + + plt.legend() + plt.show() + + # Plot distances in a separate 2D plot + plt.figure(figsize=(8, 6)) + plt.hist(distances, bins=20, color='blue', edgecolor='black') + plt.xlabel("Distance") + plt.ylabel("Count") + plt.title("Distance from Source Node to Each Target Node") + plt.grid(True) + plt.show() + def edge_histogram_matrix(config=None,sources = None,targets=None,sids=None,tids=None,no_prepend_pop=None,edge_property = None,time = None,time_compare = None,report=None,title=None,save_file=None): """ write about function here diff --git a/setup.py b/setup.py index 3accac8..a0ef129 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="bmtool", - version='0.5.9', + version='0.5.9.5', author="Neural Engineering Laboratory at the University of Missouri", author_email="gregglickert@mail.missouri.edu", description="BMTool",