diff --git a/kilosort/clustering_qr.py b/kilosort/clustering_qr.py index 09064ac2..4d1f1ab1 100644 --- a/kilosort/clustering_qr.py +++ b/kilosort/clustering_qr.py @@ -1,12 +1,16 @@ -from io import StringIO import numpy as np import torch from torch import sparse_coo_tensor as coo -from scipy.sparse import csr_matrix +from scipy.sparse import csr_matrix +from scipy.ndimage.filters import gaussian_filter +from scipy.signal import find_peaks +from scipy.cluster.vq import kmeans import faiss from tqdm import tqdm + from kilosort import hierarchical, swarmsplitter + def neigh_mat(Xd, nskip=10, n_neigh=30): Xsub = Xd[::nskip] n_samples, dim = Xd.shape @@ -50,6 +54,7 @@ def assign_mu(iclust, Xg, cols_mu, tones, nclust = None, lpow = 1): return mu, N + def assign_iclust(rows_neigh, isub, kn, tones2, nclust, lam, m, ki, kj, device=torch.device('cuda')): NN = kn.shape[0] @@ -69,6 +74,7 @@ def assign_iclust(rows_neigh, isub, kn, tones2, nclust, lam, m, ki, kj, device=t return iclust + def assign_isub(iclust, kn, tones2, nclust, nsub, lam, m,ki,kj, device=torch.device('cuda')): n_neigh = kn.shape[1] cols = iclust.unsqueeze(-1).tile((1, n_neigh)) @@ -87,6 +93,7 @@ def assign_isub(iclust, kn, tones2, nclust, nsub, lam, m,ki,kj, device=torch.dev isub = torch.argmax(xS, 1) return isub + def Mstats(M, device=torch.device('cuda')): m = M.sum() ki = np.array(M.sum(1)).flatten() @@ -176,6 +183,7 @@ def kmeans_plusplus(Xg, niter = 200, seed = 1, device=torch.device('cuda')): return iclust + def compute_score(mu, mu2, N, ccN, lam): mu_pairs = ((N*mu).unsqueeze(1) + N*mu) / (1e-6 + N+N[:,0]).unsqueeze(-1) mu2_pairs = ((N*mu2).unsqueeze(1) + N*mu2) / (1e-6 + N+N[:,0]).unsqueeze(-1) @@ -189,20 +197,16 @@ def compute_score(mu, mu2, N, ccN, lam): score = (ccN + ccN.T) - lam * dexp return score -def run_one(Xd, st0, nskip = 20, lam = 0): +def run_one(Xd, st0, nskip = 20, lam = 0): iclust, iclust0, M = cluster(Xd,nskip = nskip, lam = 0, seed = 5) - xtree, tstat, my_clus = hierarchical.maketree(M, iclust, iclust0) - - xtree, tstat = swarmsplitter.split(Xd.numpy(), xtree, tstat, iclust, my_clus, meta = st0) - + xtree, tstat = swarmsplitter.split(Xd.numpy(), xtree, tstat, iclust, + my_clus, meta = st0) iclust1 = swarmsplitter.new_clusters(iclust, my_clus, xtree, tstat) return iclust1 -import time - def xy_templates(ops): iU = ops['iU'].cpu().numpy() @@ -214,24 +218,51 @@ def xy_templates(ops): iU = ops['iU'].cpu().numpy() iC = ops['iCC'][:, ops['iU']] + return xy, iC + def xy_up(ops): xcup, ycup = ops['xcup'], ops['ycup'] xy = np.vstack((xcup, ycup)) xy = torch.from_numpy(xy) - iC = ops['iC'] + return xy, iC -def xy_c(ops): - xcup, ycup = ops['xc'][::4], ops['yc'][::4] - xy = np.vstack((xcup, ycup+10)) - xy = torch.from_numpy(xy) - iC = ops['iC'] - #print(1) - return xy, iC +def x_centers(ops): + dminx = ops['dminx'] + min_x = ops['xc'].min() + max_x = ops['xc'].max() + + # Make histogram of x-positions with bin size roughly equal to dminx, + # with a bit of padding on either end of the probe so that peaks can be + # detected at edges. + num_bins = int((max_x-min_x)/(dminx)) + 4 + bins = np.linspace(min_x - dminx*2, max_x + dminx*2, num_bins) + hist, edges = np.histogram(ops['xc'], bins=bins) + # Apply smoothing to make peak-finding simpler. + smoothed = gaussian_filter(hist, sigma=0.5) + peaks, _ = find_peaks(smoothed) + # peaks are indices, translate back to position in microns + approx_centers = [edges[p] for p in peaks] + # Use these as initial guesses for centroids in k-means to get + # a more accurate value for the actual centers. Or, if there's only 1, + # just look for one centroid. + if len(approx_centers) == 1: approx_centers = 1 + centers, distortion = kmeans(ops['xc'], approx_centers) + + # TODO: Maybe use distortion to raise warning if it seems too large? + # "The mean (non-squared) Euclidean distance between the observations passed + # and the centroids generated. Note the difference to the standard definition + # of distortion in the context of the k-means algorithm, which is the sum of + # the squared distances." + + # For example, could raise a warning if this is greater than dminx*2? + # Most probes should satisfy that criteria. + + return centers def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_bar=None): @@ -240,83 +271,90 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_b xy, iC = xy_templates(ops) iclust_template = st[:,1].astype('int32') xcup, ycup = ops['xcup'], ops['ycup'] - elif mode == 'spikes_nn': - xy, iC = xy_c(ops) - xcup, ycup = ops['xc'][::4], ops['yc'][::4] - iclust_template = st[:,5].astype('int32') else: xy, iC = xy_up(ops) iclust_template = st[:,5].astype('int32') xcup, ycup = ops['xcup'], ops['ycup'] - d0 = ops['dmin'] - ycent = np.arange(ycup.min()+d0-1, ycup.max()+d0+1, 2*d0) - - nsp = st.shape[0] - clu = np.zeros(nsp, 'int32') - nmax = 0 - + dmin = ops['dmin'] + dminx = ops['dminx'] nskip = ops['settings']['cluster_downsampling'] ncomps = ops['settings']['cluster_pcs'] + ycent = np.arange(ycup.min()+dmin-1, ycup.max()+dmin+1, 2*dmin) + xcent = x_centers(ops) + nsp = st.shape[0] + # Get positions of all grouping centers + ycent_pos, xcent_pos = np.meshgrid(ycent, xcent) + ycent_pos = torch.from_numpy(ycent_pos.flatten()) + xcent_pos = torch.from_numpy(xcent_pos.flatten()) + # Compute distances from templates + center_distance = ( + (xy[0,:] - xcent_pos.unsqueeze(-1))**2 + + (xy[1,:] - ycent_pos.unsqueeze(-1))**2 + ) + # Add some randomness in case of ties + center_distance += 1e-20*torch.rand(center_distance.shape) + # Get flattened index of x-y center that is closest to template + minimum_distance = torch.min(center_distance, 0).indices + + clu = np.zeros(nsp, 'int32') Wall = torch.zeros((0, ops['Nchan'], ops['settings']['n_pcs'])) - t0 = time.time() nearby_chans_empty = 0 - for kk in tqdm(np.arange(len(ycent)), miniters=20 if progress_bar else None, mininterval=10 if progress_bar else None): - # get the data - #iclust_template = st[:,1].astype('int32') - - Xd, ch_min, ch_max, igood = get_data_cpu( - ops, xy, iC, iclust_template, tF, ycent[kk], xcup.mean(), dmin=d0, - dminx = ops['dminx'], ncomps=ncomps - ) + nmax = 0 - if Xd is None: - nearby_chans_empty += 1 - continue - - if Xd.shape[0]<1000: - #clu[igood] = nmax - #nmax += 1 - iclust = torch.zeros((Xd.shape[0],)) - else: - if mode == 'template': - st0 = st[igood,0]/ops['fs'] + for kk in tqdm(np.arange(len(ycent)), miniters=20 if progress_bar else None, + mininterval=10 if progress_bar else None): + for jj in np.arange(len(xcent)): + # Get data for all templates that were closest to this x,y center. + ii = ii = kk + jj*ycent.size + ix = (minimum_distance == ii) + Xd, ch_min, ch_max, igood = get_data_cpu( + ops, xy, iC, iclust_template, tF, ycent[kk], xcent[jj], dmin=dmin, + dminx=dminx, ncomps=ncomps, ix=ix + ) + + if Xd is None: + nearby_chans_empty += 1 + continue + elif Xd.shape[0]<1000: + iclust = torch.zeros((Xd.shape[0],)) else: - st0 = None + if mode == 'template': + st0 = st[igood,0]/ops['fs'] + else: + st0 = None - # find new clusters - iclust, iclust0, M, iclust_init = cluster(Xd, nskip=nskip, lam=1, - seed=5, device=device) + # find new clusters + iclust, iclust0, M, iclust_init = cluster(Xd, nskip=nskip, lam=1, + seed=5, device=device) - xtree, tstat, my_clus = hierarchical.maketree(M, iclust, iclust0) + xtree, tstat, my_clus = hierarchical.maketree(M, iclust, iclust0) - xtree, tstat = swarmsplitter.split(Xd.numpy(), xtree, tstat, iclust, my_clus, meta = st0) + xtree, tstat = swarmsplitter.split(Xd.numpy(), xtree, tstat, iclust, my_clus, meta = st0) - iclust = swarmsplitter.new_clusters(iclust, my_clus, xtree, tstat) + iclust = swarmsplitter.new_clusters(iclust, my_clus, xtree, tstat) - clu[igood] = iclust + nmax - Nfilt = int(iclust.max() + 1) - nmax += Nfilt + clu[igood] = iclust + nmax + Nfilt = int(iclust.max() + 1) + nmax += Nfilt - # we need the new templates here - W = torch.zeros((Nfilt, ops['Nchan'], ops['settings']['n_pcs'])) - for j in range(Nfilt): - w = Xd[iclust==j].mean(0) - W[j, ch_min:ch_max, :] = torch.reshape(w, (-1, ops['settings']['n_pcs'])).cpu() - - Wall = torch.cat((Wall, W), 0) + # we need the new templates here + W = torch.zeros((Nfilt, ops['Nchan'], ops['settings']['n_pcs'])) + for j in range(Nfilt): + w = Xd[iclust==j].mean(0) + W[j, ch_min:ch_max, :] = torch.reshape(w, (-1, ops['settings']['n_pcs'])).cpu() + + Wall = torch.cat((Wall, W), 0) - if progress_bar is not None: - progress_bar.emit(int((kk+1) / len(ycent) * 100)) - - if 0:#kk%50==0: - print(kk, nmax, time.time()-t0) + if progress_bar is not None: + progress_bar.emit(int((kk+1) / len(ycent) * 100)) + if nearby_chans_empty == len(ycent): raise ValueError( f'`get_data_cpu` never found suitable channels in `clustering_qr.run`.' - f'\ndmin, dminx, and xcenter are: {d0, ops["dminx"], xcup.mean()}' + f'\ndmin, dminx, and xcenter are: {dmin, dminx, xcup.mean()}' ) if Wall.sum() == 0: diff --git a/kilosort/gui/probe_view_box.py b/kilosort/gui/probe_view_box.py index 969794fd..5ec66c29 100644 --- a/kilosort/gui/probe_view_box.py +++ b/kilosort/gui/probe_view_box.py @@ -1,7 +1,11 @@ +from qtpy import QtCore, QtWidgets import numpy as np import pyqtgraph as pg + +from kilosort.spikedetect import template_centers, nearest_chans +from kilosort.clustering_qr import x_centers from kilosort.gui.logger import setup_logger -from qtpy import QtCore, QtGui, QtWidgets + logger = setup_logger(__name__) @@ -15,15 +19,24 @@ def __init__(self, parent): self.setTitle("Probe View") self.gui = parent self.probe_view = pg.PlotWidget() + self.template_toggle = QtWidgets.QCheckBox('Universal Templates') + self.center_toggle = QtWidgets.QCheckBox('Grouping Centers') + self.aspect_toggle = QtWidgets.QCheckBox('True Aspect Ratio') + self.spot_scale = QtWidgets.QSlider(QtCore.Qt.Horizontal) self.setup() self.active_layout = None self.kcoords = None self.xc = None self.yc = None + self.xcup = None + self.ycup = None self.total_channels = None self.channel_map = None self.channel_map_dict = {} + self.channel_spots = None + self.template_spots = None + self.center_spots = None self.sorting_status = { "preprocess": False, @@ -34,30 +47,79 @@ def __init__(self, parent): self.active_data_view_mode = "colormap" def setup(self): - self.probe_view.hideAxis("left") - self.probe_view.hideAxis("bottom") - self.probe_view.setMouseEnabled(False, True) + self.aspect_toggle.setCheckState(QtCore.Qt.CheckState.Unchecked) + self.aspect_toggle.stateChanged.connect(self.refresh_plot) + self.template_toggle.setCheckState(QtCore.Qt.CheckState.Unchecked) + self.template_toggle.stateChanged.connect(self.refresh_plot) + self.center_toggle.setCheckState(QtCore.Qt.CheckState.Unchecked) + self.center_toggle.stateChanged.connect(self.refresh_plot) + + self.spot_scale.setMinimum(0) + self.spot_scale.setMaximum(10) + self.spot_scale.setValue(1) + self.spot_scale.valueChanged.connect(self.refresh_plot) layout = QtWidgets.QVBoxLayout() layout.addWidget(self.probe_view, 95) + layout.addWidget(self.aspect_toggle) + layout.addWidget(self.template_toggle) + layout.addWidget(self.center_toggle) + layout.addWidget(self.spot_scale) self.setLayout(layout) def set_layout(self, context): self.probe_view.clear() probe = context.raw_probe - self.set_active_layout(probe) + template_args = self.gui.settings_box.get_probe_template_args() + self.set_active_layout(probe, template_args) self.update_probe_view() - def set_active_layout(self, probe): + def set_active_layout(self, probe, template_args): self.active_layout = probe self.kcoords = self.active_layout["kcoords"] self.xc, self.yc = self.active_layout["xc"], self.active_layout["yc"] + self.xcup, self.ycup, self.ops = self.get_template_spots(*template_args) + self.xcent_pos, self.ycent_pos = self.get_center_spots() self.channel_map_dict = {} for ind, (xc, yc) in enumerate(zip(self.xc, self.yc)): self.channel_map_dict[(xc, yc)] = ind self.total_channels = self.active_layout["n_chan"] self.channel_map = self.active_layout["chanMap"] + def get_template_spots(self, nC, dmin, dminx, max_dist, device): + ops = { + 'yc': self.yc, 'xc': self.xc, 'max_channel_distance': max_dist, + 'settings': {'dmin': dmin, 'dminx': dminx} + } + ops = template_centers(ops) + [ys, xs] = np.meshgrid(ops['yup'], ops['xup']) + ys, xs = ys.flatten(), xs.flatten() + iC, ds = nearest_chans(ys, self.yc, xs, self.xc, nC, device=device) + + igood = ds[0,:] <= ops['max_channel_distance']**2 + iC = iC[:,igood] + ds = ds[:,igood] + ys = ys[igood] + xs = xs[igood] + + return xs, ys, ops + + def get_center_spots(self): + dmin = self.ops['dmin'] + ycent = np.arange(self.ycup.min()+dmin-1, self.ycup.max()+dmin+1, 2*dmin) + xcent = x_centers(self.ops) + + ycent_pos, xcent_pos = np.meshgrid(ycent, xcent) + ycent_pos = ycent_pos.flatten() + xcent_pos = xcent_pos.flatten() + + return xcent_pos, ycent_pos + + @QtCore.Slot() + def refresh_plot(self): + template_args = self.gui.settings_box.get_probe_template_args() + self.preview_probe(self.gui.settings_box.probe_layout, template_args) + @QtCore.Slot(str, int) def synchronize_data_view_mode(self, mode: str): if self.active_data_view_mode != mode: @@ -69,35 +131,72 @@ def change_sorting_status(self, status_dict): self.sorting_status = status_dict def generate_spots_list(self): - spots = [] - size = 10 - symbol = "s" - - for x_pos, y_pos in zip(self.xc, self.yc): - pos = (x_pos, y_pos) - color = 'g' - pen = pg.mkPen(0.5) - brush = pg.mkBrush(color) - spots.append({ - 'pos': pos, 'size': size, 'pen': pen, 'brush': brush, - 'symbol': symbol + channel_spots = [] + template_spots = [] + center_spots = [] + + if self.xc is not None: + size = 10 * self.spot_scale.value() + symbol = "s" + color = "g" + for x_pos, y_pos in zip(self.xc, self.yc): + pen = pg.mkPen(0.5) + brush = pg.mkBrush(color) + channel_spots.append({ + 'pos': (x_pos, y_pos), 'size': size, 'pen': pen, 'brush': brush, + 'symbol': symbol + }) + self.channel_spots = channel_spots + + if self.xcup is not None: + size = 5 * self.spot_scale.value() + symbol = "o" + color = "w" + for x, y in zip(self.xcup, self.ycup): + pen = pg.mkPen(0.5) + brush = pg.mkBrush(color) + template_spots.append({ + 'pos': (x,y), 'size': size, 'pen': pen, 'brush': brush, + 'symbol': symbol + }) + self.template_spots = template_spots + + if self.xcent_pos is not None: + size = 20 * self.spot_scale.value() + symbol = "o" + color = "y" + for x, y in zip(self.xcent_pos, self.ycent_pos): + pen = pg.mkPen(color=color) + brush = None + center_spots.append({ + 'pos': (x,y), 'size': size, 'pen': pen, 'brush': brush, + 'symbol': symbol }) + self.center_spots = center_spots - return spots @QtCore.Slot(int, int) def update_probe_view(self): self.create_plot() - @QtCore.Slot(object) - def preview_probe(self, probe): + @QtCore.Slot(object, object) + def preview_probe(self, probe, template_args): self.probe_view.clear() - self.set_active_layout(probe) + self.set_active_layout(probe, template_args) self.create_plot() def create_plot(self): - spots = self.generate_spots_list() + self.generate_spots_list() + spots = self.channel_spots + if self.template_toggle.isChecked(): + spots += self.template_spots + if self.center_toggle.isChecked(): + spots += self.center_spots scatter_plot = pg.ScatterPlotItem(spots) + if self.aspect_toggle.isChecked(): + self.probe_view.setAspectLocked() + else: + self.probe_view.setAspectLocked(lock=False) self.probe_view.addItem(scatter_plot) def reset(self): diff --git a/kilosort/gui/settings_box.py b/kilosort/gui/settings_box.py index 49b23f8e..a471b1cb 100644 --- a/kilosort/gui/settings_box.py +++ b/kilosort/gui/settings_box.py @@ -21,7 +21,7 @@ class SettingsBox(QtWidgets.QGroupBox): settingsUpdated = QtCore.Signal() - previewProbe = QtCore.Signal(object) + previewProbe = QtCore.Signal(object, object) dataChanged = QtCore.Signal() def __init__(self, parent): @@ -474,7 +474,7 @@ def check_settings(self): if not self.check_valid_binary_path(self.data_file_path): return False - none_allowed = ['dmin', 'nt0min'] + none_allowed = ['dmin', 'nt0min', 'max_channel_distance'] for k, v in self.settings.items(): if v is None and k not in none_allowed: return False @@ -500,9 +500,17 @@ def update_settings(self): else: self.settingsUpdated.emit() + def get_probe_template_args(self): + epw = self.extra_parameters_window + template_args = [ + epw.nearest_chans, epw.dmin, epw.dminx, + epw.max_channel_distance, self.gui.device + ] + return template_args + @QtCore.Slot() def show_probe_layout(self): - self.previewProbe.emit(self.probe_layout) + self.previewProbe.emit(self.probe_layout, self.get_probe_template_args()) @QtCore.Slot(str) def on_probe_layout_selected(self, name): diff --git a/kilosort/parameters.py b/kilosort/parameters.py index c78e15dd..9b6d4f95 100644 --- a/kilosort/parameters.py +++ b/kilosort/parameters.py @@ -238,6 +238,17 @@ """ }, + 'max_channel_distance': { + 'gui_name': 'max channel distance', 'type': float, 'min': 1, + 'max': np.inf, 'exclude': [], 'default': None, 'step': 'spike detection', + 'description': + """ + Templates farther away than this from their nearest channel will + not be used. Also limits distance between compared channels during + clustering. + """ + }, + 'templates_from_data': { 'gui_name': 'templates from data', 'type': bool, 'min': None, 'max': None, 'exclude': [], 'default': True, 'step': 'spike detection', diff --git a/kilosort/run_kilosort.py b/kilosort/run_kilosort.py index 479a2e16..f8988c0a 100644 --- a/kilosort/run_kilosort.py +++ b/kilosort/run_kilosort.py @@ -104,9 +104,6 @@ def run_kilosort(settings, probe=None, probe_name=None, filename=None, ) settings = {**DEFAULT_SETTINGS, **settings} - if settings['nt0min'] is None: - settings['nt0min'] = int(20 * settings['nt']/61) - if data_dtype is None: print("Interpreting binary file as default dtype='int16'. If data was " "saved in a different format, specify `data_dtype`.") @@ -202,6 +199,8 @@ def set_files(settings, filename, probe, probe_name, data_dir, results_dir): def initialize_ops(settings, probe, data_dtype, do_CAR, invert_sign, device) -> dict: """Package settings and probe information into a single `ops` dictionary.""" + if settings['nt0min'] is None: + settings['nt0min'] = int(20 * settings['nt']/61) # TODO: Clean this up during refactor. Lots of confusing duplication here. ops = settings ops['settings'] = settings diff --git a/kilosort/spikedetect.py b/kilosort/spikedetect.py index 1d9dab84..54a66c43 100644 --- a/kilosort/spikedetect.py +++ b/kilosort/spikedetect.py @@ -98,6 +98,10 @@ def template_centers(ops): nx = np.round((xmax - xmin) / (dminx/2)) + 1 ops['xup'] = np.linspace(xmin, xmax, int(nx)) + # Set max channel distance based on dmin, dminx, use whichever is greater. + if ops.get('max_channel_distance', None) is None: + ops['max_channel_distance'] = max(dmin, dminx) + return ops @@ -161,8 +165,10 @@ def nearest_chans(ys, yc, xs, xc, nC, device=torch.device('cuda')): iC = np.argsort(ds, 0)[:nC] iC = torch.from_numpy(iC).to(device) ds = np.sort(ds, 0)[:nC] + return iC, ds + def yweighted(yc, iC, adist, xy, device=torch.device('cuda')): yy = torch.from_numpy(yc).to(device)[iC] @@ -190,17 +196,24 @@ def run(ops, bfile, device=torch.device('cuda'), progress_bar=None): ops['wPCA'], ops['wTEMP'] = get_waves(ops, device=device) ops = template_centers(ops) - [ys, xs] = np.meshgrid(ops['yup'], ops['xup']) ys, xs = ys.flatten(), xs.flatten() - ops['ycup'], ops['xcup'] = ys, xs - xc, yc = ops['xc'], ops['yc'] Nfilt = len(ys) nC = ops['settings']['nearest_chans'] nC2 = ops['settings']['nearest_templates'] iC, ds = nearest_chans(ys, yc, xs, xc, nC, device=device) + + # Don't use templates that are too far away from nearest channel + # (use square of max distance since ds are squared distances) + igood = ds[0,:] <= ops['max_channel_distance']**2 + iC = iC[:,igood] + ds = ds[:,igood] + ys = ys[igood] + xs = xs[igood] + ops['ycup'], ops['xcup'] = ys, xs + iC2, ds2 = nearest_chans(ys, ys, xs, xs, nC2, device=device) ds_torch = torch.from_numpy(ds).to(device).float() diff --git a/tests/test_clustering.py b/tests/test_clustering.py new file mode 100644 index 00000000..ffd1c810 --- /dev/null +++ b/tests/test_clustering.py @@ -0,0 +1,86 @@ +import numpy as np + +from kilosort.clustering_qr import x_centers +from kilosort.io import load_probe +from kilosort.utils import PROBE_DIR + + +def random_np2(n_chans=384, n_shanks=4): + # Generates a probe containing *all* neuropixels 2 contact positions, + # then randomly subsamples from those positions to get a probe layout + # corresponding to 384-channel output data. + + probe = {} + # 12um square contacts with 32um lateral spacing, + # 15um vertical spacing, + # 1280 contacts per shank + + # Want alternating 6um, 38um for lateral positions + xc0 = np.empty(1280) + xc0[::2] = 6 + xc0[1::2] = 38 + # Then add 250um for each additional shank + xc = np.concatenate([xc0 + (250*i) for i in range(4)]) + + # For vertical positions, start at 6 and increase by 15 + yc0 = (np.arange(640)*15) + 6 + # Each position appears twice (two columns on each shank) + yc0 = np.repeat(yc0, 2) + yc = np.concatenate([yc0 for i in range(4)]) + + # Repeat 0 1280 times, then repeat 1 1280 times, etc + kcoords = np.repeat(np.arange(4), 1280) + + # Pick n_chans out of n_shanks + shanks_used = np.random.choice(range(4), n_shanks, replace=False) + shank_indices = np.argwhere(np.isin(kcoords, shanks_used))[:,0] + contact_indices = np.random.choice(shank_indices, n_chans, replace=False) + + return {'xc': xc[contact_indices], 'yc': yc[contact_indices]} + + +class TestCenters: + ops = {'dminx': 32} + + def __init__(self, data_directory): + # This is just here to make sure probes are downloaded before these + # tests are run. + pass + + def test_linear(self): + self.ops['probe'] = load_probe(PROBE_DIR/'Linear16x1_kilosortChanMap.mat') + centers = x_centers(self.ops) + # X positions are all 1um + assert len(centers) == 1 + assert np.abs(centers[0] - 1) < 5 + + def test_np1(self): + self.ops['probe'] = load_probe(PROBE_DIR/'neuropixPhase3B1_kilosortChanMap.mat') + centers = x_centers(self.ops) + # One shank from 11um to 59um, should be 1 center near 35um + assert len(centers) == 1 + assert np.abs(centers[0] - 35) < 5 + + def test_np2_1shank(self): + self.ops['probe'] = load_probe(PROBE_DIR/'NP2_kilosortChanMap.mat') + centers = x_centers(self.ops) + # One shank from 0 to 32um, should be 1 center near 16um + assert len(centers) == 1 + assert np.abs(centers[0] - 16) < 5 + + def test_np2_3shank(self): + self.ops['probe'] = random_np2(n_shanks=3) + centers = x_centers(self.ops) + assert len(centers == 3) + true = np.array([22, 272, 522, 772]) + for c in centers: + # Each center is within 2 microns of exactly one true center + print(f'center: {c}') + assert (np.abs(c - true) < 5).sum() == 1 + + def test_np2_4shank(self): + self.ops['probe'] = random_np2(n_shanks=4) + centers = x_centers(self.ops) + # All centers should be within 2 microns of the true values + print(f'centers: {centers}') + assert np.allclose(np.sort(centers), np.sort([22, 272, 522, 772]), atol=5) diff --git a/tests/test_full_pipeline.py b/tests/test_full_pipeline.py index 247daca8..3b6e039c 100644 --- a/tests/test_full_pipeline.py +++ b/tests/test_full_pipeline.py @@ -48,7 +48,6 @@ def test_pipeline(data_directory, results_directory, saved_ops, torch_device, ca print(f'Proportion difference in total spike count: {spikes_error}') print(f'Count from run_kilosort: {st.size}') print(f'Count from saved test results: {st_load.size}') - assert spikes_error <= 0.025 n = np.unique(clu).size n_load = np.unique(clu_load).size @@ -57,4 +56,6 @@ def test_pipeline(data_directory, results_directory, saved_ops, torch_device, ca print(f'Proportion difference in number of units: {unit_count_error}') print(f'Number of units from run_kilosort: {n}') print(f'Number of units from saved test results: {n_load}') + + assert spikes_error <= 0.025 assert unit_count_error <= 0.05