Skip to content

Commit

Permalink
Merge branch 'main' of github.com:MouseLand/Kilosort
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Oct 29, 2024
2 parents dc56fcd + 9bdd8f2 commit 5117a30
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions kilosort/spikedetect.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from io import StringIO
import os
import logging
import warnings
Expand Down Expand Up @@ -118,8 +117,8 @@ def template_centers(ops):
nx = np.round((xmax - xmin) / (dminx/2)) + 1
xup = np.concatenate([xup, np.linspace(xmin, xmax, int(nx))])

ops['yup'] = yup
ops['xup'] = xup
ops['yup'] = np.unique(yup)
ops['xup'] = np.unique(xup)

# Set max channel distance based on dmin, dminx, use whichever is greater.
if ops.get('max_channel_distance', None) is None:
Expand Down Expand Up @@ -222,8 +221,8 @@ def run(ops, bfile, device=torch.device('cuda'), progress_bar=None,
ops = template_centers(ops)
[ys, xs] = np.meshgrid(ops['yup'], ops['xup'])
ys, xs = ys.flatten(), xs.flatten()
logger.info(f'Number of universal templates: {ys.size}')
xc, yc = ops['xc'], ops['yc']
Nfilt = len(ys)

nC = ops['settings']['nearest_chans']
nC2 = ops['settings']['nearest_templates']
Expand All @@ -238,7 +237,7 @@ def run(ops, bfile, device=torch.device('cuda'), progress_bar=None,
xs = xs[igood]
ops['ycup'], ops['xcup'] = ys, xs

iC2, ds2 = nearest_chans(ys, ys, xs, xs, nC2, device=device)
iC2, _ = nearest_chans(ys, ys, xs, xs, nC2, device=device)

ds_torch = torch.from_numpy(ds).to(device).float()
template_sizes = sig * (1+torch.arange(nsizes, device=device))
Expand All @@ -252,7 +251,6 @@ def run(ops, bfile, device=torch.device('cuda'), progress_bar=None,
k = 0
nt = ops['nt']
tarange = torch.arange(-(nt//2),nt//2+1, device = device)
s = StringIO()
logger.info('Detecting spikes...')
prog = tqdm(np.arange(bfile.n_batches), miniters=200 if progress_bar else None,
mininterval=60 if progress_bar else None)
Expand All @@ -266,7 +264,7 @@ def run(ops, bfile, device=torch.device('cuda'), progress_bar=None,
yct = yweighted(yc, iC, adist, xy, device=device)
nsp = len(xy)

if k+nsp>st.shape[0] :
if k+nsp>st.shape[0]:
st = np.concatenate((st, np.zeros_like(st)), 0)
tF = np.concatenate((tF, np.zeros_like(tF)), 0)

Expand Down

0 comments on commit 5117a30

Please sign in to comment.