Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix duplicate results for Knn search #3

Merged
merged 10 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ jobs:
./searchKnnWithFilter_test
./multiThreadLoad_test
./multiThread_replace_test
./replaceSameLabel_test
./test_updates
./test_updates update
./repair_test
Expand Down
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
add_executable(multiThread_replace_test tests/cpp/multiThread_replace_test.cpp)
target_link_libraries(multiThread_replace_test hnswlib)

add_executable(replaceSameLabel_test tests/cpp/replaceSameLabel_test.cpp)
target_link_libraries(replaceSameLabel_test hnswlib)

add_executable(main tests/cpp/main.cpp tests/cpp/sift_1b.cpp)
target_link_libraries(main hnswlib)

Expand Down
9 changes: 8 additions & 1 deletion hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,14 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
setExternalLabel(internal_id_replaced, label);

std::unique_lock <std::mutex> lock_table(label_lookup_lock);
label_lookup_.erase(label_replaced);
// check if the label is already in the index
if (label_lookup_.find(label) != label_lookup_.end() && !isMarkedDeleted(label_lookup_[label])) {
markDeletedInternal(label_lookup_[label]);
}
auto label_replaced_lookup = label_lookup_.find(label_replaced);
if(label_replaced_lookup != label_lookup_.end() && label_replaced_lookup->second == internal_id_replaced) {
label_lookup_.erase(label_replaced);
}
label_lookup_[label] = internal_id_replaced;
lock_table.unlock();

Expand Down
80 changes: 80 additions & 0 deletions tests/cpp/replaceSameLabel_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#include "../../hnswlib/hnswlib.h"
#include <thread>
#include <chrono>
#include <iomanip>

int main() {
std::cout << "Running replace same label back-to-back test" << std::endl;
int d = 16;
int num_elements = 50;
int max_elements = 2 * num_elements;

std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib_real;
std::uniform_int_distribution<> distrib_int(0, max_elements - 1);

hnswlib::InnerProductSpace space(d);

std::cout << "Generating random data" << std::endl;
std::cout << "Initial batch size: " << max_elements << std::endl;
float* initial_batch = new float[d * max_elements];
for (int i = 0; i < d * max_elements; i++) {
initial_batch[i] = distrib_real(rng);
}

std::cout << "Update batch size: " << num_elements << std::endl;
float* update_batch = new float[d * num_elements];
for (int i = 0; i < d * num_elements; i++) {
update_batch[i] = distrib_real(rng);
}

std::cout << "Building index" << std::endl;
hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, 1024, 16, 200, 456, true);

std::vector<int> labels;

std::cout << "Adding initial batch" << std::endl;
for(int row = 0; row < max_elements; row++) {
int label = distrib_int(rng);
labels.push_back(label);
alg_hnsw->addPoint((void*)(initial_batch + d * row), label, true);
};

std::cout << "Deleting half of the initial batch" << std::endl;
for (int i = 0; i < labels.size() / 2; i++) {
if(!alg_hnsw->isMarkedDeleted(alg_hnsw->label_lookup_[labels[i]]))
alg_hnsw->markDelete(labels[i]);
}

std::cout << "Updating index size" << std::endl;
size_t curr_ele_count = alg_hnsw->getCurrentElementCount();
if(curr_ele_count + max_elements > alg_hnsw->getMaxElements()) {
alg_hnsw->resizeIndex((curr_ele_count + max_elements) * 1.3);
}

std::cout << "Adding update batch" << std::endl;
for(int row; row < num_elements; row++) {
alg_hnsw->addPoint((void*)(update_batch + d * row), 42, true);
};


std::cout << "Searching for 10 nearest neighbors" << std::endl;
auto results = alg_hnsw->searchKnnCloserFirst((void*)(initial_batch), 10);

// check if the search results contain duplicate labels
std::cout << "Checking search results for duplicate labels" << std::endl;
std::unordered_set<int> labels_set;
for (int i = 0; i < results.size(); i++) {
labels_set.insert(results[i].second);
}
if (labels_set.size() != 10) {
std::cout << "Search results contain duplicate labels" << std::endl;
throw std::runtime_error("Search results contain duplicate labels");
}

delete[] initial_batch;
delete[] update_batch;
delete alg_hnsw;
return 0;
}
Loading