Skip to content

Commit

Permalink
Tweak example.
Browse files Browse the repository at this point in the history
  • Loading branch information
kishorenc committed Oct 30, 2023
1 parent 37921cf commit 62fefce
Showing 1 changed file with 107 additions and 38 deletions.
145 changes: 107 additions & 38 deletions examples/cpp/example_search.cpp
Original file line number Diff line number Diff line change
@@ -1,58 +1,127 @@
#include "../../hnswlib/hnswlib.h"
#include <iostream>
#include <regex>
#include <string>
#include <fstream>
#include "json.hpp"

std::pair<std::vector<float>, int> extractFields(const std::string& input) {
auto doc = nlohmann::json::parse(input);
std::vector<float> embedding = doc["embedding"];
int id = std::stol(doc["id"].get<std::string>());
return std::make_pair(embedding, id);
}

static void normalize_vector(const std::vector<float>& src, std::vector<float>& norm_dest) {
float norm = 0.0f;
for (float i : src) {
norm += i * i;
}
norm = 1.0f / (sqrtf(norm) + 1e-30f);
for (size_t i = 0; i < src.size(); i++) {
norm_dest[i] = src[i] * norm;
}
}

int main() {
int dim = 16; // Dimension of the elements
int max_elements = 10000; // Maximum number of elements, should be known beforehand
int dim = 384; // Dimension of the elements
int max_elements = 1000; // Maximum number of elements, should be known beforehand
int M = 16; // Tightly connected with internal dimensionality of the data
// strongly affects the memory consumption
int ef_construction = 200; // Controls index search speed/build speed tradeoff

// Initing index
hnswlib::L2Space space(dim);
hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, M, ef_construction);
// init index
hnswlib::InnerProductSpace space(dim);
hnswlib::HierarchicalNSW<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, max_elements, M,
ef_construction, 100, true);


std::ifstream file("/tmp/titles1.jsonl");
std::vector<std::vector<float>> embeddings;
std::vector<int> ids;
size_t num_records = 0;

// Generate random data
std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib_real;
float* data = new float[dim * max_elements];
for (int i = 0; i < dim * max_elements; i++) {
data[i] = distrib_real(rng);
if (file.is_open()) {
std::string line;
while (std::getline(file, line)) {
auto pair = extractFields(line);
auto& embedding = pair.first;
auto& id = pair.second;

std::vector<float> normalized_vals(embedding.size());
normalize_vector(embedding, normalized_vals);

embeddings.push_back(normalized_vals);
ids.push_back(id);
num_records++;
}

file.close();
}

// Add data to index
for (int i = 0; i < max_elements; i++) {
alg_hnsw->addPoint(data + i * dim, i);
for(auto i = 0; i < embeddings.size(); i++) {
alg_hnsw->addPoint(embeddings[i].data(), ids[i], true);
}

// Query the elements for themselves and measure recall
float correct = 0;
for (int i = 0; i < max_elements; i++) {
std::priority_queue<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnn(data + i * dim, 1);
hnswlib::labeltype label = result.top().second;
if (label == i) correct++;
}
float recall = correct / max_elements;
std::cout << "Recall: " << recall << "\n";

// Serialize index
std::string hnsw_path = "hnsw.bin";
alg_hnsw->saveIndex(hnsw_path);
delete alg_hnsw;

// Deserialize index and check recall
alg_hnsw = new hnswlib::HierarchicalNSW<float>(&space, hnsw_path);
correct = 0;
for (int i = 0; i < max_elements; i++) {
std::priority_queue<std::pair<float, hnswlib::labeltype>> result = alg_hnsw->searchKnn(data + i * dim, 1);
hnswlib::labeltype label = result.top().second;
if (label == i) correct++;
std::priority_queue<std::pair<float, hnswlib::labeltype>> hnsw_result = alg_hnsw->searchKnn(embeddings[0].data(), 5);
std::vector<int> results;

for(auto i = 0; i < 5; i++) {
hnswlib::labeltype label = hnsw_result.top().second;
results.push_back(label);
hnsw_result.pop();
}

std::reverse(results.begin(), results.end());
for(auto label: results) {
std::cout << "hnsw label: " << label << std::endl;
}

// reindex with updates
std::ifstream file2("/tmp/titles2.jsonl");
embeddings.clear();
ids.clear();

if (file2.is_open()) {
std::string line;
while (std::getline(file2, line)) {
auto pair = extractFields(line);
auto& embedding = pair.first;
auto& id = pair.second;

std::vector<float> normalized_vals(embedding.size());
normalize_vector(embedding, normalized_vals);

embeddings.push_back(normalized_vals);
ids.push_back(id);
num_records++;
}

file2.close();
}

for(auto i = 0; i < embeddings.size(); i++) {
alg_hnsw->addPoint(embeddings[i].data(), ids[i], true);
}

std::cout << "-----\n\n\n";

hnsw_result = alg_hnsw->searchKnn(embeddings[28].data(), 5);
results.clear();

for(auto i = 0; i < 5; i++) {
hnswlib::labeltype label = hnsw_result.top().second;
results.push_back(label);
hnsw_result.pop();
}

std::reverse(results.begin(), results.end());
for(auto label: results) {
std::cout << "hnsw label: " << label << std::endl;
}
recall = (float)correct / max_elements;
std::cout << "Recall of deserialized index: " << recall << "\n";

delete[] data;
delete alg_hnsw;
return 0;
}

0 comments on commit 62fefce

Please sign in to comment.