Skip to content

Commit

Permalink
Repair graph method nmslib#515
Browse files Browse the repository at this point in the history
  • Loading branch information
kishorenc committed Oct 30, 2023
1 parent 5aba40d commit 5100d3f
Show file tree
Hide file tree
Showing 4 changed files with 253 additions and 31 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,5 @@ jobs:
./multiThread_replace_test
./test_updates
./test_updates update
./repair_test
shell: bash
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,7 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)

add_executable(main tests/cpp/main.cpp tests/cpp/sift_1b.cpp)
target_link_libraries(main hnswlib)

add_executable(repair_test tests/cpp/repair_test.cpp)
target_link_libraries(repair_test hnswlib)
endif()
155 changes: 124 additions & 31 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
std::mutex deleted_elements_lock; // lock for deleted_elements
std::unordered_set<tableint> deleted_elements; // contains internal ids of deleted elements

std::mutex repair_lock; // locks graph repair


HierarchicalNSW(SpaceInterface<dist_t> *s) {
}
Expand Down Expand Up @@ -190,9 +192,9 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}


int getRandomLevel(double reverse_size) {
int getRandomLevel(double ml) {
std::uniform_real_distribution<double> distribution(0.0, 1.0);
double r = -log(distribution(level_generator_)) * reverse_size;
double r = -log(distribution(level_generator_)) * ml;
return (int) r;
}

Expand Down Expand Up @@ -240,14 +242,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);

int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
if (layer == 0) {
data = (int*)get_linklist0(curNodeNum);
} else {
data = (int*)get_linklist(curNodeNum, layer);
// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
}
size_t size = getListCount((linklistsizeint*)data);
linklistsizeint *data = get_linklist_at_level(curNodeNum, layer);
size_t size = getListCount(data);
tableint *datal = (tableint *) (data + 1);
#ifdef USE_SSE
_mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
Expand Down Expand Up @@ -325,8 +321,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
candidate_set.pop();

