Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jacob/distance gating #662

Merged
merged 17 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 109 additions & 71 deletions kilosort/clustering_qr.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from io import StringIO
import numpy as np
import torch
from torch import sparse_coo_tensor as coo
from scipy.sparse import csr_matrix
from scipy.sparse import csr_matrix
from scipy.ndimage.filters import gaussian_filter
from scipy.signal import find_peaks
from scipy.cluster.vq import kmeans
import faiss
from tqdm import tqdm

from kilosort import hierarchical, swarmsplitter


def neigh_mat(Xd, nskip=10, n_neigh=30):
Xsub = Xd[::nskip]
n_samples, dim = Xd.shape
Expand Down Expand Up @@ -50,6 +54,7 @@ def assign_mu(iclust, Xg, cols_mu, tones, nclust = None, lpow = 1):

return mu, N


def assign_iclust(rows_neigh, isub, kn, tones2, nclust, lam, m, ki, kj, device=torch.device('cuda')):
NN = kn.shape[0]

Expand All @@ -69,6 +74,7 @@ def assign_iclust(rows_neigh, isub, kn, tones2, nclust, lam, m, ki, kj, device=t

return iclust


def assign_isub(iclust, kn, tones2, nclust, nsub, lam, m,ki,kj, device=torch.device('cuda')):
n_neigh = kn.shape[1]
cols = iclust.unsqueeze(-1).tile((1, n_neigh))
Expand All @@ -87,6 +93,7 @@ def assign_isub(iclust, kn, tones2, nclust, nsub, lam, m,ki,kj, device=torch.dev
isub = torch.argmax(xS, 1)
return isub


def Mstats(M, device=torch.device('cuda')):
m = M.sum()
ki = np.array(M.sum(1)).flatten()
Expand Down Expand Up @@ -176,6 +183,7 @@ def kmeans_plusplus(Xg, niter = 200, seed = 1, device=torch.device('cuda')):

return iclust


def compute_score(mu, mu2, N, ccN, lam):
mu_pairs = ((N*mu).unsqueeze(1) + N*mu) / (1e-6 + N+N[:,0]).unsqueeze(-1)
mu2_pairs = ((N*mu2).unsqueeze(1) + N*mu2) / (1e-6 + N+N[:,0]).unsqueeze(-1)
Expand All @@ -189,20 +197,16 @@ def compute_score(mu, mu2, N, ccN, lam):
score = (ccN + ccN.T) - lam * dexp
return score

def run_one(Xd, st0, nskip = 20, lam = 0):

def run_one(Xd, st0, nskip = 20, lam = 0):
iclust, iclust0, M = cluster(Xd,nskip = nskip, lam = 0, seed = 5)

xtree, tstat, my_clus = hierarchical.maketree(M, iclust, iclust0)

xtree, tstat = swarmsplitter.split(Xd.numpy(), xtree, tstat, iclust, my_clus, meta = st0)

xtree, tstat = swarmsplitter.split(Xd.numpy(), xtree, tstat, iclust,
my_clus, meta = st0)
iclust1 = swarmsplitter.new_clusters(iclust, my_clus, xtree, tstat)

return iclust1

import time


def xy_templates(ops):
iU = ops['iU'].cpu().numpy()
Expand All @@ -214,24 +218,51 @@ def xy_templates(ops):

iU = ops['iU'].cpu().numpy()
iC = ops['iCC'][:, ops['iU']]

return xy, iC


def xy_up(ops):
xcup, ycup = ops['xcup'], ops['ycup']
xy = np.vstack((xcup, ycup))
xy = torch.from_numpy(xy)

iC = ops['iC']

return xy, iC

def xy_c(ops):
xcup, ycup = ops['xc'][::4], ops['yc'][::4]
xy = np.vstack((xcup, ycup+10))
xy = torch.from_numpy(xy)

iC = ops['iC']
#print(1)
return xy, iC
def x_centers(ops):
dminx = ops['dminx']
min_x = ops['xc'].min()
max_x = ops['xc'].max()

# Make histogram of x-positions with bin size roughly equal to dminx,
# with a bit of padding on either end of the probe so that peaks can be
# detected at edges.
num_bins = int((max_x-min_x)/(dminx)) + 4
bins = np.linspace(min_x - dminx*2, max_x + dminx*2, num_bins)
hist, edges = np.histogram(ops['xc'], bins=bins)
# Apply smoothing to make peak-finding simpler.
smoothed = gaussian_filter(hist, sigma=0.5)
peaks, _ = find_peaks(smoothed)
# peaks are indices, translate back to position in microns
approx_centers = [edges[p] for p in peaks]
# Use these as initial guesses for centroids in k-means to get
# a more accurate value for the actual centers. Or, if there's only 1,
# just look for one centroid.
if len(approx_centers) == 1: approx_centers = 1
centers, distortion = kmeans(ops['xc'], approx_centers)

# TODO: Maybe use distortion to raise warning if it seems too large?
# "The mean (non-squared) Euclidean distance between the observations passed
# and the centroids generated. Note the difference to the standard definition
# of distortion in the context of the k-means algorithm, which is the sum of
# the squared distances."

# For example, could raise a warning if this is greater than dminx*2?
# Most probes should satisfy that criteria.

return centers


def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_bar=None):
Expand All @@ -240,83 +271,90 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_b
xy, iC = xy_templates(ops)
iclust_template = st[:,1].astype('int32')
xcup, ycup = ops['xcup'], ops['ycup']
elif mode == 'spikes_nn':
xy, iC = xy_c(ops)
xcup, ycup = ops['xc'][::4], ops['yc'][::4]
iclust_template = st[:,5].astype('int32')
else:
xy, iC = xy_up(ops)
iclust_template = st[:,5].astype('int32')
xcup, ycup = ops['xcup'], ops['ycup']

d0 = ops['dmin']
ycent = np.arange(ycup.min()+d0-1, ycup.max()+d0+1, 2*d0)

nsp = st.shape[0]
clu = np.zeros(nsp, 'int32')
nmax = 0

dmin = ops['dmin']
dminx = ops['dminx']
nskip = ops['settings']['cluster_downsampling']
ncomps = ops['settings']['cluster_pcs']
ycent = np.arange(ycup.min()+dmin-1, ycup.max()+dmin+1, 2*dmin)
xcent = x_centers(ops)
nsp = st.shape[0]

# Get positions of all grouping centers
ycent_pos, xcent_pos = np.meshgrid(ycent, xcent)
ycent_pos = torch.from_numpy(ycent_pos.flatten())
xcent_pos = torch.from_numpy(xcent_pos.flatten())
# Compute distances from templates
center_distance = (
(xy[0,:] - xcent_pos.unsqueeze(-1))**2
+ (xy[1,:] - ycent_pos.unsqueeze(-1))**2
)
# Add some randomness in case of ties
center_distance += 1e-20*torch.rand(center_distance.shape)
# Get flattened index of x-y center that is closest to template
minimum_distance = torch.min(center_distance, 0).indices

clu = np.zeros(nsp, 'int32')
Wall = torch.zeros((0, ops['Nchan'], ops['settings']['n_pcs']))
t0 = time.time()
nearby_chans_empty = 0
for kk in tqdm(np.arange(len(ycent)), miniters=20 if progress_bar else None, mininterval=10 if progress_bar else None):
# get the data
#iclust_template = st[:,1].astype('int32')

Xd, ch_min, ch_max, igood = get_data_cpu(
ops, xy, iC, iclust_template, tF, ycent[kk], xcup.mean(), dmin=d0,
dminx = ops['dminx'], ncomps=ncomps
)
nmax = 0

if Xd is None:
nearby_chans_empty += 1
continue

if Xd.shape[0]<1000:
#clu[igood] = nmax
#nmax += 1
iclust = torch.zeros((Xd.shape[0],))
else:
if mode == 'template':
st0 = st[igood,0]/ops['fs']
for kk in tqdm(np.arange(len(ycent)), miniters=20 if progress_bar else None,
mininterval=10 if progress_bar else None):
for jj in np.arange(len(xcent)):
# Get data for all templates that were closest to this x,y center.
ii = ii = kk + jj*ycent.size
ix = (minimum_distance == ii)
Xd, ch_min, ch_max, igood = get_data_cpu(
ops, xy, iC, iclust_template, tF, ycent[kk], xcent[jj], dmin=dmin,
dminx=dminx, ncomps=ncomps, ix=ix
)

if Xd is None:
nearby_chans_empty += 1
continue
elif Xd.shape[0]<1000:
iclust = torch.zeros((Xd.shape[0],))
else:
st0 = None
if mode == 'template':
st0 = st[igood,0]/ops['fs']
else:
st0 = None

# find new clusters
iclust, iclust0, M, iclust_init = cluster(Xd, nskip=nskip, lam=1,
seed=5, device=device)
# find new clusters
iclust, iclust0, M, iclust_init = cluster(Xd, nskip=nskip, lam=1,
seed=5, device=device)

xtree, tstat, my_clus = hierarchical.maketree(M, iclust, iclust0)
xtree, tstat, my_clus = hierarchical.maketree(M, iclust, iclust0)

xtree, tstat = swarmsplitter.split(Xd.numpy(), xtree, tstat, iclust, my_clus, meta = st0)
xtree, tstat = swarmsplitter.split(Xd.numpy(), xtree, tstat, iclust, my_clus, meta = st0)

iclust = swarmsplitter.new_clusters(iclust, my_clus, xtree, tstat)
iclust = swarmsplitter.new_clusters(iclust, my_clus, xtree, tstat)

clu[igood] = iclust + nmax
Nfilt = int(iclust.max() + 1)
nmax += Nfilt
clu[igood] = iclust + nmax
Nfilt = int(iclust.max() + 1)
nmax += Nfilt

# we need the new templates here
W = torch.zeros((Nfilt, ops['Nchan'], ops['settings']['n_pcs']))
for j in range(Nfilt):
w = Xd[iclust==j].mean(0)
W[j, ch_min:ch_max, :] = torch.reshape(w, (-1, ops['settings']['n_pcs'])).cpu()

Wall = torch.cat((Wall, W), 0)
# we need the new templates here
W = torch.zeros((Nfilt, ops['Nchan'], ops['settings']['n_pcs']))
for j in range(Nfilt):
w = Xd[iclust==j].mean(0)
W[j, ch_min:ch_max, :] = torch.reshape(w, (-1, ops['settings']['n_pcs'])).cpu()
Wall = torch.cat((Wall, W), 0)

if progress_bar is not None:
progress_bar.emit(int((kk+1) / len(ycent) * 100))

if 0:#kk%50==0:
print(kk, nmax, time.time()-t0)
if progress_bar is not None:
progress_bar.emit(int((kk+1) / len(ycent) * 100))


if nearby_chans_empty == len(ycent):
raise ValueError(
f'`get_data_cpu` never found suitable channels in `clustering_qr.run`.'
f'\ndmin, dminx, and xcenter are: {d0, ops["dminx"], xcup.mean()}'
f'\ndmin, dminx, and xcenter are: {dmin, dminx, xcup.mean()}'
)

if Wall.sum() == 0:
Expand Down
Loading
Loading