Skip to content

Commit

Permalink
Merge pull request #662 from MouseLand/jacob/distance_gating
Browse files Browse the repository at this point in the history
Jacob/distance gating
  • Loading branch information
jacobpennington authored Apr 13, 2024
2 parents 39bedf1 + 82f82ec commit 10a768d
Show file tree
Hide file tree
Showing 8 changed files with 359 additions and 104 deletions.
180 changes: 109 additions & 71 deletions kilosort/clustering_qr.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from io import StringIO
import numpy as np
import torch
from torch import sparse_coo_tensor as coo
from scipy.sparse import csr_matrix
from scipy.sparse import csr_matrix
from scipy.ndimage.filters import gaussian_filter
from scipy.signal import find_peaks
from scipy.cluster.vq import kmeans
import faiss
from tqdm import tqdm

from kilosort import hierarchical, swarmsplitter


def neigh_mat(Xd, nskip=10, n_neigh=30):
Xsub = Xd[::nskip]
n_samples, dim = Xd.shape
Expand Down Expand Up @@ -50,6 +54,7 @@ def assign_mu(iclust, Xg, cols_mu, tones, nclust = None, lpow = 1):

return mu, N


def assign_iclust(rows_neigh, isub, kn, tones2, nclust, lam, m, ki, kj, device=torch.device('cuda')):
NN = kn.shape[0]

