Skip to content

Commit

Permalink
[fix][index] Fix vector ivf pq in no data search failed.
Browse files Browse the repository at this point in the history
  • Loading branch information
Haijun Yu authored and ketor committed Mar 11, 2024
1 parent 3b89c9b commit 302c4ef
Show file tree
Hide file tree
Showing 11 changed files with 46 additions and 143 deletions.
28 changes: 2 additions & 26 deletions src/vector/vector_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -946,33 +946,9 @@ bool VectorIndexWrapper::IsPermanentHoldVectorIndex(int64_t region_id) {
}

butil::Status VectorIndexWrapper::SetVectorIndexRangeFilter(
VectorIndexPtr vector_index, std::vector<std::shared_ptr<VectorIndex::FilterFunctor>>& filters,
VectorIndexPtr /*vector_index*/, std::vector<std::shared_ptr<VectorIndex::FilterFunctor>>& filters,
int64_t min_vector_id, int64_t max_vector_id) {
if (vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_HNSW) {
filters.push_back(std::make_shared<VectorIndex::RangeFilterFunctor>(min_vector_id, max_vector_id));
} else if (vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_FLAT ||
vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_BRUTEFORCE) {
filters.push_back(std::make_shared<VectorIndex::RangeFilterFunctor>(min_vector_id, max_vector_id));
} else if (vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_IVF_FLAT) {
filters.push_back(std::make_shared<VectorIndex::RangeFilterFunctor>(min_vector_id, max_vector_id));
} else if (vector_index->VectorIndexType() == pb::common::VECTOR_INDEX_TYPE_IVF_PQ) {
if (vector_index->VectorIndexSubType() == pb::common::VECTOR_INDEX_TYPE_IVF_PQ) {
filters.push_back(std::make_shared<VectorIndex::RangeFilterFunctor>(min_vector_id, max_vector_id));
} else if (vector_index->VectorIndexSubType() == pb::common::VECTOR_INDEX_TYPE_FLAT) {
filters.push_back(std::make_shared<VectorIndex::RangeFilterFunctor>(min_vector_id, max_vector_id));
} else {
return butil::Status(pb::error::Errno::EVECTOR_NOT_SUPPORT,
fmt::format("SetVectorIndexFilter not support index type: {} sub type: {}",
pb::common::VectorIndexType_Name(vector_index->VectorIndexType()),
pb::common::VectorIndexType_Name(vector_index->VectorIndexSubType())));
}
} else {
return butil::Status(pb::error::Errno::EVECTOR_NOT_SUPPORT,
fmt::format("SetVectorIndexFilter not support index type: {} sub type: {}",
pb::common::VectorIndexType_Name(vector_index->VectorIndexType()),
pb::common::VectorIndexType_Name(vector_index->VectorIndexSubType())));
}

filters.push_back(std::make_shared<VectorIndex::RangeFilterFunctor>(min_vector_id, max_vector_id));
return butil::Status::OK();
}

Expand Down
49 changes: 0 additions & 49 deletions src/vector/vector_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,6 @@ class VectorIndex {
int64_t max_vector_id_;
};

// Range filter just for flat
// Range transform list
// deprecated
class FlatRangeFilterFunctor : public FilterFunctor {
public:
FlatRangeFilterFunctor(int64_t min_vector_id, int64_t max_vector_id)
: min_vector_id_(min_vector_id), max_vector_id_(max_vector_id) {}

void Build(std::vector<faiss::idx_t>& id_map) override { this->id_map_ = &id_map; }

bool Check(int64_t index) override {
return (*id_map_)[index] >= min_vector_id_ && (*id_map_)[index] < max_vector_id_;
}

private:
int64_t min_vector_id_;
int64_t max_vector_id_;
std::vector<faiss::idx_t>* id_map_{nullptr};
};