tableint current_node_id = current_node_pair.second;
int *data = (int *) get_linklist0(current_node_id);
size_t size = getListCount((linklistsizeint*)data);
linklistsizeint *data = get_linklist0(current_node_id);
size_t size = getListCount(data);
// bool cur_node_deleted = isMarkedDeleted(current_node_id);
if (collect_metrics) {
metric_hops++;
Expand Down Expand Up @@ -471,11 +467,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
if (isUpdate) {
lock.lock();
}
linklistsizeint *ll_cur;
if (level == 0)
ll_cur = get_linklist0(cur_c);
else
ll_cur = get_linklist(cur_c, level);
linklistsizeint *ll_cur = get_linklist_at_level(cur_c, level);

if (*ll_cur && !isUpdate) {
throw std::runtime_error("The newly inserted element should have blank link list");
Expand All @@ -495,12 +487,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) {
std::unique_lock <std::mutex> lock(link_list_locks_[selectedNeighbors[idx]]);

linklistsizeint *ll_other;
if (level == 0)
ll_other = get_linklist0(selectedNeighbors[idx]);
else
ll_other = get_linklist(selectedNeighbors[idx], level);

linklistsizeint *ll_other = get_linklist_at_level(selectedNeighbors[idx], level);
size_t sz_link_list_other = getListCount(ll_other);

if (sz_link_list_other > Mcurmax)
Expand Down Expand Up @@ -969,8 +956,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

{
std::unique_lock <std::mutex> lock(link_list_locks_[neigh]);
linklistsizeint *ll_cur;
ll_cur = get_linklist_at_level(neigh, layer);
linklistsizeint *ll_cur = get_linklist_at_level(neigh, layer);
size_t candSize = candidates.size();
setListCount(ll_cur, candSize);
tableint *data = (tableint *) (ll_cur + 1);
Expand Down Expand Up @@ -999,7 +985,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
bool changed = true;
while (changed) {
changed = false;
unsigned int *data;
linklistsizeint *data;
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
data = get_linklist_at_level(currObj, level);
int size = getListCount(data);
Expand Down Expand Up @@ -1057,7 +1043,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {

std::vector<tableint> getConnectionsWithLock(tableint internalId, int level) {
std::unique_lock <std::mutex> lock(link_list_locks_[internalId]);
unsigned int *data = get_linklist_at_level(internalId, level);
linklistsizeint *data = get_linklist_at_level(internalId, level);
int size = getListCount(data);
std::vector<tableint> result(size);
tableint *ll = (tableint *) (data + 1);
Expand Down Expand Up @@ -1095,6 +1081,10 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}

cur_c = cur_element_count;
// use the element level as a flag to show that an element is not added yet
// the element count is increased but no lock is aquired
// so someone can start using the new element
element_levels_[cur_c] = -1;
cur_element_count++;
label_lookup_[label] = cur_c;
}
Expand Down Expand Up @@ -1134,7 +1124,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
bool changed = true;
while (changed) {
changed = false;
unsigned int *data;
linklistsizeint *data;
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
data = get_linklist(currObj, level);
int size = getListCount(data);
Expand Down Expand Up @@ -1196,9 +1186,7 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
bool changed = true;
while (changed) {
changed = false;
unsigned int *data;

data = (unsigned int *) get_linklist(currObj, level);
linklistsizeint *data = get_linklist(currObj, level);
int size = getListCount(data);
metric_hops++;
metric_distance_computations+=size;
Expand Down Expand Up @@ -1271,5 +1259,110 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
}
std::cout << "integrity ok, checked " << connections_checked << " connections\n";
}


void repair_zero_indegree() {
// only one repair is allowed to be in progress at a time
std::unique_lock <std::mutex> lock_repair(repair_lock);

int maxlevel_copy = maxlevel_;
size_t element_count_copy = cur_element_count;
std::vector<size_t> indegree(element_count_copy);

for (int level = maxlevel_copy; level >=0 ; level--) {
std::fill(indegree.begin(), indegree.end(), 0);

size_t m_max = level ? maxM_ : maxM0_;
int num_elements = 0;
// calculate in-degree
for (tableint internal_id = 0; internal_id < element_count_copy; internal_id++) {
// lock until addition is finished
std::unique_lock <std::mutex> lock_el(link_list_locks_[internal_id]);
// skip elements that are not in the current level
// Note: if the element was not added to the graph before the lock
// then element_level = -1 and we skip it as well
int element_level = element_levels_[internal_id];
if (element_level < level) {
continue;
}

linklistsizeint *ll = get_linklist_at_level(internal_id, level);
int size = getListCount(ll);
tableint *datal = (tableint *) (ll + 1);
for (int i = 0; i < size; i++) {
tableint nei_id = datal[i];
// skip newly added elements
if (nei_id >= element_count_copy) {
continue;
}
indegree[nei_id] += 1;
}
num_elements += 1;
}

// skip levels with 1 element
if (num_elements <= 1) {
continue;
}

// fix elements with 0 in-degree
for (tableint internal_id = 0; internal_id < element_count_copy; internal_id++) {
int element_level = element_levels_[internal_id];
if (element_level < level || indegree[internal_id] > 0) {
continue;
}

char* data_point = getDataByInternalId(internal_id);
tableint currObj = enterpoint_node_;

dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_);
for (int level_above = maxlevel_copy; level_above > level; level_above--) {
bool changed = true;
while (changed) {
changed = false;
linklistsizeint *data;
std::unique_lock <std::mutex> lock(link_list_locks_[currObj]);
data = get_linklist_at_level(currObj, level_above);
int size = getListCount(data);

tableint *datal = (tableint *) (data + 1);
for (int i = 0; i < size; i++) {
tableint cand = datal[i];
dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_);
if (d < curdist) {
curdist = d;
currObj = cand;
changed = true;
}
}
}
}

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidates = searchBaseLayer(
currObj, data_point, level);

while (candidates.size() > 0) {
tableint cand_id = candidates.top().second;
// skip same element
if (cand_id == internal_id) {
candidates.pop();
continue;
}

// try to connect candidate to the element
// add an edge if there is space
std::unique_lock <std::mutex> lock(link_list_locks_[cand_id]);
linklistsizeint *ll_cand = get_linklist_at_level(cand_id, level);
tableint *data_cand = (tableint *) (ll_cand + 1);
size_t size = getListCount(ll_cand);
if (size < m_max) {
data_cand[size] = internal_id;
setListCount(ll_cand, size + 1);
}
candidates.pop();
}
}
}
}
};
} // namespace hnswlib
125 changes: 125 additions & 0 deletions tests/cpp/repair_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#include "../../hnswlib/hnswlib.h"
#include <thread>


