diff --git a/kilosort/postprocessing.py b/kilosort/postprocessing.py index 466fa7d5..a7286aec 100644 --- a/kilosort/postprocessing.py +++ b/kilosort/postprocessing.py @@ -81,6 +81,7 @@ def make_pc_features(ops, spike_templates, spike_clusters, tF): # xy: template centers, iC: channels associated with each template xy, iC = xy_templates(ops) + n_templates = iC.shape[1] n_clusters = np.unique(spike_clusters).size n_chans = ops['nearest_chans'] feature_ind = np.zeros((n_clusters, n_chans), dtype=np.uint32) @@ -89,7 +90,7 @@ def make_pc_features(ops, spike_templates, spike_clusters, tF): # Get templates associated with cluster (often just 1) iunq = np.unique(spike_templates[spike_clusters==i]).astype(int) # Get boolean mask with size (n_templates,), True if they match cluster - ix = torch.from_numpy(np.zeros(int(spike_templates.max())+1, bool)) + ix = torch.from_numpy(np.zeros(n_templates, bool)) ix[iunq] = True # Get PC features for all spikes detected with those templates (Xd), # and the indices in tF where those spikes occur (igood).