Skip to content

Commit

Permalink
Update snnpy.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chenxinye authored Jan 25, 2024
1 parent a4bc124 commit 2b02da7
Showing 1 changed file with 19 additions and 34 deletions.
53 changes: 19 additions & 34 deletions python/snnpy/snnpy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# MIT License
# Copyright (c) 2022 Stefan Güttel, Xinye Chen
# See license file for details

import numba
import numpy as np
from scipy.linalg import get_blas_funcs, eigh
Expand Down Expand Up @@ -100,29 +96,26 @@ def _r_batch_query_mef(mu, v, xxt, sort_vals, sort_id, data, queries, r, return_
return query_batches_mef(queries, data, inner_queries, xxt, r, lefts, rights, sort_id)



@numba.njit(cache=False)
def query_batches_mef(queries, data, inner_queries, xxt, r, lefts, rights, sort_id):
def query_batches(queries, data, inner_queries, xxt, r, lefts, rights, sort_id):
knn_ind = dict()

ddata_queries = np.dot(queries, data.T)
for i in range(queries.shape[0]):
ddata_query = np.dot(data, queries[i])
batch_dist_set = (xxt + inner_queries[i] - 2*ddata_query)
batch_dist_set = (xxt + inner_queries[i] - 2*ddata_queries[i])

batch_dist_set = batch_dist_set[lefts[i]:rights[i]]

knn_ind[i] = sort_id[lefts[i]:rights[i]][batch_dist_set <= r]

return knn_ind


@numba.njit(cache=False)
def query_batches_dist_mef(queries, data, inner_queries, xxt, r, lefts, rights, sort_id):
def query_batches_dist(queries, data, inner_queries, xxt, r, lefts, rights, sort_id):
knn_ind = dict()
knn_dist = dict()

ddata_queries = np.dot(queries, data.T)
for i in range(queries.shape[0]):
ddata_query = np.dot(data, queries[i])
batch_dist_set = (xxt + inner_queries[i] - 2*ddata_query)
batch_dist_set = (xxt + inner_queries[i] - 2*ddata_queries[i])

batch_dist_set = batch_dist_set[lefts[i]:rights[i]]

filter_r = batch_dist_set <= r
Expand All @@ -132,49 +125,41 @@ def query_batches_dist_mef(queries, data, inner_queries, xxt, r, lefts, rights,
return knn_ind, knn_dist



@numba.njit(cache=True) # return (ids)
def query_batches(queries, data, inner_queries, xxt, r, lefts, rights, sort_id):
@numba.njit(cache=False)
def query_batches_mef(queries, data, inner_queries, xxt, r, lefts, rights, sort_id):
knn_ind = dict()
ddata_queries = np.dot(queries, data.T)

for i in range(queries.shape[0]):
batch_dist_set = (xxt + inner_queries[i] - 2*ddata_queries[i])

batch_dist_set = batch_dist_set[lefts[i]:rights[i]]
ddata_query = np.dot(data[lefts[i]:rights[i]], queries[i])
batch_dist_set = (xxt[lefts[i]:rights[i]] + inner_queries[i] - 2*ddata_query)
knn_ind[i] = sort_id[lefts[i]:rights[i]][batch_dist_set <= r]

return knn_ind


@numba.njit(cache=False) # return (ids, distances)
def query_batches_dist(queries, data, inner_queries, xxt, r, lefts, rights, sort_id):
@numba.njit(cache=False)
def query_batches_dist_mef(queries, data, inner_queries, xxt, r, lefts, rights, sort_id):
knn_ind = dict()
knn_dist = dict()
ddata_queries = np.dot(queries, data.T)

for i in range(queries.shape[0]):
batch_dist_set = (xxt + inner_queries[i] - 2*ddata_queries[i])

batch_dist_set = batch_dist_set[lefts[i]:rights[i]]
ddata_query = np.dot(data[lefts[i]:rights[i]], queries[i])
batch_dist_set = (xxt[lefts[i]:rights[i]] + inner_queries[i] - 2*ddata_query)

filter_r = batch_dist_set <= r
knn_ind[i] = sort_id[lefts[i]:rights[i]][filter_r]
knn_dist[i] = np.sqrt(batch_dist_set[filter_r])

return knn_ind, knn_dist




def euclid(xxt, X, v):
return (xxt + np.inner(v,v).ravel() -2*X.dot(v)).astype(float)



@numba.njit(cache=False)
def bisection_sort(queries, mu, v, sort_vals, r):
queries = np.subtract(queries, mu)
sv_qs = np.dot(queries, v)
lefts = np.searchsorted(sort_vals, sv_qs-r)
rights = np.searchsorted(sort_vals, sv_qs+r)
return queries, lefts, rights


0 comments on commit 2b02da7

Please sign in to comment.