class ConcreteFilterFunctor : public FilterFunctor, public faiss::IDSelectorBatch {
public:
ConcreteFilterFunctor(const ConcreteFilterFunctor&) = delete;
Expand All @@ -101,35 +81,6 @@ class VectorIndex {
bool Check(int64_t vector_id) override { return is_member(vector_id); }
};

// List filter
// be careful not to use the parent class to release,
// otherwise there will be memory leaks
class HnswListFilterFunctor : public ConcreteFilterFunctor {
public:
explicit HnswListFilterFunctor(const std::vector<int64_t>& vector_ids) : ConcreteFilterFunctor(vector_ids) {}
~HnswListFilterFunctor() override = default;
};

class FlatListFilterFunctor : public ConcreteFilterFunctor {
public:
explicit FlatListFilterFunctor(const std::vector<int64_t>& vector_ids) : ConcreteFilterFunctor(vector_ids) {}
~FlatListFilterFunctor() override = default;
};

// List filter just for ivf flat
class IvfFlatListFilterFunctor : public ConcreteFilterFunctor {
public:
explicit IvfFlatListFilterFunctor(const std::vector<int64_t>& vector_ids) : ConcreteFilterFunctor(vector_ids) {}
~IvfFlatListFilterFunctor() override = default;
};

// List filter just for ivf pq
class IvfPqListFilterFunctor : public ConcreteFilterFunctor {
public:
explicit IvfPqListFilterFunctor(const std::vector<int64_t>& vector_ids) : ConcreteFilterFunctor(vector_ids) {}
~IvfPqListFilterFunctor() override = default;
};

virtual int32_t GetDimension() = 0;
virtual pb::common::MetricType GetMetricType() = 0;
virtual butil::Status GetCount(int64_t& count);
Expand Down
28 changes: 2 additions & 26 deletions src/vector/vector_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1384,34 +1384,10 @@ butil::Status VectorReader::DoVectorSearchForScalarPreFilterDebug(
return butil::Status::OK();
}

butil::Status VectorReader::SetVectorIndexIdsFilter(VectorIndexWrapperPtr vector_index,
butil::Status VectorReader::SetVectorIndexIdsFilter(VectorIndexWrapperPtr /*vector_index*/,
std::vector<std::shared_ptr<VectorIndex::FilterFunctor>>& filters,
const std::vector<int64_t>& vector_ids) {
if (vector_index->Type() == pb::common::VECTOR_INDEX_TYPE_HNSW) {
filters.push_back(std::make_shared<VectorIndex::HnswListFilterFunctor>(vector_ids));
} else if (vector_index->Type() == pb::common::VECTOR_INDEX_TYPE_FLAT ||
vector_index->Type() == pb::common::VECTOR_INDEX_TYPE_BRUTEFORCE) {
filters.push_back(std::make_shared<VectorIndex::FlatListFilterFunctor>(vector_ids));
} else if (vector_index->Type() == pb::common::VECTOR_INDEX_TYPE_IVF_FLAT) {
filters.push_back(std::make_shared<VectorIndex::IvfFlatListFilterFunctor>(vector_ids));
} else if (vector_index->Type() == pb::common::VECTOR_INDEX_TYPE_IVF_PQ) {
if (vector_index->SubType() == pb::common::VECTOR_INDEX_TYPE_IVF_PQ) {
filters.push_back(std::make_shared<VectorIndex::IvfPqListFilterFunctor>(vector_ids));
} else if (vector_index->SubType() == pb::common::VECTOR_INDEX_TYPE_FLAT) {
filters.push_back(std::make_shared<VectorIndex::FlatListFilterFunctor>(vector_ids));
} else {
return butil::Status(pb::error::Errno::EVECTOR_NOT_SUPPORT,
fmt::format("SetVectorIndexFilter not support index type: {} sub type: {}",
pb::common::VectorIndexType_Name(vector_index->Type()),
pb::common::VectorIndexType_Name(vector_index->SubType())));
}
} else {
return butil::Status(pb::error::Errno::EVECTOR_NOT_SUPPORT,
fmt::format("SetVectorIndexFilter not support index type: {} sub type: {}",
pb::common::VectorIndexType_Name(vector_index->Type()),
pb::common::VectorIndexType_Name(vector_index->SubType())));
}

filters.push_back(std::make_shared<VectorIndex::ConcreteFilterFunctor>(vector_ids));
return butil::Status::OK();
}

Expand Down
10 changes: 5 additions & 5 deletions test/unit_test/test_vector_index_flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ TEST_F(VectorIndexFlatTest, RangeSearchNoData) {
vector_ids.push_back(cnt + data_base_size);
}

auto flat_list_filter_functor = std::make_shared<VectorIndex::FlatListFilterFunctor>(vector_ids);
auto flat_list_filter_functor = std::make_shared<VectorIndex::ConcreteFilterFunctor>(vector_ids);

ok = vector_index_flat_l2->RangeSearch(vector_with_ids, radius, {flat_list_filter_functor}, false, {}, results_l2);
EXPECT_EQ(ok.error_code(), pb::error::Errno::OK);
Expand Down Expand Up @@ -1168,7 +1168,7 @@ TEST_F(VectorIndexFlatTest, RangeSearch) {
vector_ids.push_back(cnt + data_base_size);
}

auto flat_list_filter_functor = std::make_shared<VectorIndex::FlatListFilterFunctor>(vector_ids);
auto flat_list_filter_functor = std::make_shared<VectorIndex::ConcreteFilterFunctor>(vector_ids);

ok = vector_index_flat_l2->RangeSearch(vector_with_ids, radius, {flat_list_filter_functor}, false, {}, results_l2);
EXPECT_EQ(ok.error_code(), pb::error::Errno::OK);
Expand Down Expand Up @@ -1351,8 +1351,8 @@ TEST_F(VectorIndexFlatTest, SearchAfterLoad) {
std::vector<int64_t> vector_select_ids(vector_ids.begin(), vector_ids.begin() + (data_base_size / 2));
std::vector<int64_t> vector_select_ids_clone = vector_select_ids;

std::shared_ptr<VectorIndex::IvfFlatListFilterFunctor> filter =
std::make_shared<VectorIndex::IvfFlatListFilterFunctor>(std::move(vector_select_ids));
std::shared_ptr<VectorIndex::ConcreteFilterFunctor> filter =
std::make_shared<VectorIndex::ConcreteFilterFunctor>(std::move(vector_select_ids));
const bool reconstruct = false;
pb::common::VectorSearchParameter parameter;
parameter.mutable_ivf_flat()->set_nprobe(10);
Expand Down Expand Up @@ -1694,7 +1694,7 @@ TEST_F(VectorIndexFlatTest, RangeSearchAfterLoad) {
vector_ids.push_back(cnt + data_base_size);
}

auto flat_list_filter_functor = std::make_shared<VectorIndex::FlatListFilterFunctor>(vector_ids);
auto flat_list_filter_functor = std::make_shared<VectorIndex::ConcreteFilterFunctor>(vector_ids);

ok = vector_index_flat_l2->RangeSearch(vector_with_ids, radius, {flat_list_filter_functor}, false, {}, results_l2);
EXPECT_EQ(ok.error_code(), pb::error::Errno::OK);
Expand Down
4 changes: 2 additions & 2 deletions test/unit_test/test_vector_index_flat_search_limit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ TEST_F(VectorIndexFlatSearchParamLimitTest, Search) {

std::vector<std::shared_ptr<VectorIndex::FilterFunctor>> filters;
if (has_filter) {
filters.emplace_back(std::make_shared<VectorIndex::FlatListFilterFunctor>(std::move(vector_ids_for_search)));
filters.emplace_back(std::make_shared<VectorIndex::ConcreteFilterFunctor>(std::move(vector_ids_for_search)));
}

ok = vector_index_flat->Search(vector_with_ids, topk, filters, false, {}, results);
Expand Down Expand Up @@ -365,7 +365,7 @@ TEST_F(VectorIndexFlatSearchParamLimitTest, SearchAfterInsert) {

std::vector<std::shared_ptr<VectorIndex::FilterFunctor>> filters;
if (has_filter) {
filters.emplace_back(std::make_shared<VectorIndex::FlatListFilterFunctor>(std::move(vector_ids_for_search)));
filters.emplace_back(std::make_shared<VectorIndex::ConcreteFilterFunctor>(std::move(vector_ids_for_search)));
}

ok = vector_index_flat->Search(vector_with_ids, topk, filters, false, {}, results);
Expand Down
2 changes: 1 addition & 1 deletion test/unit_test/test_vector_index_flat_search_param.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ TEST_F(VectorIndexFlatSearchParamTest, Search) {
auto vector_ids_for_search_copy = vector_ids_for_search;

std::vector<std::shared_ptr<VectorIndex::FilterFunctor>> filters;
filters.emplace_back(std::make_shared<VectorIndex::FlatListFilterFunctor>(std::move(vector_ids_for_search)));
filters.emplace_back(std::make_shared<VectorIndex::ConcreteFilterFunctor>(std::move(vector_ids_for_search)));

ok = vector_index_flat->Search(vector_with_ids, topk, filters, false, {}, results);
EXPECT_EQ(ok.error_code(), pb::error::Errno::OK);
Expand Down
4 changes: 2 additions & 2 deletions test/unit_test/test_vector_index_hnsw_search_param.cc
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ TEST_F(VectorIndexHnswSearchParamTest, Search) {
auto [vector_ids, vector_ids_for_search] = lambda_random_function();

std::vector<std::shared_ptr<VectorIndex::FilterFunctor>> filters;
filters.push_back(std::make_shared<VectorIndex::HnswListFilterFunctor>(vector_ids_for_search));
filters.push_back(std::make_shared<VectorIndex::ConcreteFilterFunctor>(vector_ids_for_search));
ok = vector_index_hnsw->Search(vector_with_ids, topk, filters, false, {}, results);
EXPECT_EQ(ok.error_code(), pb::error::Errno::OK);

Expand Down Expand Up @@ -348,7 +348,7 @@ TEST_F(VectorIndexHnswSearchParamTest, SearchOrder) {
// auto [vector_ids, vector_ids_for_search] = lambda_random_function();

// std::vector<std::shared_ptr<VectorIndex::FilterFunctor>> filters;
// filters.push_back(std::make_shared<VectorIndex::HnswListFilterFunctor>(vector_ids_for_search));
// filters.push_back(std::make_shared<VectorIndex::ConcreteFilterFunctor>(vector_ids_for_search));
ok = vector_index_hnsw->Search(vector_with_ids, topk, {}, false, {}, results);
EXPECT_EQ(ok.error_code(), pb::error::Errno::OK);

Expand Down
16 changes: 8 additions & 8 deletions test/unit_test/test_vector_index_ivf_flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -947,8 +947,8 @@ TEST_F(VectorIndexIvfFlatTest, Search) {
std::vector<int64_t> vector_select_ids(vector_ids.begin(), vector_ids.begin() + (data_base_size / 2));
std::vector<int64_t> vector_select_ids_clone = vector_select_ids;

std::shared_ptr<VectorIndex::IvfFlatListFilterFunctor> filter =
std::make_shared<VectorIndex::IvfFlatListFilterFunctor>(std::move(vector_select_ids));
std::shared_ptr<VectorIndex::ConcreteFilterFunctor> filter =
std::make_shared<VectorIndex::ConcreteFilterFunctor>(std::move(vector_select_ids));
const bool reconstruct = false;
pb::common::VectorSearchParameter parameter;
parameter.mutable_ivf_flat()->set_nprobe(10);
Expand Down Expand Up @@ -1132,8 +1132,8 @@ TEST_F(VectorIndexIvfFlatTest, RangeSearch) {
std::vector<int64_t> vector_select_ids(vector_ids.begin(), vector_ids.begin() + (data_base_size / 2));
std::vector<int64_t> vector_select_ids_clone = vector_select_ids;

std::shared_ptr<VectorIndex::IvfFlatListFilterFunctor> filter =
std::make_shared<VectorIndex::IvfFlatListFilterFunctor>(std::move(vector_select_ids));
std::shared_ptr<VectorIndex::ConcreteFilterFunctor> filter =
std::make_shared<VectorIndex::ConcreteFilterFunctor>(std::move(vector_select_ids));
const bool reconstruct = false;
pb::common::VectorSearchParameter parameter;
parameter.mutable_ivf_flat()->set_nprobe(10);
Expand Down Expand Up @@ -1387,8 +1387,8 @@ TEST_F(VectorIndexIvfFlatTest, SearchAfterLoad) {
std::vector<int64_t> vector_select_ids(vector_ids.begin(), vector_ids.begin() + (data_base_size / 2));
std::vector<int64_t> vector_select_ids_clone = vector_select_ids;

std::shared_ptr<VectorIndex::IvfFlatListFilterFunctor> filter =
std::make_shared<VectorIndex::IvfFlatListFilterFunctor>(std::move(vector_select_ids));
std::shared_ptr<VectorIndex::ConcreteFilterFunctor> filter =
std::make_shared<VectorIndex::ConcreteFilterFunctor>(std::move(vector_select_ids));
const bool reconstruct = false;
pb::common::VectorSearchParameter parameter;
parameter.mutable_ivf_flat()->set_nprobe(10);
Expand Down Expand Up @@ -1572,8 +1572,8 @@ TEST_F(VectorIndexIvfFlatTest, RangeSearchAfterLoad) {
std::vector<int64_t> vector_select_ids(vector_ids.begin(), vector_ids.begin() + (data_base_size / 2));
std::vector<int64_t> vector_select_ids_clone = vector_select_ids;

std::shared_ptr<VectorIndex::IvfFlatListFilterFunctor> filter =
std::make_shared<VectorIndex::IvfFlatListFilterFunctor>(std::move(vector_select_ids));
std::shared_ptr<VectorIndex::ConcreteFilterFunctor> filter =
std::make_shared<VectorIndex::ConcreteFilterFunctor>(std::move(vector_select_ids));
const bool reconstruct = false;
pb::common::VectorSearchParameter parameter;
parameter.mutable_ivf_flat()->set_nprobe(10);
Expand Down
Loading

0 comments on commit 302c4ef

Please sign in to comment.