diff --git a/kilosort/clustering_qr.py b/kilosort/clustering_qr.py index 205e09eb..442f125e 100644 --- a/kilosort/clustering_qr.py +++ b/kilosort/clustering_qr.py @@ -127,12 +127,10 @@ def cluster(Xd, iclust = None, kn = None, nskip = 20, n_neigh = 10, nclust = 200 m, ki, kj = Mstats(M, device=device) - #Xg = torch.from_numpy(Xd).to(dev) - Xg = Xd.to(device) kn = torch.from_numpy(kn).to(device) n_neigh = kn.shape[1] - NN, nfeat = Xg.shape + NN, nfeat = Xd.shape nsub = (NN-1)//nskip + 1 rows_neigh = torch.arange(NN, device = device).unsqueeze(-1).tile((1,n_neigh)) @@ -140,7 +138,7 @@ def cluster(Xd, iclust = None, kn = None, nskip = 20, n_neigh = 10, nclust = 200 tones2 = torch.ones((NN, n_neigh), device = device) if iclust is None: - iclust_init = kmeans_plusplus(Xg, niter = nclust, seed = seed, device=device) + iclust_init = kmeans_plusplus(Xd, niter = nclust, seed = seed, device=device) iclust = iclust_init.clone() else: iclust_init = iclust.clone() @@ -162,9 +160,12 @@ def cluster(Xd, iclust = None, kn = None, nskip = 20, n_neigh = 10, nclust = 200 return iclust, isub, M, iclust_init -def kmeans_plusplus(Xg, niter = 200, seed = 1, device=torch.device('cuda')): - #Xg = torch.from_numpy(Xd).to(dev) - vtot = (Xg**2).sum(1) +def kmeans_plusplus(Xd, niter = 200, seed = 1, device=torch.device('cuda')): + + Xg = Xd.to(device) + + Xd_squared = (Xd ** 2) + vtot = Xd_squared.sum(1).to(device) n1 = vtot.shape[0] if n1 > 2**24: @@ -199,7 +200,10 @@ def kmeans_plusplus(Xg, niter = 200, seed = 1, device=torch.device('cuda')): isamp = torch.multinomial(v2, ntry) Xc = Xg[isamp] - vexp = 2 * Xg @ Xc.T - (Xc**2).sum(1) + Xc_squared_sum = (Xc ** 2).sum(1) + vexp = Xg @ Xc.T + vexp.mul_(2) + vexp = vexp - Xc_squared_sum dexp = vexp - vexp0.unsqueeze(1) dexp = torch.relu(dexp)