bool is_indegree_ok(hnswlib::HierarchicalNSW<float>* alg_hnsw) {
bool is_ok_flag = true;
std::vector<int> indegree(alg_hnsw->cur_element_count);

for (int level = alg_hnsw->maxlevel_; level >=0 ; level--) {
std::fill(indegree.begin(), indegree.end(), 0);
int num_elements = 0;
// calculate in-degree
for (int internal_id = 0; internal_id < alg_hnsw->cur_element_count; internal_id++) {
int element_level = alg_hnsw->element_levels_[internal_id];
if (element_level < level) {
continue;
}
std::vector<hnswlib::tableint> neis = alg_hnsw->getConnectionsWithLock(internal_id, level);
for (hnswlib::tableint nei : neis) {
indegree[nei] += 1;
}
num_elements += 1;
}
// skip levels with 1 element
if (num_elements <= 1) {
continue;
}

// check in-degree
for (int internal_id = 0; internal_id < alg_hnsw->cur_element_count; internal_id++) {
int element_level = alg_hnsw->element_levels_[internal_id];
if (element_level < level) {
continue;
}
if (indegree[internal_id] == 0) {
std::cout << "zero in-degree node found, level=" << level << " id=" << internal_id << "\n" << std::flush;
is_ok_flag = false;
}
}
}

return is_ok_flag;
}


int main() {
int dim = 4; // Dimension of the elements
int n = 1000; // Maximum number of elements, should be known beforehand
int M = 8; // Tightly connected with internal dimensionality of the data
// strongly affects the memory consumption
int ef_construction = 200; // Controls index search speed/build speed tradeoff
int num_test_iter = 5;

int test_id = 0;

std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib_real;
while (test_id < num_test_iter) {
// Initing index
std::cout << "Initing index" << std::endl;
hnswlib::L2Space space(dim);
hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, n * 3, M, ef_construction, 100, true);

// Generate random data
float* data = new float[dim * n];
for (int i = 0; i < dim * n; i++) {
data[i] = distrib_real(rng);
}

std::cout << "Add data to index" << std::endl;

// Add data to index
for (int i = 0; i < n; i++) {
//std::cout << "insert " << i << std::endl;
alg_hnsw->addPoint(data + i * dim, i, true);
}
std::cout << test_id << " Index is ready\n";

std::vector<std::thread> threads;

// mix new inserts with modifications (50% of operations are new)
for(int i = 0; i < n; i += 100) {
//std::cout << "mixed insert " << i << std::endl;

threads.emplace_back([alg_hnsw, data, i, dim]() {
std::uniform_real_distribution<> distrib_real;
std::mt19937 rng;
rng.seed(49);

for(auto j = 0; j < 10; j++) {
auto actual_index = i + j;
auto id = ( actual_index % 2 != 0) ? actual_index + 10000 : actual_index;
std::vector<float> values;
for (size_t j = 0; j < dim; j++) {
values.push_back(distrib_real(rng) + 0.01);
}
alg_hnsw->addPoint(values.data(), id, true);
}
});
}

// add repair method to check concurrency
threads.emplace_back([alg_hnsw] {
alg_hnsw->repair_zero_indegree();
});

for(auto& t: threads) {
t.join();
}

bool is_ok_before_flag = is_indegree_ok(alg_hnsw);
// fix in-degree if it is broken
if (!is_ok_before_flag) {
alg_hnsw->repair_zero_indegree();
}
bool is_ok_after_flag = is_indegree_ok(alg_hnsw);
assert(is_ok_after_flag);
test_id += 1;

delete[] data;
delete alg_hnsw;
}
return 0;
}

0 comments on commit 5100d3f

Please sign in to comment.