diff --git a/kilosort/spikedetect.py b/kilosort/spikedetect.py index 3698cc5..30df0f9 100644 --- a/kilosort/spikedetect.py +++ b/kilosort/spikedetect.py @@ -117,8 +117,8 @@ def template_centers(ops): nx = np.round((xmax - xmin) / (dminx/2)) + 1 xup = np.concatenate([xup, np.linspace(xmin, xmax, int(nx))]) - ops['yup'] = yup - ops['xup'] = xup + ops['yup'] = np.unique(yup) + ops['xup'] = np.unique(xup) # Set max channel distance based on dmin, dminx, use whichever is greater. if ops.get('max_channel_distance', None) is None: @@ -219,7 +219,7 @@ def run(ops, bfile, device=torch.device('cuda'), progress_bar=None, ops['wPCA'], ops['wTEMP'] = get_waves(ops, device=device) ops = template_centers(ops) - [ys, xs] = np.meshgrid(np.unique(ops['yup']), np.unique(ops['xup'])) + [ys, xs] = np.meshgrid(ops['yup'], ops['xup']) ys, xs = ys.flatten(), xs.flatten() logger.info(f'Number of universal templates: {ys.size}') xc, yc = ops['xc'], ops['yc']