diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 30b33ae9..8cf4acf9 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -96,6 +96,10 @@ class BruteforceSearch : public AlgorithmInterface { cur_element_count--; } + std::priority_queue> + searchKnn(const void *query_data, size_t k, const size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const { + return searchKnn(query_data, k, isIdAllowed); + } std::priority_queue> searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index f7d7f264..e9e4f7f3 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -1173,9 +1173,14 @@ class HierarchicalNSW : public AlgorithmInterface { return cur_c; } - std::priority_queue> searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { + return searchKnn(query_data, k, this->ef_, isIdAllowed); + } + + + std::priority_queue> + searchKnn(const void *query_data, size_t k, const size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const { std::priority_queue> result; if (cur_element_count == 0) return result; @@ -1210,10 +1215,10 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue, std::vector>, CompareByFirst> top_candidates; if (num_deleted_) { top_candidates = searchBaseLayerST( - currObj, query_data, std::max(ef_, k), isIdAllowed); + currObj, query_data, std::max(ef, k), isIdAllowed); } else { top_candidates = searchBaseLayerST( - currObj, query_data, std::max(ef_, k), isIdAllowed); + currObj, query_data, std::max(ef, k), isIdAllowed); } while (top_candidates.size() > k) { diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index fb7118fa..b93c2599 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -160,6 +160,9 @@ class AlgorithmInterface { public: virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0; + virtual std::priority_queue> + searchKnn(const void*, size_t, const size_t ef_, BaseFilterFunctor* isIdAllowed = nullptr) const = 0; + virtual std::priority_queue> searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0; @@ -167,11 +170,34 @@ class AlgorithmInterface { virtual std::vector> searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const; + virtual std::vector> + searchKnnCloserFirst(const void* query_data, size_t k, const size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const; + virtual void saveIndex(const std::string &location) = 0; virtual ~AlgorithmInterface(){ } }; +template +std::vector> +AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, + const size_t ef, BaseFilterFunctor* isIdAllowed) const { + std::vector> result; + + // here searchKnn returns the result in the order of further first + auto ret = searchKnn(query_data, k, ef, isIdAllowed); + { + size_t sz = ret.size(); + result.resize(sz); + while (!ret.empty()) { + result[--sz] = ret.top(); + ret.pop(); + } + } + + return result; +} + template std::vector> AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k,