diff --git a/kilosort/io.py b/kilosort/io.py index 1ee595dd..acef4246 100644 --- a/kilosort/io.py +++ b/kilosort/io.py @@ -162,8 +162,8 @@ def save_to_phy(st, clu, tF, Wall, probe, ops, imin, results_dir=None, np.save((results_dir / 'whitening_mat_inv.npy'), whitening_mat_inv) # spike properties - spike_times = st[:,0] + imin # shift by minimum sample index - spike_templates = st[:,1] + spike_times = st[:,0].astype('int64') + imin # shift by minimum sample index + spike_templates = st[:,1].astype('int32') spike_clusters = clu xs, ys = compute_spike_positions(st, tF, ops) spike_positions = np.vstack([xs, ys]).T diff --git a/kilosort/postprocessing.py b/kilosort/postprocessing.py index c9b6c82a..4244d883 100644 --- a/kilosort/postprocessing.py +++ b/kilosort/postprocessing.py @@ -4,13 +4,19 @@ import torch -@njit +@njit("(int64[:], int32[:], int32)") def remove_duplicates(spike_times, spike_clusters, dt=15): '''Removes same-cluster spikes that occur within `dt` samples.''' keep = np.zeros_like(spike_times, bool_) cluster_t0 = {} - for (i,t), c in zip(enumerate(spike_times), spike_clusters): - t0 = cluster_t0.get(c, t-dt) + for i in range(spike_times.size): + t = spike_times[i] + c = spike_clusters[i] + if c in cluster_t0: + t0 = cluster_t0[c] + else: + t0 = t - dt + if t >= (t0 + dt): # Separate spike, reset t0 and keep spike cluster_t0[c] = t