forked from nmslib/hnswlib
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
107 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,127 @@ | ||
#include "../../hnswlib/hnswlib.h" | ||
#include <iostream> | ||
#include <regex> | ||
#include <string> | ||
#include <fstream> | ||
#include "json.hpp" | ||
|
||
std::pair<std::vector<float>, int> extractFields(const std::string& input) { | ||
auto doc = nlohmann::json::parse(input); | ||
std::vector<float> embedding = doc["embedding"]; | ||
int id = std::stol(doc["id"].get<std::string>()); | ||
return std::make_pair(embedding, id); | ||
} | ||
|
||
static void normalize_vector(const std::vector<float>& src, std::vector<float>& norm_dest) { | ||
float norm = 0.0f; | ||
for (float i : src) { | ||
norm += i * i; | ||
} | ||
norm = 1.0f / (sqrtf(norm) + 1e-30f); | ||
for (size_t i = 0; i < src.size(); i++) { | ||
norm_dest[i] = src[i] * norm; | ||
} | ||
} | ||
|
||
int main() { | ||
int dim = 16; // Dimension of the elements | ||
int max_elements = 10000; // Maximum number of elements, should be known beforehand | ||
int dim = 384; // Dimension of the elements | ||
int max_elements = 1000; // Maximum number of elements, should be known beforehand | ||
int M = 16; // Tightly connected with internal dimensionality of the data | ||
// strongly affects the memory consumption | ||
int ef_construction = 200; // Controls index search speed/build speed tradeoff | ||
|
||
// Initing index | ||
hnswlib::L2Space space(dim); | ||
hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, M, ef_construction); | ||
// init index | ||
hnswlib::InnerProductSpace space(dim); | ||
hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, M, | ||
ef_construction, 100, true); | ||
|
||
|
||
std::ifstream file("/tmp/titles1.jsonl"); | ||
std::vector<std::vector<float>> embeddings; | ||
std::vector<int> ids; | ||
size_t num_records = 0; | ||
|
||
// Generate random data | ||
std::mt19937 rng; | ||
rng.seed(47); | ||
std::uniform_real_distribution<> distrib_real; | ||
float* data = new float[dim * max_elements]; | ||
for (int i = 0; i < dim * max_elements; i++) { | ||
data[i] = distrib_real(rng); | ||
if (file.is_open()) { | ||
std::string line; | ||
while (std::getline(file, line)) { | ||
auto pair = extractFields(line); | ||
auto& embedding = pair.first; | ||
auto& id = pair.second; | ||
|
||
std::vector<float> normalized_vals(embedding.size()); | ||
normalize_vector(embedding, normalized_vals); | ||
|
||
embeddings.push_back(normalized_vals); | ||
ids.push_back(id); | ||
num_records++; | ||
} | ||
|
||
file.close(); | ||
} | ||
|
||
// Add data to index | ||
for (int i = 0; i < max_elements; i++) { | ||
alg_hnsw->addPoint(data + i * dim, i); | ||
for(auto i = 0; i < embeddings.size(); i++) { | ||
alg_hnsw->addPoint(embeddings[i].data(), ids[i], true); | ||
} | ||
|
||
// Query the elements for themselves and measure recall | ||
float correct = 0; | ||
for (int i = 0; i < max_elements; i++) { | ||
std::priority_queue<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnn(data + i * dim, 1); | ||
hnswlib::labeltype label = result.top().second; | ||
if (label == i) correct++; | ||
} | ||
float recall = correct / max_elements; | ||
std::cout << "Recall: " << recall << "\n"; | ||
|
||
// Serialize index | ||
std::string hnsw_path = "hnsw.bin"; | ||
alg_hnsw->saveIndex(hnsw_path); | ||
delete alg_hnsw; | ||
|
||
// Deserialize index and check recall | ||
alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, hnsw_path); | ||
correct = 0; | ||
for (int i = 0; i < max_elements; i++) { | ||
std::priority_queue<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnn(data + i * dim, 1); | ||
hnswlib::labeltype label = result.top().second; | ||
if (label == i) correct++; | ||
std::priority_queue<std::pair<float, hnswlib::labeltype>> hnsw_result = alg_hnsw->searchKnn(embeddings[0].data(), 5); | ||
std::vector<int> results; | ||
|
||
for(auto i = 0; i < 5; i++) { | ||
hnswlib::labeltype label = hnsw_result.top().second; | ||
results.push_back(label); | ||
hnsw_result.pop(); | ||
} | ||
|
||
std::reverse(results.begin(), results.end()); | ||
for(auto label: results) { | ||
std::cout << "hnsw label: " << label << std::endl; | ||
} | ||
|
||
// reindex with updates | ||
std::ifstream file2("/tmp/titles2.jsonl"); | ||
embeddings.clear(); | ||
ids.clear(); | ||
|
||
if (file2.is_open()) { | ||
std::string line; | ||
while (std::getline(file2, line)) { | ||
auto pair = extractFields(line); | ||
auto& embedding = pair.first; | ||
auto& id = pair.second; | ||
|
||
std::vector<float> normalized_vals(embedding.size()); | ||
normalize_vector(embedding, normalized_vals); | ||
|
||
embeddings.push_back(normalized_vals); | ||
ids.push_back(id); | ||
num_records++; | ||
} | ||
|
||
file2.close(); | ||
} | ||
|
||
for(auto i = 0; i < embeddings.size(); i++) { | ||
alg_hnsw->addPoint(embeddings[i].data(), ids[i], true); | ||
} | ||
|
||
std::cout << "-----\n\n\n"; | ||
|
||
hnsw_result = alg_hnsw->searchKnn(embeddings[28].data(), 5); | ||
results.clear(); | ||
|
||
for(auto i = 0; i < 5; i++) { | ||
hnswlib::labeltype label = hnsw_result.top().second; | ||
results.push_back(label); | ||
hnsw_result.pop(); | ||
} | ||
|
||
std::reverse(results.begin(), results.end()); | ||
for(auto label: results) { | ||
std::cout << "hnsw label: " << label << std::endl; | ||
} | ||
recall = (float)correct / max_elements; | ||
std::cout << "Recall of deserialized index: " << recall << "\n"; | ||
|
||
delete[] data; | ||
delete alg_hnsw; | ||
return 0; | ||
} |