Skip to content

Commit

Permalink
Fix duplicate results for Knn search (#3)
Browse files Browse the repository at this point in the history
* Fix duplicate results for Knn search

* Fix erasing already updated label

* Fix indentation

* Use iterator to avoid extra lookup

* Add test

* Update build.yml and CMake to run new test

* Refactor test

* Add logs for test

* Fix for loop

* Fix initial value for initial batch loop
  • Loading branch information
ozanarmagan authored Feb 5, 2024
1 parent 2fec56c commit 687d981
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 1 deletion.
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ jobs:
./searchKnnWithFilter_test
./multiThreadLoad_test
./multiThread_replace_test
./replaceSameLabel_test
./test_updates
./test_updates update
./repair_test
Expand Down
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
add_executable(multiThread_replace_test tests/cpp/multiThread_replace_test.cpp)
target_link_libraries(multiThread_replace_test hnswlib)

add_executable(replaceSameLabel_test tests/cpp/replaceSameLabel_test.cpp)
target_link_libraries(replaceSameLabel_test hnswlib)

add_executable(main tests/cpp/main.cpp tests/cpp/sift_1b.cpp)
target_link_libraries(main hnswlib)

Expand Down
9 changes: 8 additions & 1 deletion hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,14 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
setExternalLabel(internal_id_replaced, label);

std::unique_lock <std::mutex> lock_table(label_lookup_lock);
label_lookup_.erase(label_replaced);
// check if the label is already in the index
if (label_lookup_.find(label) != label_lookup_.end() && !isMarkedDeleted(label_lookup_[label])) {
markDeletedInternal(label_lookup_[label]);
}
auto label_replaced_lookup = label_lookup_.find(label_replaced);
if(label_replaced_lookup != label_lookup_.end() && label_replaced_lookup->second == internal_id_replaced) {
label_lookup_.erase(label_replaced);
}
label_lookup_[label] = internal_id_replaced;
lock_table.unlock();

Expand Down
80 changes: 80 additions & 0 deletions tests/cpp/replaceSameLabel_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#include "../../hnswlib/hnswlib.h"
#include <thread>
#include <chrono>
#include <iomanip>

int main() {
std::cout << "Running replace same label back-to-back test" << std::endl;
int d = 16;
int num_elements = 50;
int max_elements = 2 * num_elements;

std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib_real;
std::uniform_int_distribution<> distrib_int(0, max_elements - 1);

hnswlib::InnerProductSpace space(d);

std::cout << "Generating random data" << std::endl;
std::cout << "Initial batch size: " << max_elements << std::endl;
float* initial_batch = new float[d * max_elements];
for (int i = 0; i < d * max_elements; i++) {
initial_batch[i] = distrib_real(rng);
}

std::cout << "Update batch size: " << num_elements << std::endl;
float* update_batch = new float[d * num_elements];
for (int i = 0; i < d * num_elements; i++) {
update_batch[i] = distrib_real(rng);
}

std::cout << "Building index" << std::endl;
hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 1024, 16, 200, 456, true);

std::vector<int> labels;

std::cout << "Adding initial batch" << std::endl;
for(int row = 0; row < max_elements; row++) {
int label = distrib_int(rng);
labels.push_back(label);
alg_hnsw->addPoint((void*)(initial_batch + d * row), label, true);
};

std::cout << "Deleting half of the initial batch" << std::endl;
for (int i = 0; i < labels.size() / 2; i++) {
if(!alg_hnsw->isMarkedDeleted(alg_hnsw->label_lookup_[labels[i]]))
alg_hnsw->markDelete(labels[i]);
}

std::cout << "Updating index size" << std::endl;
size_t curr_ele_count = alg_hnsw->getCurrentElementCount();
if(curr_ele_count + max_elements > alg_hnsw->getMaxElements()) {
alg_hnsw->resizeIndex((curr_ele_count + max_elements) * 1.3);
}

std::cout << "Adding update batch" << std::endl;
for(int row; row < num_elements; row++) {
alg_hnsw->addPoint((void*)(update_batch + d * row), 42, true);
};


std::cout << "Searching for 10 nearest neighbors" << std::endl;
auto results = alg_hnsw->searchKnnCloserFirst((void*)(initial_batch), 10);

// check if the search results contain duplicate labels
std::cout << "Checking search results for duplicate labels" << std::endl;
std::unordered_set<int> labels_set;
for (int i = 0; i < results.size(); i++) {
labels_set.insert(results[i].second);
}
if (labels_set.size() != 10) {
std::cout << "Search results contain duplicate labels" << std::endl;
throw std::runtime_error("Search results contain duplicate labels");
}

delete[] initial_batch;
delete[] update_batch;
delete alg_hnsw;
return 0;
}

0 comments on commit 687d981

Please sign in to comment.