diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 8cfa469a..42f55416 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -73,4 +73,5 @@ jobs: ./multiThread_replace_test ./test_updates ./test_updates update + ./repair_test shell: bash diff --git a/CMakeLists.txt b/CMakeLists.txt index 7cebe600..d5a22baf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,4 +53,7 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) add_executable(main tests/cpp/main.cpp tests/cpp/sift_1b.cpp) target_link_libraries(main hnswlib) + + add_executable(repair_test tests/cpp/repair_test.cpp) + target_link_libraries(repair_test hnswlib) endif() diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index a97a6540..f7d7f264 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -69,6 +69,8 @@ class HierarchicalNSW : public AlgorithmInterface { std::mutex deleted_elements_lock; // lock for deleted_elements std::unordered_set deleted_elements; // contains internal ids of deleted elements + std::mutex repair_lock; // locks graph repair + HierarchicalNSW(SpaceInterface *s) { } @@ -190,9 +192,9 @@ class HierarchicalNSW : public AlgorithmInterface { } - int getRandomLevel(double reverse_size) { + int getRandomLevel(double ml) { std::uniform_real_distribution distribution(0.0, 1.0); - double r = -log(distribution(level_generator_)) * reverse_size; + double r = -log(distribution(level_generator_)) * ml; return (int) r; } @@ -240,14 +242,8 @@ class HierarchicalNSW : public AlgorithmInterface { std::unique_lock lock(link_list_locks_[curNodeNum]); - int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); - if (layer == 0) { - data = (int*)get_linklist0(curNodeNum); - } else { - data = (int*)get_linklist(curNodeNum, layer); -// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); - } - size_t size = getListCount((linklistsizeint*)data); + linklistsizeint *data = get_linklist_at_level(curNodeNum, layer); + size_t size = getListCount(data); tableint *datal = (tableint *) (data + 1); #ifdef USE_SSE _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); @@ -325,8 +321,8 @@ class HierarchicalNSW : public AlgorithmInterface { candidate_set.pop(); tableint current_node_id = current_node_pair.second; - int *data = (int *) get_linklist0(current_node_id); - size_t size = getListCount((linklistsizeint*)data); + linklistsizeint *data = get_linklist0(current_node_id); + size_t size = getListCount(data); // bool cur_node_deleted = isMarkedDeleted(current_node_id); if (collect_metrics) { metric_hops++; @@ -471,11 +467,7 @@ class HierarchicalNSW : public AlgorithmInterface { if (isUpdate) { lock.lock(); } - linklistsizeint *ll_cur; - if (level == 0) - ll_cur = get_linklist0(cur_c); - else - ll_cur = get_linklist(cur_c, level); + linklistsizeint *ll_cur = get_linklist_at_level(cur_c, level); if (*ll_cur && !isUpdate) { throw std::runtime_error("The newly inserted element should have blank link list"); @@ -495,12 +487,7 @@ class HierarchicalNSW : public AlgorithmInterface { for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); - linklistsizeint *ll_other; - if (level == 0) - ll_other = get_linklist0(selectedNeighbors[idx]); - else - ll_other = get_linklist(selectedNeighbors[idx], level); - + linklistsizeint *ll_other = get_linklist_at_level(selectedNeighbors[idx], level); size_t sz_link_list_other = getListCount(ll_other); if (sz_link_list_other > Mcurmax) @@ -969,8 +956,7 @@ class HierarchicalNSW : public AlgorithmInterface { { std::unique_lock lock(link_list_locks_[neigh]); - linklistsizeint *ll_cur; - ll_cur = get_linklist_at_level(neigh, layer); + linklistsizeint *ll_cur = get_linklist_at_level(neigh, layer); size_t candSize = candidates.size(); setListCount(ll_cur, candSize); tableint *data = (tableint *) (ll_cur + 1); @@ -999,7 +985,7 @@ class HierarchicalNSW : public AlgorithmInterface { bool changed = true; while (changed) { changed = false; - unsigned int *data; + linklistsizeint *data; std::unique_lock lock(link_list_locks_[currObj]); data = get_linklist_at_level(currObj, level); int size = getListCount(data); @@ -1057,7 +1043,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::vector getConnectionsWithLock(tableint internalId, int level) { std::unique_lock lock(link_list_locks_[internalId]); - unsigned int *data = get_linklist_at_level(internalId, level); + linklistsizeint *data = get_linklist_at_level(internalId, level); int size = getListCount(data); std::vector result(size); tableint *ll = (tableint *) (data + 1); @@ -1095,6 +1081,10 @@ class HierarchicalNSW : public AlgorithmInterface { } cur_c = cur_element_count; + // use the element level as a flag to show that an element is not added yet + // the element count is increased but no lock is aquired + // so someone can start using the new element + element_levels_[cur_c] = -1; cur_element_count++; label_lookup_[label] = cur_c; } @@ -1134,7 +1124,7 @@ class HierarchicalNSW : public AlgorithmInterface { bool changed = true; while (changed) { changed = false; - unsigned int *data; + linklistsizeint *data; std::unique_lock lock(link_list_locks_[currObj]); data = get_linklist(currObj, level); int size = getListCount(data); @@ -1196,9 +1186,7 @@ class HierarchicalNSW : public AlgorithmInterface { bool changed = true; while (changed) { changed = false; - unsigned int *data; - - data = (unsigned int *) get_linklist(currObj, level); + linklistsizeint *data = get_linklist(currObj, level); int size = getListCount(data); metric_hops++; metric_distance_computations+=size; @@ -1271,5 +1259,110 @@ class HierarchicalNSW : public AlgorithmInterface { } std::cout << "integrity ok, checked " << connections_checked << " connections\n"; } + + + void repair_zero_indegree() { + // only one repair is allowed to be in progress at a time + std::unique_lock lock_repair(repair_lock); + + int maxlevel_copy = maxlevel_; + size_t element_count_copy = cur_element_count; + std::vector indegree(element_count_copy); + + for (int level = maxlevel_copy; level >=0 ; level--) { + std::fill(indegree.begin(), indegree.end(), 0); + + size_t m_max = level ? maxM_ : maxM0_; + int num_elements = 0; + // calculate in-degree + for (tableint internal_id = 0; internal_id < element_count_copy; internal_id++) { + // lock until addition is finished + std::unique_lock lock_el(link_list_locks_[internal_id]); + // skip elements that are not in the current level + // Note: if the element was not added to the graph before the lock + // then element_level = -1 and we skip it as well + int element_level = element_levels_[internal_id]; + if (element_level < level) { + continue; + } + + linklistsizeint *ll = get_linklist_at_level(internal_id, level); + int size = getListCount(ll); + tableint *datal = (tableint *) (ll + 1); + for (int i = 0; i < size; i++) { + tableint nei_id = datal[i]; + // skip newly added elements + if (nei_id >= element_count_copy) { + continue; + } + indegree[nei_id] += 1; + } + num_elements += 1; + } + + // skip levels with 1 element + if (num_elements <= 1) { + continue; + } + + // fix elements with 0 in-degree + for (tableint internal_id = 0; internal_id < element_count_copy; internal_id++) { + int element_level = element_levels_[internal_id]; + if (element_level < level || indegree[internal_id] > 0) { + continue; + } + + char* data_point = getDataByInternalId(internal_id); + tableint currObj = enterpoint_node_; + + dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); + for (int level_above = maxlevel_copy; level_above > level; level_above--) { + bool changed = true; + while (changed) { + changed = false; + linklistsizeint *data; + std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist_at_level(currObj, level_above); + int size = getListCount(data); + + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; + } + } + } + } + + std::priority_queue, std::vector>, CompareByFirst> candidates = searchBaseLayer( + currObj, data_point, level); + + while (candidates.size() > 0) { + tableint cand_id = candidates.top().second; + // skip same element + if (cand_id == internal_id) { + candidates.pop(); + continue; + } + + // try to connect candidate to the element + // add an edge if there is space + std::unique_lock lock(link_list_locks_[cand_id]); + linklistsizeint *ll_cand = get_linklist_at_level(cand_id, level); + tableint *data_cand = (tableint *) (ll_cand + 1); + size_t size = getListCount(ll_cand); + if (size < m_max) { + data_cand[size] = internal_id; + setListCount(ll_cand, size + 1); + } + candidates.pop(); + } + } + } + } }; } // namespace hnswlib diff --git a/tests/cpp/repair_test.cpp b/tests/cpp/repair_test.cpp new file mode 100644 index 00000000..f756814a --- /dev/null +++ b/tests/cpp/repair_test.cpp @@ -0,0 +1,125 @@ +#include "../../hnswlib/hnswlib.h" +#include + + +bool is_indegree_ok(hnswlib::HierarchicalNSW* alg_hnsw) { + bool is_ok_flag = true; + std::vector indegree(alg_hnsw->cur_element_count); + + for (int level = alg_hnsw->maxlevel_; level >=0 ; level--) { + std::fill(indegree.begin(), indegree.end(), 0); + int num_elements = 0; + // calculate in-degree + for (int internal_id = 0; internal_id < alg_hnsw->cur_element_count; internal_id++) { + int element_level = alg_hnsw->element_levels_[internal_id]; + if (element_level < level) { + continue; + } + std::vector neis = alg_hnsw->getConnectionsWithLock(internal_id, level); + for (hnswlib::tableint nei : neis) { + indegree[nei] += 1; + } + num_elements += 1; + } + // skip levels with 1 element + if (num_elements <= 1) { + continue; + } + + // check in-degree + for (int internal_id = 0; internal_id < alg_hnsw->cur_element_count; internal_id++) { + int element_level = alg_hnsw->element_levels_[internal_id]; + if (element_level < level) { + continue; + } + if (indegree[internal_id] == 0) { + std::cout << "zero in-degree node found, level=" << level << " id=" << internal_id << "\n" << std::flush; + is_ok_flag = false; + } + } + } + + return is_ok_flag; +} + + +int main() { + int dim = 4; // Dimension of the elements + int n = 1000; // Maximum number of elements, should be known beforehand + int M = 8; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + int num_test_iter = 5; + + int test_id = 0; + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + while (test_id < num_test_iter) { + // Initing index + std::cout << "Initing index" << std::endl; + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, n * 3, M, ef_construction, 100, true); + + // Generate random data + float* data = new float[dim * n]; + for (int i = 0; i < dim * n; i++) { + data[i] = distrib_real(rng); + } + + std::cout << "Add data to index" << std::endl; + + // Add data to index + for (int i = 0; i < n; i++) { + //std::cout << "insert " << i << std::endl; + alg_hnsw->addPoint(data + i * dim, i, true); + } + std::cout << test_id << " Index is ready\n"; + + std::vector threads; + + // mix new inserts with modifications (50% of operations are new) + for(int i = 0; i < n; i += 100) { + //std::cout << "mixed insert " << i << std::endl; + + threads.emplace_back([alg_hnsw, data, i, dim]() { + std::uniform_real_distribution<> distrib_real; + std::mt19937 rng; + rng.seed(49); + + for(auto j = 0; j < 10; j++) { + auto actual_index = i + j; + auto id = ( actual_index % 2 != 0) ? actual_index + 10000 : actual_index; + std::vector values; + for (size_t j = 0; j < dim; j++) { + values.push_back(distrib_real(rng) + 0.01); + } + alg_hnsw->addPoint(values.data(), id, true); + } + }); + } + + // add repair method to check concurrency + threads.emplace_back([alg_hnsw] { + alg_hnsw->repair_zero_indegree(); + }); + + for(auto& t: threads) { + t.join(); + } + + bool is_ok_before_flag = is_indegree_ok(alg_hnsw); + // fix in-degree if it is broken + if (!is_ok_before_flag) { + alg_hnsw->repair_zero_indegree(); + } + bool is_ok_after_flag = is_indegree_ok(alg_hnsw); + assert(is_ok_after_flag); + test_id += 1; + + delete[] data; + delete alg_hnsw; + } + return 0; +}