Skip to content

Commit

Permalink
add vector range search
Browse files Browse the repository at this point in the history
  • Loading branch information
ljcui committed Sep 24, 2024
1 parent 334c4fc commit 3408a6d
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 48 deletions.
15 changes: 11 additions & 4 deletions docs/zh-CN/source/8.query/3.vector_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,27 @@ CREATE (n1)-[r:like]->(n2),
(n3)-[r:like]->(n1);
```
## 向量查询

### KnnSearch
根据向量搜索出点,第四个参数是个map,里面可以指定一些向量搜索的参数。
```
CALL db.vertexVectorIndexQuery('person','embedding', [1.0,2.0,3.0,4.0], {top_k:2, hnsw_ef_search:10})
CALL db.vertexVectorKnnSearch('person','embedding', [1.0,2.0,3.0,4.0], {top_k:2, hnsw_ef_search:10})
yield node return node
```
根据向量搜索出点,返回`age`小于30的
```
CALL db.vertexVectorIndexQuery('person','embedding',[1.0,2.0,3.0,4.0], {top_k:2, hnsw_ef_search:10})
CALL db.vertexVectorKnnSearch('person','embedding',[1.0,2.0,3.0,4.0], {top_k:2, hnsw_ef_search:10})
yield node where node.age < 30 return node
```
根据向量搜索出点,返回age小于30的点,然后再查这些点的一度邻居是谁。
```
CALL db.vertexVectorIndexQuery('person','embedding',[1.0,2.0,3.0,4.0], {top_k:2, hnsw_ef_search:10})
CALL db.vertexVectorKnnSearch('person','embedding',[1.0,2.0,3.0,4.0], {top_k:2, hnsw_ef_search:10})
yield node where node.age < 30 with node as p
match(p)-[r]->(m) return m
```
### RangeSearch
根据向量搜索出距离小于10的、age小于30的点,然后再查这些点的一度邻居是谁。
```
CALL db.vertexVectorRangeSearch('person','embedding',[1.0,2.0,3.0,4.0], {radius:10.0, hnsw_ef_search:10})
yield node where node.age < 30 with node as p
match(p)-[r]->(m) return m
```
5 changes: 4 additions & 1 deletion src/core/vector_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class VectorIndex {

// search vector in index
virtual std::vector<std::pair<int64_t, float>>
Search(const std::vector<float>& query, int64_t num_results, int ef_search) = 0;
KnnSearch(const std::vector<float>& query, int64_t top_k, int ef_search) = 0;

virtual std::vector<std::pair<int64_t, float>>
RangeSearch(const std::vector<float>& query, float radius, int ef_search, int limit) = 0;
};
} // namespace lgraph
27 changes: 25 additions & 2 deletions src/core/vsag_hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ void HNSW::Load(std::vector<uint8_t>& idx_bytes) {

// search vector in index
std::vector<std::pair<int64_t, float>>
HNSW::Search(const std::vector<float>& query, int64_t num_results, int ef_search) {
HNSW::KnnSearch(const std::vector<float>& query, int64_t top_k, int ef_search) {
auto* query_copy = new float[query.size()];
std::copy(query.begin(), query.end(), query_copy);
auto dataset = vsag::Dataset::Make();
Expand All @@ -185,7 +185,30 @@ HNSW::Search(const std::vector<float>& query, int64_t num_results, int ef_search
{"hnsw", {{"ef_search", ef_search}}},
};
std::vector<std::pair<int64_t, float>> ret;
auto result = index_->KnnSearch(dataset, num_results, parameters.dump());
auto result = index_->KnnSearch(dataset, top_k, parameters.dump());
if (result.has_value()) {
for (int64_t i = 0; i < result.value()->GetDim(); ++i) {
ret.emplace_back(result.value()->GetIds()[i], result.value()->GetDistances()[i]);
}
} else {
THROW_CODE(VectorIndexException, result.error().message);
}
return ret;
}

std::vector<std::pair<int64_t, float>>
HNSW::RangeSearch(const std::vector<float>& query, float radius, int ef_search, int limit) {
auto* query_copy = new float[query.size()];
std::copy(query.begin(), query.end(), query_copy);
auto dataset = vsag::Dataset::Make();
dataset->Dim(vec_dimension_)
->NumElements(1)
->Float32Vectors(query_copy);
nlohmann::json parameters{
{"hnsw", {{"ef_search", ef_search}}},
};
std::vector<std::pair<int64_t, float>> ret;
auto result = index_->RangeSearch(dataset, radius, parameters.dump(), limit);
if (result.has_value()) {
for (int64_t i = 0; i < result.value()->GetDim(); ++i) {
ret.emplace_back(result.value()->GetIds()[i], result.value()->GetDistances()[i]);
Expand Down
7 changes: 5 additions & 2 deletions src/core/vsag_hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,11 @@ class HNSW : public VectorIndex {
void Load(std::vector<uint8_t>& idx_bytes) override;

// search vector in index
std::vector<std::pair<int64_t, float>> Search(
const std::vector<float>& query, int64_t num_results, int ef_search) override;
std::vector<std::pair<int64_t, float>> KnnSearch(
const std::vector<float>& query, int64_t top_k, int ef_search) override;

std::vector<std::pair<int64_t, float>> RangeSearch(
const std::vector<float>& query, float radius, int ef_search, int limit) override;

template <typename T>
static void writeBinaryPOD(std::ostream& out, const T& podRef) {
Expand Down
90 changes: 84 additions & 6 deletions src/cypher/procedure/procedure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4258,21 +4258,21 @@ void VectorFunc::ShowVertexVectorIndex(RTContext *ctx, const cypher::Record *rec
FillProcedureYieldItem("db.showVertexVectorIndex", yield_items, records);
}

void VectorFunc::VertexVectorIndexQuery(RTContext *ctx, const cypher::Record *record,
void VectorFunc::VertexVectorKnnSearch(RTContext *ctx, const cypher::Record *record,
const cypher::VEC_EXPR &args,
const cypher::VEC_STR &yield_items,
struct std::vector<cypher::Record> *records) {
CYPHER_DB_PROCEDURE_GRAPH_CHECK();
CYPHER_ARG_CHECK(args.size() == 4,
"e.g. db.vertexVectorIndexQuery(label_name, field_name,"
"e.g. db.vertexVectorKnnSearch(label_name, field_name,"
" vec, parameter);")
CYPHER_ARG_CHECK(args[0].IsString(),
"label_name type should be string")
CYPHER_ARG_CHECK(args[1].IsString(),
"field_name type should be string")
CYPHER_ARG_CHECK(args[2].IsArray(), "Please check the vector you entered, e.g. [1, 2, 3]")
CYPHER_ARG_CHECK(args[3].IsMap(), "parameter type should be map")
CheckProcedureYieldItem("db.vertexVectorIndexQuery", yield_items);
CheckProcedureYieldItem("db.vertexVectorKnnSearch", yield_items);
std::vector<float> query_vector;
for (size_t i = 0; i < args[2].constant.AsConstantArray().size(); i++) {
float fltValue;
Expand All @@ -4296,8 +4296,9 @@ void VectorFunc::VertexVectorIndexQuery(RTContext *ctx, const cypher::Record *re
auto label = args[0].constant.AsString();
auto field = args[1].constant.AsString();
auto index = ctx->txn_->GetTxn()->GetVertexVectorIndex(label, field);

int top_k = 10;
auto parameter = *args[3].constant.map;
auto parameter = *args[3].constant.map;
if (parameter.count("top_k")) {
top_k = parameter.at("top_k").AsInt64();
}
Expand All @@ -4308,7 +4309,83 @@ void VectorFunc::VertexVectorIndexQuery(RTContext *ctx, const cypher::Record *re
}
CYPHER_ARG_CHECK((ef_search <= 1000 && ef_search >= 1),
"hnsw.ef_search should be an integer in the range [1, 1000]");
auto res = index->Search(query_vector, top_k, ef_search);
auto res = index->KnnSearch(query_vector, top_k, ef_search);
for (auto& item : res) {
Record r;
cypher::Node n;
n.SetVid(item.first);
r.AddNode(&n);
r.AddConstant(lgraph::FieldData(item.second));
records->emplace_back(r.Snapshot());
}
FillProcedureYieldItem("db.vertexVectorKnnSearch", yield_items, records);
}

void VectorFunc::VertexVectorRangeSearch(RTContext *ctx, const cypher::Record *record,
const cypher::VEC_EXPR &args,
const cypher::VEC_STR &yield_items,
struct std::vector<cypher::Record> *records) {
CYPHER_DB_PROCEDURE_GRAPH_CHECK();
CYPHER_ARG_CHECK(args.size() == 4,
"e.g. db.vertexVectorRangeSearch(label_name, field_name,"
" vec, parameter);")
CYPHER_ARG_CHECK(args[0].IsString(),
"label_name type should be string")
CYPHER_ARG_CHECK(args[1].IsString(),
"field_name type should be string")
CYPHER_ARG_CHECK(args[2].IsArray(), "Please check the vector you entered, e.g. [1, 2, 3]")
CYPHER_ARG_CHECK(args[3].IsMap(), "parameter type should be map")
CheckProcedureYieldItem("db.vertexVectorRangeSearch", yield_items);
std::vector<float> query_vector;
for (size_t i = 0; i < args[2].constant.AsConstantArray().size(); i++) {
float fltValue;
if (args[2].constant.AsConstantArray().at(i).IsFloat()) {
float dblValue =
args[2].constant.AsConstantArray().at(i).AsFloat();
fltValue = static_cast<float>(dblValue);
} else if (args[2].constant.AsConstantArray().at(i).IsInteger()) {
int64_t dblValue =
args[2].constant.AsConstantArray().at(i).AsInt64();
fltValue = static_cast<float>(dblValue);
} else if (args[2].constant.AsConstantArray().at(i).IsDouble()) {
double dblValue =
args[2].constant.AsConstantArray().at(i).AsDouble();
fltValue = static_cast<float>(dblValue);
} else {
throw lgraph::ReminderException("Please check the vector");
}
query_vector.push_back(fltValue);
}
auto label = args[0].constant.AsString();
auto field = args[1].constant.AsString();
auto index = ctx->txn_->GetTxn()->GetVertexVectorIndex(label, field);

float radius = 0.1;
auto parameter = *args[3].constant.map;
if (parameter.count("radius")) {
if (parameter.at("radius").scalar.IsDouble()) {
radius = (float)parameter.at("radius").AsDouble();
} else if (parameter.at("radius").scalar.IsInteger()) {
radius = (float)parameter.at("radius").AsInt64();
} else {
throw lgraph::ReminderException("radius type error");
}
} else {
throw lgraph::ReminderException("radius is required for vector range search");
}
CYPHER_ARG_CHECK((radius > 0), "radius must be greater than 0");
int ef_search = 200;
if (parameter.count("hnsw_ef_search")) {
ef_search = parameter.at("hnsw_ef_search").AsInt64();
}
CYPHER_ARG_CHECK((ef_search <= 1000 && ef_search >= 1),
"hnsw.ef_search should be an integer in the range [1, 1000]");
int limit = -1;
if (parameter.count("limit")) {
limit = parameter.at("limit").AsInt64();
}
CYPHER_ARG_CHECK((limit != 0), "limit must not be 0");
auto res = index->RangeSearch(query_vector, radius, ef_search, limit);
for (auto& item : res) {
Record r;
cypher::Node n;
Expand All @@ -4317,6 +4394,7 @@ void VectorFunc::VertexVectorIndexQuery(RTContext *ctx, const cypher::Record *re
r.AddConstant(lgraph::FieldData(item.second));
records->emplace_back(r.Snapshot());
}
FillProcedureYieldItem("db.vertexVectorIndexQuery", yield_items, records);
FillProcedureYieldItem("db.vertexVectorRangeSearch", yield_items, records);
}

} // namespace cypher
19 changes: 16 additions & 3 deletions src/cypher/procedure/procedure.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,10 @@ class VectorFunc {
static void ShowVertexVectorIndex(RTContext *ctx, const Record *record, const VEC_EXPR &args,
const VEC_STR &yield_items, std::vector<Record> *records);

static void VertexVectorIndexQuery(RTContext *ctx, const Record *record, const VEC_EXPR &args,
static void VertexVectorKnnSearch(RTContext *ctx, const Record *record, const VEC_EXPR &args,
const VEC_STR &yield_items, std::vector<Record> *records);
static void VertexVectorRangeSearch(RTContext *ctx, const Record *record, const VEC_EXPR &args,
const VEC_STR &yield_items, std::vector<Record> *records);
};

struct Procedure {
Expand Down Expand Up @@ -950,7 +952,18 @@ static std::vector<Procedure> global_procedures = {
{"hnsm.ef_construction", {6, lgraph_api::LGraphType::INTEGER}},
}),

Procedure("db.vertexVectorIndexQuery", VectorFunc::VertexVectorIndexQuery,
Procedure("db.vertexVectorKnnSearch", VectorFunc::VertexVectorKnnSearch,
Procedure::SIG_SPEC{
{"label_name", {0, lgraph_api::LGraphType::STRING}},
{"field_name", {1, lgraph_api::LGraphType::STRING}},
{"vec", {2, lgraph_api::LGraphType::LIST}},
{"parameter", {3, lgraph_api::LGraphType::MAP}},
},
Procedure::SIG_SPEC{
{"node", {0, lgraph_api::LGraphType::NODE}},
{"distance", {1, lgraph_api::LGraphType::FLOAT}},
}),
Procedure("db.vertexVectorRangeSearch", VectorFunc::VertexVectorRangeSearch,
Procedure::SIG_SPEC{
{"label_name", {0, lgraph_api::LGraphType::STRING}},
{"field_name", {1, lgraph_api::LGraphType::STRING}},
Expand All @@ -959,7 +972,7 @@ static std::vector<Procedure> global_procedures = {
},
Procedure::SIG_SPEC{
{"node", {0, lgraph_api::LGraphType::NODE}},
{"score", {1, lgraph_api::LGraphType::FLOAT}},
{"distance", {1, lgraph_api::LGraphType::FLOAT}},
}),

Procedure("dbms.security.listRoles", BuiltinProcedure::DbmsSecurityListRoles,
Expand Down
Loading

0 comments on commit 3408a6d

Please sign in to comment.