Skip to content

Commit

Permalink
Fixed missing/innacurate docs in run_kilosort
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Jul 23, 2024
1 parent 35bc3cc commit 5c2178e
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 11 deletions.
4 changes: 2 additions & 2 deletions kilosort/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,10 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None,
clu : np.ndarray
1D vector of cluster ids indicating which spike came from which cluster,
same shape as `st[:,0]`.
tF : np.ndarray
tF : torch.Tensor
PC features for each spike, with shape
(n_spikes, nearest_chans, n_pcs)
Wall : np.ndarray
Wall : torch.Tensor
PC feature representation of spike waveforms for each cluster, with shape
(n_clusters, n_channels, n_pcs).
probe : dict; optional.
Expand Down
55 changes: 46 additions & 9 deletions kilosort/run_kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None,
clu : np.ndarray
1D vector of cluster ids indicating which spike came from which cluster,
same shape as `st[:,0]`.
tF : np.ndarray
tF : torch.Tensor
PC features for each spike, with shape
(n_spikes, nearest_chans, n_pcs)
Wall : np.ndarray
Wall : torch.Tensor
PC feature representation of spike waveforms for each cluster, with shape
(n_clusters, n_channels, n_pcs).
similar_templates : np.ndarray.
Expand Down Expand Up @@ -477,8 +477,12 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,
Returns
-------
ops : dict
Dictionary storing settings and results for all algorithmic steps.
bfile : kilosort.io.BinaryFiltered
Wrapped file object for handling data.
st0 : np.ndarray.
Intermediate spike times variable with 6 columns. This is only used
for generating the 'Drift Scatter' plot through the GUI.
"""

Expand Down Expand Up @@ -523,7 +527,7 @@ def compute_drift_correction(ops, device, tic0=np.nan, progress_bar=None,


def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
"""Run spike sorting algorithm and save intermediate results to `ops`.
"""Detect spikes via template deconvolution.
Parameters
----------
Expand All @@ -546,10 +550,12 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):
clu : np.ndarray
1D vector of cluster ids indicating which spike came from which cluster,
same shape as `st`.
tF : np.ndarray
TODO
Wall : np.ndarray
TODO
tF : torch.Tensor
PC features for each spike, with shape
(n_spikes, nearest_chans, n_pcs)
Wall : torch.Tensor
PC feature representation of spike waveforms for each cluster, with shape
(n_clusters, n_channels, n_pcs).
"""

Expand Down Expand Up @@ -595,6 +601,37 @@ def detect_spikes(ops, device, bfile, tic0=np.nan, progress_bar=None):


def cluster_spikes(st, tF, ops, device, bfile, tic0=np.nan, progress_bar=None):
"""Cluster spikes using graph-based methods.
Parameters
----------
st : np.ndarray
3-column array of peak time (in samples), template, and amplitude for
each spike.
tF : torch.Tensor
PC features for each spike, with shape
(n_spikes, nearest_chans, n_pcs)
ops : dict
Dictionary storing settings and results for all algorithmic steps.
device : torch.device
Indicates whether `pytorch` operations should be run on cpu or gpu.
bfile : kilosort.io.BinaryFiltered
Wrapped file object for handling data.
tic0 : float; default=np.nan.
Start time of `run_kilosort`.
progress_bar : TODO; optional.
Informs `tqdm` package how to report progress, type unclear.
Returns
-------
clu : np.ndarray
1D vector of cluster ids indicating which spike came from which cluster,
same shape as `st`.
Wall : torch.Tensor
PC feature representation of spike waveforms for each cluster, with shape
(n_clusters, n_channels, n_pcs).
"""
tic = time.time()
logger.info(' ')
logger.info('Final clustering')
Expand Down Expand Up @@ -639,10 +676,10 @@ def save_sorting(ops, results_dir, st, clu, tF, Wall, imin, tic0=np.nan,
clu : np.ndarray
1D vector of cluster ids indicating which spike came from which cluster,
same shape as `st[:,0]`.
tF : np.ndarray
tF : torch.Tensor
PC features for each spike, with shape
(n_spikes, nearest_chans, n_pcs)
Wall : np.ndarray
Wall : torch.Tensor
PC feature representation of spike waveforms for each cluster, with shape
(n_clusters, n_channels, n_pcs).
imin : int
Expand Down

0 comments on commit 5c2178e

Please sign in to comment.