Skip to content

Commit

Permalink
Updated template placment to handle shanks separately
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobpennington committed Jun 7, 2024
1 parent 1c8828e commit fc78f34
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 14 deletions.
3 changes: 2 additions & 1 deletion kilosort/gui/probe_view_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ def set_active_layout(self, probe, template_args):
def get_template_spots(self, nC, dmin, dminx, max_dist, x_centers, device):
ops = {
'yc': self.yc, 'xc': self.xc, 'max_channel_distance': max_dist,
'x_centers': x_centers, 'settings': {'dmin': dmin, 'dminx': dminx}
'x_centers': x_centers, 'settings': {'dmin': dmin, 'dminx': dminx},
'kcoords': self.kcoords
}
ops = template_centers(ops)
[ys, xs] = np.meshgrid(ops['yup'], ops['xup'])
Expand Down
27 changes: 19 additions & 8 deletions kilosort/spikedetect.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,30 @@ def get_waves(ops, device=torch.device('cuda')):
return wPCA, wTEMP

def template_centers(ops):
xmin, xmax, ymin, ymax = ops['xc'].min(), ops['xc'].max(), \
ops['yc'].min(), ops['yc'].max()

shank_idx = ops['kcoords']
xc = ops['xc']
yc = ops['yc']
dmin = ops['settings']['dmin']
if dmin is None:
# Try to determine a good value automatically based on contact positions.
dmin = np.median(np.diff(np.unique(ops['yc'])))
dmin = np.median(np.diff(np.unique(yc)))
ops['dmin'] = dmin
ops['yup'] = np.arange(ymin, ymax+.00001, dmin/2)

ops['dminx'] = dminx = ops['settings']['dminx']
nx = np.round((xmax - xmin) / (dminx/2)) + 1
ops['xup'] = np.linspace(xmin, xmax, int(nx))

# Iteratively determine template placement for each shank separately.
yup = np.array([])
xup = np.array([])
for i in np.unique(shank_idx):
xc_i = xc[shank_idx == i]
yc_i = yc[shank_idx == i]
xmin, xmax, ymin, ymax = xc_i.min(), xc_i.max(), yc_i.min(), yc_i.max()

yup = np.concatenate([yup, np.arange(ymin, ymax+.00001, dmin/2)])
nx = np.round((xmax - xmin) / (dminx/2)) + 1
xup = np.concatenate([xup, np.linspace(xmin, xmax, int(nx))])

ops['yup'] = yup
ops['xup'] = xup

# Set max channel distance based on dmin, dminx, use whichever is greater.
if ops.get('max_channel_distance', None) is None:
Expand Down
13 changes: 8 additions & 5 deletions tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,36 +10,39 @@
def test_dmin():
settings = {'dmin': None, 'dminx': 32}
ops = {'xc': np.array([10, 20, 30]), 'yc': np.array([40, 40, 60]),
'settings': settings}
'settings': settings, 'kcoords': np.array([0, 0, 0])}
ops = template_centers(ops)
assert ops['dmin'] is not None # set based on xc, yc
assert ops['dminx'] is not None
assert ops['settings']['dmin'] is None # shouldn't change

# Neuropixels 1 3B1 (4 columns with stagger)
np1_probe = load_probe(PROBE_DIR / 'neuropixPhase3B1_kilosortChanMap.mat')
ops = {'xc': np1_probe['xc'], 'yc': np1_probe['yc'], 'settings': settings}
ops = {'xc': np1_probe['xc'], 'yc': np1_probe['yc'],
'kcoords': np1_probe['kcoords'], 'settings': settings}
ops = template_centers(ops)
assert ops['dmin'] == 20 # Median vertical spacing of contacts
assert ops['xup'].size == 4 # Number of lateral pos for universal templates

# Just one shank of NP2
np2_probe = load_probe(PROBE_DIR / 'NP2_kilosortChanMap.mat')
ops = {'xc': np2_probe['xc'], 'yc': np2_probe['yc'], 'settings': settings}
ops = {'xc': np2_probe['xc'], 'yc': np2_probe['yc'],
'kcoords': np2_probe['kcoords'], 'settings': settings}
ops = template_centers(ops)
assert ops['dmin'] == 15
assert ops['xup'].size == 3

# Linear probe
lin_probe = load_probe(PROBE_DIR / 'Linear16x1_kilosortChanMap.mat')
ops = {'xc': lin_probe['xc'], 'yc': lin_probe['yc']*20, 'settings': settings}
ops = {'xc': lin_probe['xc'], 'yc': lin_probe['yc']*20,
'kcoords': lin_probe['kcoords'], 'settings': settings}
ops = template_centers(ops)
assert ops['dmin'] == 20
assert ops['xup'].size == 1

settings = {'dmin': 5, 'dminx': 7}
ops = {'xc': np.array([10, 20, 30]), 'yc': np.array([40, 40, 60]),
'settings': settings}
'settings': settings, 'kcoords': np.array([0, 0, 0])}
ops = template_centers(ops)
assert ops['dmin'] == 5
assert ops['dminx'] == 7

0 comments on commit fc78f34

Please sign in to comment.