diff --git a/faiss/IndexRefine.cpp b/faiss/IndexRefine.cpp index 9bb63ba029..9aade766a6 100644 --- a/faiss/IndexRefine.cpp +++ b/faiss/IndexRefine.cpp @@ -96,12 +96,26 @@ void IndexRefine::search( idx_t k, float* distances, idx_t* labels, - const SearchParameters* params) const { - FAISS_THROW_IF_NOT_MSG( - !params, "search params not supported for this index"); + const SearchParameters* params_in) const { + const IndexRefineSearchParameters* params = nullptr; + if (params_in) { + params = dynamic_cast(params_in); + FAISS_THROW_IF_NOT_MSG( + params, "IndexRefine params have incorrect type"); + } + + idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor) + : idx_t(k * k_factor); + SearchParameters* base_index_params = + (params != nullptr) ? params->base_index_params : nullptr; + + FAISS_THROW_IF_NOT(k_base >= k); + + FAISS_THROW_IF_NOT(base_index); + FAISS_THROW_IF_NOT(refine_index); + FAISS_THROW_IF_NOT(k > 0); FAISS_THROW_IF_NOT(is_trained); - idx_t k_base = idx_t(k * k_factor); idx_t* base_labels = labels; float* base_distances = distances; ScopeDeleter del1; @@ -114,7 +128,8 @@ void IndexRefine::search( del2.set(base_distances); } - base_index->search(n, x, k_base, base_distances, base_labels); + base_index->search( + n, x, k_base, base_distances, base_labels, base_index_params); for (int i = 0; i < n * k_base; i++) assert(base_labels[i] >= -1 && base_labels[i] < ntotal); @@ -225,12 +240,26 @@ void IndexRefineFlat::search( idx_t k, float* distances, idx_t* labels, - const SearchParameters* params) const { - FAISS_THROW_IF_NOT_MSG( - !params, "search params not supported for this index"); + const SearchParameters* params_in) const { + const IndexRefineSearchParameters* params = nullptr; + if (params_in) { + params = dynamic_cast(params_in); + FAISS_THROW_IF_NOT_MSG( + params, "IndexRefineFlat params have incorrect type"); + } + + idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor) + : idx_t(k * k_factor); + SearchParameters* base_index_params = + (params != nullptr) ? params->base_index_params : nullptr; + + FAISS_THROW_IF_NOT(k_base >= k); + + FAISS_THROW_IF_NOT(base_index); + FAISS_THROW_IF_NOT(refine_index); + FAISS_THROW_IF_NOT(k > 0); FAISS_THROW_IF_NOT(is_trained); - idx_t k_base = idx_t(k * k_factor); idx_t* base_labels = labels; float* base_distances = distances; ScopeDeleter del1; @@ -243,7 +272,8 @@ void IndexRefineFlat::search( del2.set(base_distances); } - base_index->search(n, x, k_base, base_distances, base_labels); + base_index->search( + n, x, k_base, base_distances, base_labels, base_index_params); for (int i = 0; i < n * k_base; i++) assert(base_labels[i] >= -1 && base_labels[i] < ntotal); diff --git a/faiss/IndexRefine.h b/faiss/IndexRefine.h index 79b671b56a..23687af9f8 100644 --- a/faiss/IndexRefine.h +++ b/faiss/IndexRefine.h @@ -11,6 +11,13 @@ namespace faiss { +struct IndexRefineSearchParameters : SearchParameters { + float k_factor = 1; + SearchParameters* base_index_params = nullptr; // non-owning + + virtual ~IndexRefineSearchParameters() = default; +}; + /** Index that queries in a base_index (a fast one) and refines the * results with an exact search, hopefully improving the results. */ diff --git a/tests/test_refine.py b/tests/test_refine.py index aff285f400..4e85ee11ec 100644 --- a/tests/test_refine.py +++ b/tests/test_refine.py @@ -64,3 +64,58 @@ def test_distance_computer_AQ_LUT(self): def test_distance_computer_AQ_LUT_IP(self): self.do_test("RQ3x4_Nqint8", faiss.METRIC_INNER_PRODUCT) + + +class TestIndexRefineSearchParams(unittest.TestCase): + + def do_test(self, factory_string): + ds = datasets.SyntheticDataset(32, 256, 100, 40) + + index = faiss.index_factory(32, factory_string) + index.train(ds.get_train()) + index.add(ds.get_database()) + index.nprobe = 4 + xq = ds.get_queries() + + # do a search with k_factor = 1 + D1, I1 = index.search(xq, 10) + inter1 = faiss.eval_intersection(I1, ds.get_groundtruth(10)) + + # do a search with k_factor = 1.5 + params = faiss.IndexRefineSearchParameters(k_factor=1.1) + D2, I2 = index.search(xq, 10, params=params) + inter2 = faiss.eval_intersection(I2, ds.get_groundtruth(10)) + + # do a search with k_factor = 2 + params = faiss.IndexRefineSearchParameters(k_factor=2) + D3, I3 = index.search(xq, 10, params=params) + inter3 = faiss.eval_intersection(I3, ds.get_groundtruth(10)) + + # make sure that the recall rate increases with k_factor + self.assertGreater(inter2, inter1) + self.assertGreater(inter3, inter2) + + # make sure that the baseline k_factor is unchanged + self.assertEqual(index.k_factor, 1) + + # try passing params for the baseline index, change nprobe + base_params = faiss.IVFSearchParameters(nprobe=10) + params = faiss.IndexRefineSearchParameters(k_factor=1, base_index_params=base_params) + D4, I4 = index.search(xq, 10, params=params) + inter4 = faiss.eval_intersection(I4, ds.get_groundtruth(10)) + + base_params = faiss.IVFSearchParameters(nprobe=2) + params = faiss.IndexRefineSearchParameters(k_factor=1, base_index_params=base_params) + D5, I5 = index.search(xq, 10, params=params) + inter5 = faiss.eval_intersection(I5, ds.get_groundtruth(10)) + + # make sure that the recall rate changes + self.assertNotEqual(inter4, inter5) + + def test_rflat(self): + # flat is handled by the IndexRefineFlat class + self.do_test("IVF8,PQ2x4np,RFlat") + + def test_refine_sq8(self): + # this case uses the IndexRefine class + self.do_test("IVF8,PQ2x4np,Refine(SQ8)")