Expand All @@ -69,6 +74,7 @@ def assign_iclust(rows_neigh, isub, kn, tones2, nclust, lam, m, ki, kj, device=t

return iclust


def assign_isub(iclust, kn, tones2, nclust, nsub, lam, m,ki,kj, device=torch.device('cuda')):
n_neigh = kn.shape[1]
cols = iclust.unsqueeze(-1).tile((1, n_neigh))
Expand All @@ -87,6 +93,7 @@ def assign_isub(iclust, kn, tones2, nclust, nsub, lam, m,ki,kj, device=torch.dev
isub = torch.argmax(xS, 1)
return isub


def Mstats(M, device=torch.device('cuda')):
m = M.sum()
ki = np.array(M.sum(1)).flatten()
Expand Down Expand Up @@ -176,6 +183,7 @@ def kmeans_plusplus(Xg, niter = 200, seed = 1, device=torch.device('cuda')):

return iclust


def compute_score(mu, mu2, N, ccN, lam):
mu_pairs = ((N*mu).unsqueeze(1) + N*mu) / (1e-6 + N+N[:,0]).unsqueeze(-1)
mu2_pairs = ((N*mu2).unsqueeze(1) + N*mu2) / (1e-6 + N+N[:,0]).unsqueeze(-1)
Expand All @@ -189,20 +197,16 @@ def compute_score(mu, mu2, N, ccN, lam):
score = (ccN + ccN.T) - lam * dexp
return score

def run_one(Xd, st0, nskip = 20, lam = 0):

def run_one(Xd, st0, nskip = 20, lam = 0):
iclust, iclust0, M = cluster(Xd,nskip = nskip, lam = 0, seed = 5)

xtree, tstat, my_clus = hierarchical.maketree(M, iclust, iclust0)

xtree, tstat = swarmsplitter.split(Xd.numpy(), xtree, tstat, iclust, my_clus, meta = st0)

xtree, tstat = swarmsplitter.split(Xd.numpy(), xtree, tstat, iclust,
my_clus, meta = st0)
iclust1 = swarmsplitter.new_clusters(iclust, my_clus, xtree, tstat)

return iclust1

import time


def xy_templates(ops):
iU = ops['iU'].cpu().numpy()
Expand All @@ -214,24 +218,51 @@ def xy_templates(ops):

iU = ops['iU'].cpu().numpy()
iC = ops['iCC'][:, ops['iU']]

return xy, iC


def xy_up(ops):
xcup, ycup = ops['xcup'], ops['ycup']
xy = np.vstack((xcup, ycup))
xy = torch.from_numpy(xy)

iC = ops['iC']

return xy, iC

def xy_c(ops):
xcup, ycup = ops['xc'][::4], ops['yc'][::4]
xy = np.vstack((xcup, ycup+10))
xy = torch.from_numpy(xy)

iC = ops['iC']
#print(1)
return xy, iC
def x_centers(ops):
dminx = ops['dminx']
min_x = ops['xc'].min()
max_x = ops['xc'].max()

# Make histogram of x-positions with bin size roughly equal to dminx,
# with a bit of padding on either end of the probe so that peaks can be
# detected at edges.
num_bins = int((max_x-min_x)/(dminx)) + 4
bins = np.linspace(min_x - dminx*2, max_x + dminx*2, num_bins)
hist, edges = np.histogram(ops['xc'], bins=bins)
# Apply smoothing to make peak-finding simpler.
smoothed = gaussian_filter(hist, sigma=0.5)
peaks, _ = find_peaks(smoothed)
# peaks are indices, translate back to position in microns
approx_centers = [edges[p] for p in peaks]
# Use these as initial guesses for centroids in k-means to get
# a more accurate value for the actual centers. Or, if there's only 1,
# just look for one centroid.
if len(approx_centers) == 1: approx_centers = 1
centers, distortion = kmeans(ops['xc'], approx_centers)

# TODO: Maybe use distortion to raise warning if it seems too large?
# "The mean (non-squared) Euclidean distance between the observations passed
# and the centroids generated. Note the difference to the standard definition
# of distortion in the context of the k-means algorithm, which is the sum of
# the squared distances."

# For example, could raise a warning if this is greater than dminx*2?
# Most probes should satisfy that criteria.

return centers


def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_bar=None):
Expand All @@ -240,83 +271,90 @@ def run(ops, st, tF, mode = 'template', device=torch.device('cuda'), progress_b
xy, iC = xy_templates(ops)
iclust_template = st[:,1].astype('int32')
xcup, ycup = ops['xcup'], ops['ycup']
elif mode == 'spikes_nn':
xy, iC = xy_c(ops)
xcup, ycup = ops['xc'][::4], ops['yc'][::4]
iclust_template = st[:,5].astype('int32')
else:
xy, iC = xy_up(ops)
iclust_template = st[:,5].astype('int32')
xcup, ycup = ops['xcup'], ops['ycup']

d0 = ops['dmin']
ycent = np.arange(ycup.min()+d0-1, ycup.max()+d0+1, 2*d0)

nsp = st.shape[0]
clu = np.zeros(nsp, 'int32')
nmax = 0

dmin = ops['dmin']
dminx = ops['dminx']
nskip = ops['settings']['cluster_downsampling']
ncomps = ops['settings']['cluster_pcs']
ycent = np.arange(ycup.min()+dmin-1, ycup.max()+dmin+1, 2*dmin)
xcent = x_centers(ops)
nsp = st.shape[0]

# Get positions of all grouping centers
ycent_pos, xcent_pos = np.meshgrid(ycent, xcent)
ycent_pos = torch.from_numpy(ycent_pos.flatten())
xcent_pos = torch.from_numpy(xcent_pos.flatten())
# Compute distances from templates
center_distance = (
(xy[0,:] - xcent_pos.unsqueeze(-1))**2
+ (xy[1,:] - ycent_pos.unsqueeze(-1))**2
)
# Add some randomness in case of ties
center_distance += 1e-20*torch.rand(center_distance.shape)
# Get flattened index of x-y center that is closest to template
minimum_distance = torch.min(center_distance, 0).indices

clu = np.zeros(nsp, 'int32')
Wall = torch.zeros((0, ops['Nchan'], ops['settings']['n_pcs']))
t0 = time.time()
nearby_chans_empty = 0
for kk in tqdm(np.arange(len(ycent)), miniters=20 if progress_bar else None, mininterval=10 if progress_bar else None):
# get the data
#iclust_template = st[:,1].astype('int32')

Xd, ch_min, ch_max, igood = get_data_cpu(
ops, xy, iC, iclust_template, tF, ycent[kk], xcup.mean(), dmin=d0,
dminx = ops['dminx'], ncomps=ncomps
)
nmax = 0

if Xd is None:
nearby_chans_empty += 1
continue

if Xd.shape[0]<1000:
#clu[igood] = nmax
#nmax += 1
iclust = torch.zeros((Xd.shape[0],))
else:
if mode == 'template':
st0 = st[igood,0]/ops['fs']
for kk in tqdm(np.arange(len(ycent)), miniters=20 if progress_bar else None,
mininterval=10 if progress_bar else None):
for jj in np.arange(len(xcent)):
# Get data for all templates that were closest to this x,y center.
ii = ii = kk + jj*ycent.size
ix = (minimum_distance == ii)
Xd, ch_min, ch_max, igood = get_data_cpu(
ops, xy, iC, iclust_template, tF, ycent[kk], xcent[jj], dmin=dmin,
dminx=dminx, ncomps=ncomps, ix=ix
)

if Xd is None:
nearby_chans_empty += 1
continue
elif Xd.shape[0]<1000:
iclust = torch.zeros((Xd.shape[0],))
else:
st0 = None
if mode == 'template':
st0 = st[igood,0]/ops['fs']
else:
st0 = None

# find new clusters
iclust, iclust0, M, iclust_init = cluster(Xd, nskip=nskip, lam=1,
seed=5, device=device)
# find new clusters
iclust, iclust0, M, iclust_init = cluster(Xd, nskip=nskip, lam=1,
seed=5, device=device)

xtree, tstat, my_clus = hierarchical.maketree(M, iclust, iclust0)
xtree, tstat, my_clus = hierarchical.maketree(M, iclust, iclust0)

xtree, tstat = swarmsplitter.split(Xd.numpy(), xtree, tstat, iclust, my_clus, meta = st0)
xtree, tstat = swarmsplitter.split(Xd.numpy(), xtree, tstat, iclust, my_clus, meta = st0)

iclust = swarmsplitter.new_clusters(iclust, my_clus, xtree, tstat)
iclust = swarmsplitter.new_clusters(iclust, my_clus, xtree, tstat)

clu[igood] = iclust + nmax
Nfilt = int(iclust.max() + 1)
nmax += Nfilt
clu[igood] = iclust + nmax
Nfilt = int(iclust.max() + 1)
nmax += Nfilt

# we need the new templates here
W = torch.zeros((Nfilt, ops['Nchan'], ops['settings']['n_pcs']))
for j in range(Nfilt):
w = Xd[iclust==j].mean(0)
W[j, ch_min:ch_max, :] = torch.reshape(w, (-1, ops['settings']['n_pcs'])).cpu()

Wall = torch.cat((Wall, W), 0)
# we need the new templates here
W = torch.zeros((Nfilt, ops['Nchan'], ops['settings']['n_pcs']))
for j in range(Nfilt):
w = Xd[iclust==j].mean(0)
W[j, ch_min:ch_max, :] = torch.reshape(w, (-1, ops['settings']['n_pcs'])).cpu()
Wall = torch.cat((Wall, W), 0)

if progress_bar is not None:
progress_bar.emit(int((kk+1) / len(ycent) * 100))

if 0:#kk%50==0:
print(kk, nmax, time.time()-t0)
if progress_bar is not None:
progress_bar.emit(int((kk+1) / len(ycent) * 100))


if nearby_chans_empty == len(ycent):
raise ValueError(
f'`get_data_cpu` never found suitable channels in `clustering_qr.run`.'
f'\ndmin, dminx, and xcenter are: {d0, ops["dminx"], xcup.mean()}'
f'\ndmin, dminx, and xcenter are: {dmin, dminx, xcup.mean()}'
)

if Wall.sum() == 0:
Expand Down
Loading

0 comments on commit 10a768d

Please sign in to comment.