Skip to content

Commit

Permalink
Add search parameters for IndexRefine::search() and IndexRefineFlat::…
Browse files Browse the repository at this point in the history
…search() (#3122)

Summary:
Add search params for `faiss::IndexRefine` and `faiss::IndexRefineFlat`

Pull Request resolved: #3122

Test Plan: buck test //faiss/tests/:test_refine

Reviewed By: pemazare

Differential Revision: D50968413

Pulled By: mdouze

fbshipit-source-id: 9f020d7e9c9d96b9acba54d9d7fff13bcf703b9e
  • Loading branch information
alexanderguzhva authored and facebook-github-bot committed Nov 5, 2023
1 parent df7280b commit 9a66532
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 10 deletions.
50 changes: 40 additions & 10 deletions faiss/IndexRefine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,26 @@ void IndexRefine::search(
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
const SearchParameters* params_in) const {
const IndexRefineSearchParameters* params = nullptr;
if (params_in) {
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
FAISS_THROW_IF_NOT_MSG(
params, "IndexRefine params have incorrect type");
}

idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
: idx_t(k * k_factor);
SearchParameters* base_index_params =
(params != nullptr) ? params->base_index_params : nullptr;

FAISS_THROW_IF_NOT(k_base >= k);

FAISS_THROW_IF_NOT(base_index);
FAISS_THROW_IF_NOT(refine_index);

FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(is_trained);
idx_t k_base = idx_t(k * k_factor);
idx_t* base_labels = labels;
float* base_distances = distances;
ScopeDeleter<idx_t> del1;
Expand All @@ -114,7 +128,8 @@ void IndexRefine::search(
del2.set(base_distances);
}

base_index->search(n, x, k_base, base_distances, base_labels);
base_index->search(
n, x, k_base, base_distances, base_labels, base_index_params);

for (int i = 0; i < n * k_base; i++)
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
Expand Down Expand Up @@ -225,12 +240,26 @@ void IndexRefineFlat::search(
idx_t k,
float* distances,
idx_t* labels,
const SearchParameters* params) const {
FAISS_THROW_IF_NOT_MSG(
!params, "search params not supported for this index");
const SearchParameters* params_in) const {
const IndexRefineSearchParameters* params = nullptr;
if (params_in) {
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
FAISS_THROW_IF_NOT_MSG(
params, "IndexRefineFlat params have incorrect type");
}

idx_t k_base = (params != nullptr) ? idx_t(k * params->k_factor)
: idx_t(k * k_factor);
SearchParameters* base_index_params =
(params != nullptr) ? params->base_index_params : nullptr;

FAISS_THROW_IF_NOT(k_base >= k);

FAISS_THROW_IF_NOT(base_index);
FAISS_THROW_IF_NOT(refine_index);

FAISS_THROW_IF_NOT(k > 0);
FAISS_THROW_IF_NOT(is_trained);
idx_t k_base = idx_t(k * k_factor);
idx_t* base_labels = labels;
float* base_distances = distances;
ScopeDeleter<idx_t> del1;
Expand All @@ -243,7 +272,8 @@ void IndexRefineFlat::search(
del2.set(base_distances);
}

base_index->search(n, x, k_base, base_distances, base_labels);
base_index->search(
n, x, k_base, base_distances, base_labels, base_index_params);

for (int i = 0; i < n * k_base; i++)
assert(base_labels[i] >= -1 && base_labels[i] < ntotal);
Expand Down
7 changes: 7 additions & 0 deletions faiss/IndexRefine.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@

namespace faiss {

struct IndexRefineSearchParameters : SearchParameters {
float k_factor = 1;
SearchParameters* base_index_params = nullptr; // non-owning

virtual ~IndexRefineSearchParameters() = default;
};

/** Index that queries in a base_index (a fast one) and refines the
* results with an exact search, hopefully improving the results.
*/
Expand Down
55 changes: 55 additions & 0 deletions tests/test_refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,58 @@ def test_distance_computer_AQ_LUT(self):

def test_distance_computer_AQ_LUT_IP(self):
self.do_test("RQ3x4_Nqint8", faiss.METRIC_INNER_PRODUCT)


class TestIndexRefineSearchParams(unittest.TestCase):

def do_test(self, factory_string):
ds = datasets.SyntheticDataset(32, 256, 100, 40)

index = faiss.index_factory(32, factory_string)
index.train(ds.get_train())
index.add(ds.get_database())
index.nprobe = 4
xq = ds.get_queries()

# do a search with k_factor = 1
D1, I1 = index.search(xq, 10)
inter1 = faiss.eval_intersection(I1, ds.get_groundtruth(10))

# do a search with k_factor = 1.5
params = faiss.IndexRefineSearchParameters(k_factor=1.1)
D2, I2 = index.search(xq, 10, params=params)
inter2 = faiss.eval_intersection(I2, ds.get_groundtruth(10))

# do a search with k_factor = 2
params = faiss.IndexRefineSearchParameters(k_factor=2)
D3, I3 = index.search(xq, 10, params=params)
inter3 = faiss.eval_intersection(I3, ds.get_groundtruth(10))

# make sure that the recall rate increases with k_factor
self.assertGreater(inter2, inter1)
self.assertGreater(inter3, inter2)

# make sure that the baseline k_factor is unchanged
self.assertEqual(index.k_factor, 1)

# try passing params for the baseline index, change nprobe
base_params = faiss.IVFSearchParameters(nprobe=10)
params = faiss.IndexRefineSearchParameters(k_factor=1, base_index_params=base_params)
D4, I4 = index.search(xq, 10, params=params)
inter4 = faiss.eval_intersection(I4, ds.get_groundtruth(10))

base_params = faiss.IVFSearchParameters(nprobe=2)
params = faiss.IndexRefineSearchParameters(k_factor=1, base_index_params=base_params)
D5, I5 = index.search(xq, 10, params=params)
inter5 = faiss.eval_intersection(I5, ds.get_groundtruth(10))

# make sure that the recall rate changes
self.assertNotEqual(inter4, inter5)

def test_rflat(self):
# flat is handled by the IndexRefineFlat class
self.do_test("IVF8,PQ2x4np,RFlat")

def test_refine_sq8(self):
# this case uses the IndexRefine class
self.do_test("IVF8,PQ2x4np,Refine(SQ8)")

0 comments on commit 9a66532

Please sign in to comment.