diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 42f55416..039dce5f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -71,6 +71,7 @@ jobs: ./searchKnnWithFilter_test ./multiThreadLoad_test ./multiThread_replace_test + ./replaceSameLabel_test ./test_updates ./test_updates update ./repair_test diff --git a/CMakeLists.txt b/CMakeLists.txt index d5a22baf..41bd0aca 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,9 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) add_executable(multiThread_replace_test tests/cpp/multiThread_replace_test.cpp) target_link_libraries(multiThread_replace_test hnswlib) + add_executable(replaceSameLabel_test tests/cpp/replaceSameLabel_test.cpp) + target_link_libraries(replaceSameLabel_test hnswlib) + add_executable(main tests/cpp/main.cpp tests/cpp/sift_1b.cpp) target_link_libraries(main hnswlib) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index e9e4f7f3..ca3f08d7 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -884,7 +884,14 @@ class HierarchicalNSW : public AlgorithmInterface { setExternalLabel(internal_id_replaced, label); std::unique_lock lock_table(label_lookup_lock); - label_lookup_.erase(label_replaced); + // check if the label is already in the index + if (label_lookup_.find(label) != label_lookup_.end() && !isMarkedDeleted(label_lookup_[label])) { + markDeletedInternal(label_lookup_[label]); + } + auto label_replaced_lookup = label_lookup_.find(label_replaced); + if(label_replaced_lookup != label_lookup_.end() && label_replaced_lookup->second == internal_id_replaced) { + label_lookup_.erase(label_replaced); + } label_lookup_[label] = internal_id_replaced; lock_table.unlock(); diff --git a/tests/cpp/replaceSameLabel_test.cpp b/tests/cpp/replaceSameLabel_test.cpp new file mode 100644 index 00000000..9e675f29 --- /dev/null +++ b/tests/cpp/replaceSameLabel_test.cpp @@ -0,0 +1,80 @@ +#include "../../hnswlib/hnswlib.h" +#include +#include +#include + +int main() { + std::cout << "Running replace same label back-to-back test" << std::endl; + int d = 16; + int num_elements = 50; + int max_elements = 2 * num_elements; + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + std::uniform_int_distribution<> distrib_int(0, max_elements - 1); + + hnswlib::InnerProductSpace space(d); + + std::cout << "Generating random data" << std::endl; + std::cout << "Initial batch size: " << max_elements << std::endl; + float* initial_batch = new float[d * max_elements]; + for (int i = 0; i < d * max_elements; i++) { + initial_batch[i] = distrib_real(rng); + } + + std::cout << "Update batch size: " << num_elements << std::endl; + float* update_batch = new float[d * num_elements]; + for (int i = 0; i < d * num_elements; i++) { + update_batch[i] = distrib_real(rng); + } + + std::cout << "Building index" << std::endl; + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 1024, 16, 200, 456, true); + + std::vector labels; + + std::cout << "Adding initial batch" << std::endl; + for(int row = 0; row < max_elements; row++) { + int label = distrib_int(rng); + labels.push_back(label); + alg_hnsw->addPoint((void*)(initial_batch + d * row), label, true); + }; + + std::cout << "Deleting half of the initial batch" << std::endl; + for (int i = 0; i < labels.size() / 2; i++) { + if(!alg_hnsw->isMarkedDeleted(alg_hnsw->label_lookup_[labels[i]])) + alg_hnsw->markDelete(labels[i]); + } + + std::cout << "Updating index size" << std::endl; + size_t curr_ele_count = alg_hnsw->getCurrentElementCount(); + if(curr_ele_count + max_elements > alg_hnsw->getMaxElements()) { + alg_hnsw->resizeIndex((curr_ele_count + max_elements) * 1.3); + } + + std::cout << "Adding update batch" << std::endl; + for(int row; row < num_elements; row++) { + alg_hnsw->addPoint((void*)(update_batch + d * row), 42, true); + }; + + + std::cout << "Searching for 10 nearest neighbors" << std::endl; + auto results = alg_hnsw->searchKnnCloserFirst((void*)(initial_batch), 10); + + // check if the search results contain duplicate labels + std::cout << "Checking search results for duplicate labels" << std::endl; + std::unordered_set labels_set; + for (int i = 0; i < results.size(); i++) { + labels_set.insert(results[i].second); + } + if (labels_set.size() != 10) { + std::cout << "Search results contain duplicate labels" << std::endl; + throw std::runtime_error("Search results contain duplicate labels"); + } + + delete[] initial_batch; + delete[] update_batch; + delete alg_hnsw; + return 0; +}