From 2b02da700f248e58e8343eb5f503419e4921fa8d Mon Sep 17 00:00:00 2001 From: the null Date: Thu, 25 Jan 2024 04:05:39 +0100 Subject: [PATCH] Update snnpy.py --- python/snnpy/snnpy.py | 53 ++++++++++++++++--------------------------- 1 file changed, 19 insertions(+), 34 deletions(-) diff --git a/python/snnpy/snnpy.py b/python/snnpy/snnpy.py index 5b4ec63..76e7186 100644 --- a/python/snnpy/snnpy.py +++ b/python/snnpy/snnpy.py @@ -1,7 +1,3 @@ -# MIT License -# Copyright (c) 2022 Stefan Güttel, Xinye Chen -# See license file for details - import numba import numpy as np from scipy.linalg import get_blas_funcs, eigh @@ -100,29 +96,26 @@ def _r_batch_query_mef(mu, v, xxt, sort_vals, sort_id, data, queries, r, return_ return query_batches_mef(queries, data, inner_queries, xxt, r, lefts, rights, sort_id) - @numba.njit(cache=False) -def query_batches_mef(queries, data, inner_queries, xxt, r, lefts, rights, sort_id): +def query_batches(queries, data, inner_queries, xxt, r, lefts, rights, sort_id): knn_ind = dict() - + ddata_queries = np.dot(queries, data.T) for i in range(queries.shape[0]): - ddata_query = np.dot(data, queries[i]) - batch_dist_set = (xxt + inner_queries[i] - 2*ddata_query) + batch_dist_set = (xxt + inner_queries[i] - 2*ddata_queries[i]) + batch_dist_set = batch_dist_set[lefts[i]:rights[i]] - knn_ind[i] = sort_id[lefts[i]:rights[i]][batch_dist_set <= r] return knn_ind - @numba.njit(cache=False) -def query_batches_dist_mef(queries, data, inner_queries, xxt, r, lefts, rights, sort_id): +def query_batches_dist(queries, data, inner_queries, xxt, r, lefts, rights, sort_id): knn_ind = dict() knn_dist = dict() - + ddata_queries = np.dot(queries, data.T) for i in range(queries.shape[0]): - ddata_query = np.dot(data, queries[i]) - batch_dist_set = (xxt + inner_queries[i] - 2*ddata_query) + batch_dist_set = (xxt + inner_queries[i] - 2*ddata_queries[i]) + batch_dist_set = batch_dist_set[lefts[i]:rights[i]] filter_r = batch_dist_set <= r @@ -132,43 +125,37 @@ def query_batches_dist_mef(queries, data, inner_queries, xxt, r, lefts, rights, return knn_ind, knn_dist - -@numba.njit(cache=True) # return (ids) -def query_batches(queries, data, inner_queries, xxt, r, lefts, rights, sort_id): +@numba.njit(cache=False) +def query_batches_mef(queries, data, inner_queries, xxt, r, lefts, rights, sort_id): knn_ind = dict() - ddata_queries = np.dot(queries, data.T) + for i in range(queries.shape[0]): - batch_dist_set = (xxt + inner_queries[i] - 2*ddata_queries[i]) - - batch_dist_set = batch_dist_set[lefts[i]:rights[i]] + ddata_query = np.dot(data[lefts[i]:rights[i]], queries[i]) + batch_dist_set = (xxt[lefts[i]:rights[i]] + inner_queries[i] - 2*ddata_query) knn_ind[i] = sort_id[lefts[i]:rights[i]][batch_dist_set <= r] return knn_ind -@numba.njit(cache=False) # return (ids, distances) -def query_batches_dist(queries, data, inner_queries, xxt, r, lefts, rights, sort_id): +@numba.njit(cache=False) +def query_batches_dist_mef(queries, data, inner_queries, xxt, r, lefts, rights, sort_id): knn_ind = dict() knn_dist = dict() - ddata_queries = np.dot(queries, data.T) + for i in range(queries.shape[0]): - batch_dist_set = (xxt + inner_queries[i] - 2*ddata_queries[i]) - - batch_dist_set = batch_dist_set[lefts[i]:rights[i]] + ddata_query = np.dot(data[lefts[i]:rights[i]], queries[i]) + batch_dist_set = (xxt[lefts[i]:rights[i]] + inner_queries[i] - 2*ddata_query) filter_r = batch_dist_set <= r knn_ind[i] = sort_id[lefts[i]:rights[i]][filter_r] knn_dist[i] = np.sqrt(batch_dist_set[filter_r]) return knn_ind, knn_dist - - - + def euclid(xxt, X, v): return (xxt + np.inner(v,v).ravel() -2*X.dot(v)).astype(float) - @numba.njit(cache=False) def bisection_sort(queries, mu, v, sort_vals, r): queries = np.subtract(queries, mu) @@ -176,5 +163,3 @@ def bisection_sort(queries, mu, v, sort_vals, r): lefts = np.searchsorted(sort_vals, sv_qs-r) rights = np.searchsorted(sort_vals, sv_qs+r) return queries, lefts, rights - -