From 21de4ef94f6d76807abb3408099816c59003d2ad Mon Sep 17 00:00:00 2001 From: Wang Zhiyong Date: Mon, 23 Sep 2024 10:23:15 +0000 Subject: [PATCH] fix vector index coredump --- src/core/index_manager.cpp | 4 ++-- src/core/vector_index.h | 2 +- src/core/vsag_hnsw.cpp | 6 ++---- src/core/vsag_hnsw.h | 2 +- src/cypher/procedure/procedure.cpp | 16 ++++++++-------- test/test_vsag_index.cpp | 12 ++++++------ 6 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/core/index_manager.cpp b/src/core/index_manager.cpp index 99b0cd0c0..4c46b0947 100644 --- a/src/core/index_manager.cpp +++ b/src/core/index_manager.cpp @@ -100,7 +100,7 @@ IndexManager::IndexManager(KvTransaction& txn, SchemaManager* v_schema_manager, Schema* schema = v_schema_manager->GetSchema(idx.label); FMA_DBG_ASSERT(schema); FMA_DBG_ASSERT(schema->DetachProperty()); - LOG_INFO() << FMA_FMT("start building vertex index for {}:{} in detached model", + LOG_INFO() << FMA_FMT("start building vertex vector index for {}:{} in detached model", idx.label, idx.field); const _detail::FieldExtractor* extractor = schema->GetFieldExtractor(idx.field); FMA_DBG_ASSERT(extractor); @@ -128,7 +128,7 @@ IndexManager::IndexManager(KvTransaction& txn, SchemaManager* v_schema_manager, kv_iter.reset(); LOG_DEBUG() << "index count: " << count; schema->MarkVectorIndexed(extractor->GetFieldId(), vsag_index.release()); - LOG_INFO() << FMA_FMT("end building vector index for {}:{} in detached model", + LOG_INFO() << FMA_FMT("end building vertex vector index for {}:{} in detached model", idx.label, idx.field); } else { LOG_ERROR() << "Unknown index type: " << index_name; diff --git a/src/core/vector_index.h b/src/core/vector_index.h index 1b15d9ca1..f9b956b52 100644 --- a/src/core/vector_index.h +++ b/src/core/vector_index.h @@ -69,7 +69,7 @@ class VectorIndex { const std::vector& vids, int64_t num_vectors) = 0; // build index - virtual bool Build() = 0; + virtual void Build() = 0; // serialize index virtual std::vector Save() = 0; diff --git a/src/core/vsag_hnsw.cpp b/src/core/vsag_hnsw.cpp index c1370afeb..e1dd4b1b7 100644 --- a/src/core/vsag_hnsw.cpp +++ b/src/core/vsag_hnsw.cpp @@ -65,7 +65,7 @@ void HNSW::Add(const std::vector>& vectors, } } -bool HNSW::Build() { +void HNSW::Build() { nlohmann::json hnsw_parameters{ {"max_degree", index_spec_[0]}, {"ef_construction", index_spec_[1]} @@ -81,10 +81,8 @@ bool HNSW::Build() { createindex_ = std::move(temp.value()); index_ = createindex_.get(); } else { - LOG_WARN() << FMA_FMT("create vsag index error: {}", temp.error().message); - return false; + THROW_CODE(VectorIndexException, temp.error().message); } - return true; } // serialize index diff --git a/src/core/vsag_hnsw.h b/src/core/vsag_hnsw.h index b1ca5b448..daef893b9 100644 --- a/src/core/vsag_hnsw.h +++ b/src/core/vsag_hnsw.h @@ -54,7 +54,7 @@ class HNSW : public VectorIndex { const std::vector& vids, int64_t num_vectors) override; // build index - bool Build() override; + void Build() override; // serialize index std::vector Save() override; diff --git a/src/cypher/procedure/procedure.cpp b/src/cypher/procedure/procedure.cpp index 7066c4034..78b2b42ae 100644 --- a/src/cypher/procedure/procedure.cpp +++ b/src/cypher/procedure/procedure.cpp @@ -4192,14 +4192,14 @@ void VectorFunc::AddVertexVectorIndex(RTContext *ctx, const cypher::Record *reco if (parameter.count("hnsm_m")) { hnsm_m = (int)parameter.at("hnsm_m").AsInt64(); } - CYPHER_ARG_CHECK((hnsm_m < 2048 && hnsm_m > 2), - "M should be an integer in the range (2,2048)"); + CYPHER_ARG_CHECK((hnsm_m <= 64 && hnsm_m >= 5), + "hnsm.m should be an integer in the range [5, 64]"); int hnsm_ef_construction = 100; if (parameter.count("hnsm_ef_construction")) { hnsm_ef_construction = (int)parameter.at("hnsm_ef_construction").AsInt64(); } - CYPHER_ARG_CHECK((hnsm_ef_construction < 65536 && hnsm_ef_construction > 1), - "efConstruction should be an integer in the range (1,65536)"); + CYPHER_ARG_CHECK((hnsm_ef_construction <= 1000 && hnsm_ef_construction >= hnsm_m), + "hnsm.efConstruction should be an integer in the range [hnsm.m,1000]"); std::vector index_spec = {hnsm_m, hnsm_ef_construction}; auto ac_db = ctx->galaxy_->OpenGraph(ctx->user_, ctx->graph_); bool success = ac_db.AddVectorIndex(true, label, field, index_type, @@ -4301,13 +4301,13 @@ void VectorFunc::VertexVectorIndexQuery(RTContext *ctx, const cypher::Record *re if (parameter.count("top_k")) { top_k = parameter.at("top_k").AsInt64(); } - int ef_search = 100; + CYPHER_ARG_CHECK((top_k >= 1), "top_k must be greater than 0"); + int ef_search = 200; if (parameter.count("hnsw_ef_search")) { ef_search = parameter.at("hnsw_ef_search").AsInt64(); } - CYPHER_ARG_CHECK((ef_search <= 65536 && ef_search >= top_k), - "Please check the parameter," - "ef should be an integer in the range [top_k, 65536]"); + CYPHER_ARG_CHECK((ef_search <= 1000 && ef_search >= 1), + "hnsw.ef_search should be an integer in the range [1, 1000]"); auto res = index->Search(query_vector, top_k, ef_search); for (auto& item : res) { Record r; diff --git a/test/test_vsag_index.cpp b/test/test_vsag_index.cpp index 1f8605bd3..a37deafe6 100644 --- a/test/test_vsag_index.cpp +++ b/test/test_vsag_index.cpp @@ -57,15 +57,15 @@ class TestVsag : public TuGraphTest { void TearDown() override {} }; -TEST_F(TestVsag, BuildIndex) { ASSERT_TRUE(vector_index->Build()); } +TEST_F(TestVsag, BuildIndex) { EXPECT_NO_THROW(vector_index->Build()); } TEST_F(TestVsag, AddVectors) { - ASSERT_TRUE(vector_index->Build()); + EXPECT_NO_THROW(vector_index->Build()); EXPECT_NO_THROW(vector_index->Add(vectors, vids, num_vectors)); } TEST_F(TestVsag, SearchIndex) { - ASSERT_TRUE(vector_index->Build()); + EXPECT_NO_THROW(vector_index->Build()); EXPECT_NO_THROW(vector_index->Add(vectors, vids, num_vectors)); std::vector query(vectors[0].begin(), vectors[0].end()); std::vector> ret; @@ -75,12 +75,12 @@ TEST_F(TestVsag, SearchIndex) { } TEST_F(TestVsag, SaveAndLoadIndex) { - ASSERT_TRUE(vector_index->Build()); + EXPECT_NO_THROW(vector_index->Build()); EXPECT_NO_THROW(vector_index->Add(vectors, vids, num_vectors)); std::vector serialized_index = vector_index->Save(); ASSERT_FALSE(serialized_index.empty()); lgraph::HNSW vector_index_loaded("label", "name", "l2", "hnsw", dim, index_spec); - ASSERT_TRUE(vector_index_loaded.Build()); + EXPECT_NO_THROW(vector_index_loaded.Build()); vector_index_loaded.Load(serialized_index); std::vector query(vectors[0].begin(), vectors[0].end()); auto ret = vector_index_loaded.Search(query, 10, 10); @@ -89,7 +89,7 @@ TEST_F(TestVsag, SaveAndLoadIndex) { } TEST_F(TestVsag, DeleteVectors) { - ASSERT_TRUE(vector_index->Build()); + EXPECT_NO_THROW(vector_index->Build()); EXPECT_NO_THROW(vector_index->Add(vectors, vids, num_vectors)); std::vector delete_vids = {vids[0], vids[1]}; EXPECT_NO_THROW(vector_index->Add({}, delete_vids, 0));