From 5c2178ec6734a8f0548eb824a41ccbd6a2cb33a4 Mon Sep 17 00:00:00 2001 From: jacobpennington Date: Mon, 22 Jul 2024 17:41:43 -0700 Subject: [PATCH] Fixed missing/innacurate docs in run_kilosort --- kilosort/io.py | 4 +-- kilosort/run_kilosort.py | 55 +++++++++++++++++++++++++++++++++------- 2 files changed, 48 insertions(+), 11 deletions(-) diff --git a/kilosort/io.py b/kilosort/io.py index 4637d148..3164ee2e 100644 --- a/kilosort/io.py +++ b/kilosort/io.py @@ -222,10 +222,10 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None, clu : np.ndarray 1D vector of cluster ids indicating which spike came from which cluster, same shape as `st[:,0]`. - tF : np.ndarray + tF : torch.Tensor PC features for each spike, with shape (n_spikes, nearest_chans, n_pcs) - Wall : np.ndarray + Wall : torch.Tensor PC feature representation of spike waveforms for each cluster, with shape (n_clusters, n_channels, n_pcs). probe : dict; optional. diff --git a/kilosort/run_kilosort.py b/kilosort/run_kilosort.py index 6f81cedf..bb837d50 100644 --- a/kilosort/run_kilosort.py +++ b/kilosort/run_kilosort.py @@ -107,10 +107,10 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None, clu : np.ndarray 1D vector of cluster ids indicating which spike came from which cluster, same shape as `st[:,0]`. - tF : np.ndarray + tF : torch.Tensor PC features for each spike, with shape (n_spikes, nearest_chans, n_pcs) - Wall : np.ndarray + Wall : torch.Tensor PC feature representation of spike waveforms for each cluster, with shape (n_clusters, n_channels, n_pcs). similar_templates : np.ndarray. @@ -477,8 +477,12 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None, Returns ------- ops : dict + Dictionary storing settings and results for all algorithmic steps. bfile : kilosort.io.BinaryFiltered Wrapped file object for handling data. + st0 : np.ndarray. + Intermediate spike times variable with 6 columns. This is only used + for generating the 'Drift Scatter' plot through the GUI. """ @@ -523,7 +527,7 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None, def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None): - """Run spike sorting algorithm and save intermediate results to `ops`. + """Detect spikes via template deconvolution. Parameters ---------- @@ -546,10 +550,12 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None): clu : np.ndarray 1D vector of cluster ids indicating which spike came from which cluster, same shape as `st`. - tF : np.ndarray - TODO - Wall : np.ndarray - TODO + tF : torch.Tensor + PC features for each spike, with shape + (n_spikes, nearest_chans, n_pcs) + Wall : torch.Tensor + PC feature representation of spike waveforms for each cluster, with shape + (n_clusters, n_channels, n_pcs). """ @@ -595,6 +601,37 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None): def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None): + """Cluster spikes using graph-based methods. + + Parameters + ---------- + st : np.ndarray + 3-column array of peak time (in samples), template, and amplitude for + each spike. + tF : torch.Tensor + PC features for each spike, with shape + (n_spikes, nearest_chans, n_pcs) + ops : dict + Dictionary storing settings and results for all algorithmic steps. + device : torch.device + Indicates whether `pytorch` operations should be run on cpu or gpu. + bfile : kilosort.io.BinaryFiltered + Wrapped file object for handling data. + tic0 : float; default=np.nan. + Start time of `run_kilosort`. + progress_bar : TODO; optional. + Informs `tqdm` package how to report progress, type unclear. + + Returns + ------- + clu : np.ndarray + 1D vector of cluster ids indicating which spike came from which cluster, + same shape as `st`. + Wall : torch.Tensor + PC feature representation of spike waveforms for each cluster, with shape + (n_clusters, n_channels, n_pcs). + + """ tic = time.time() logger.info(' ') logger.info('Final clustering') @@ -639,10 +676,10 @@ def save_sorting(ops, results_dir, st, clu, tF, Wall, imin, tic0=np.nan, clu : np.ndarray 1D vector of cluster ids indicating which spike came from which cluster, same shape as `st[:,0]`. - tF : np.ndarray + tF : torch.Tensor PC features for each spike, with shape (n_spikes, nearest_chans, n_pcs) - Wall : np.ndarray + Wall : torch.Tensor PC feature representation of spike waveforms for each cluster, with shape (n_clusters, n_channels, n_pcs). imin : int