Skip to content

Commit

Permalink
Search and return codes (#3143)
Browse files Browse the repository at this point in the history
Summary:
This PR adds a functionality where an IVF index can be searched and the corresponding codes be returned. It also adds a few functions to compress int arrays into a bit-compact representation.

Pull Request resolved: #3143

Test Plan:
```
buck test //faiss/tests/:test_index_composite -- TestSearchAndReconstruct

buck test //faiss/tests/:test_standalone_codec -- test_arrays
```

Reviewed By: algoriddle

Differential Revision: D51544613

Pulled By: mdouze

fbshipit-source-id: 875f72d0f9140096851592422570efa0f65431fc
  • Loading branch information
mdouze authored and facebook-github-bot committed Nov 25, 2023
1 parent 467f70e commit b109d08
Show file tree
Hide file tree
Showing 15 changed files with 929 additions and 277 deletions.
582 changes: 341 additions & 241 deletions benchs/bench_all_ivf/bench_all_ivf.py

Large diffs are not rendered by default.

9 changes: 1 addition & 8 deletions benchs/bench_hybrid_cpu_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,14 +530,7 @@ def aa(*args, **kwargs):
raise RuntimeError()

totex = op.num_experiments()
rs = np.random.RandomState(123)
if totex < args.n_autotune:
experiments = rs.permutation(totex - 2) + 1
else:
experiments = rs.randint(
totex - 2, size=args.n_autotune - 2, replace=False)

experiments = [0, totex - 1] + list(experiments)
experiments = op.sample_experiments()
print(f"total nb experiments {totex}, running {len(experiments)}")

print("perform search")
Expand Down
18 changes: 17 additions & 1 deletion contrib/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,23 @@ def do_nothing_key(self):
return np.zeros(len(self.ranges), dtype=int)

def num_experiments(self):
return np.prod([len(values) for name, values in self.ranges])
return int(np.prod([len(values) for name, values in self.ranges]))

def sample_experiments(self, n_autotune, rs=np.random):
""" sample a set of experiments of max size n_autotune
(run all experiments in random order if n_autotune is 0)
"""
assert n_autotune == 0 or n_autotune >= 2
totex = self.num_experiments()
rs = np.random.RandomState(123)
if n_autotune == 0 or totex < n_autotune:
experiments = rs.permutation(totex - 2)
else:
experiments = rs.choice(
totex - 2, size=n_autotune - 2, replace=False)

experiments = [0, totex - 1] + [int(cno) + 1 for cno in experiments]
return experiments

def cno_to_key(self, cno):
"""Convert a sequential experiment number to a key"""
Expand Down
107 changes: 85 additions & 22 deletions faiss/IndexIVF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,44 +977,107 @@ void IndexIVF::search_and_reconstruct(
std::min(nlist, params ? params->nprobe : this->nprobe);
FAISS_THROW_IF_NOT(nprobe > 0);

idx_t* idx = new idx_t[n * nprobe];
ScopeDeleter<idx_t> del(idx);
float* coarse_dis = new float[n * nprobe];
ScopeDeleter<float> del2(coarse_dis);
std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);

quantizer->search(n, x, nprobe, coarse_dis, idx);
quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());

invlists->prefetch_lists(idx, n * nprobe);
invlists->prefetch_lists(idx.get(), n * nprobe);

// search_preassigned() with `store_pairs` enabled to obtain the list_no
// and offset into `codes` for reconstruction
search_preassigned(
n,
x,
k,
idx,
coarse_dis,
idx.get(),
coarse_dis.get(),
distances,
labels,
true /* store_pairs */,
params);
for (idx_t i = 0; i < n; ++i) {
for (idx_t j = 0; j < k; ++j) {
idx_t ij = i * k + j;
idx_t key = labels[ij];
float* reconstructed = recons + ij * d;
if (key < 0) {
// Fill with NaNs
memset(reconstructed, -1, sizeof(*reconstructed) * d);
} else {
int list_no = lo_listno(key);
int offset = lo_offset(key);
#pragma omp parallel for if (n * k > 1000)
for (idx_t ij = 0; ij < n * k; ij++) {
idx_t key = labels[ij];
float* reconstructed = recons + ij * d;
if (key < 0) {
// Fill with NaNs
memset(reconstructed, -1, sizeof(*reconstructed) * d);
} else {
int list_no = lo_listno(key);
int offset = lo_offset(key);

// Update label to the actual id
labels[ij] = invlists->get_single_id(list_no, offset);

reconstruct_from_offset(list_no, offset, reconstructed);
}
}
}

void IndexIVF::search_and_return_codes(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
uint8_t* codes,
bool include_listno,
const SearchParameters* params_in) const {
const IVFSearchParameters* params = nullptr;
if (params_in) {
params = dynamic_cast<const IVFSearchParameters*>(params_in);
FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type");
}
const size_t nprobe =
std::min(nlist, params ? params->nprobe : this->nprobe);
FAISS_THROW_IF_NOT(nprobe > 0);

std::unique_ptr<idx_t[]> idx(new idx_t[n * nprobe]);
std::unique_ptr<float[]> coarse_dis(new float[n * nprobe]);

quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get());

invlists->prefetch_lists(idx.get(), n * nprobe);

// search_preassigned() with `store_pairs` enabled to obtain the list_no
// and offset into `codes` for reconstruction
search_preassigned(
n,
x,
k,
idx.get(),
coarse_dis.get(),
distances,
labels,
true /* store_pairs */,
params);

size_t code_size_1 = code_size;
if (include_listno) {
code_size_1 += coarse_code_size();
}

#pragma omp parallel for if (n * k > 1000)
for (idx_t ij = 0; ij < n * k; ij++) {
idx_t key = labels[ij];
uint8_t* code1 = codes + ij * code_size_1;

if (key < 0) {
// Fill with 0xff
memset(code1, -1, code_size_1);
} else {
int list_no = lo_listno(key);
int offset = lo_offset(key);
const uint8_t* cc = invlists->get_single_code(list_no, offset);

// Update label to the actual id
labels[ij] = invlists->get_single_id(list_no, offset);
labels[ij] = invlists->get_single_id(list_no, offset);

reconstruct_from_offset(list_no, offset, reconstructed);
if (include_listno) {
encode_listno(list_no, code1);
code1 += code_size_1 - code_size;
}
memcpy(code1, cc, code_size);
}
}
}
Expand Down
18 changes: 18 additions & 0 deletions faiss/IndexIVF.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,24 @@ struct IndexIVF : Index, IndexIVFInterface {
float* recons,
const SearchParameters* params = nullptr) const override;

/** Similar to search, but also returns the codes corresponding to the
* stored vectors for the search results.
*
* @param codes codes (n, k, code_size)
* @param include_listno
* include the list ids in the code (in this case add
* ceil(log8(nlist)) to the code size)
*/
void search_and_return_codes(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels,
uint8_t* recons,
bool include_listno = false,
const SearchParameters* params = nullptr) const;

/** Reconstruct a vector given the location in terms of (inv list index +
* inv list offset) instead of the id.
*
Expand Down
1 change: 1 addition & 0 deletions faiss/IndexIVFAdditiveQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ struct AQInvertedListScanner : InvertedListScanner {
const float* q;
/// following codes come from this inverted list
void set_list(idx_t list_no, float coarse_dis) override {
this->list_no = list_no;
if (ia.metric_type == METRIC_L2 && ia.by_residual) {
ia.quantizer->compute_residual(q0, tmp.data(), list_no);
q = tmp.data();
Expand Down
2 changes: 1 addition & 1 deletion faiss/impl/AdditiveQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ void AdditiveQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
is_trained, "The additive quantizer is not trained yet.");

// standard additive quantizer decoding
#pragma omp parallel for if (n > 1000)
#pragma omp parallel for if (n > 100)
for (int64_t i = 0; i < n; i++) {
BitstringReader bsr(code + i * code_size, code_size);
float* xi = x + i * d;
Expand Down
3 changes: 2 additions & 1 deletion faiss/impl/ProductQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,8 @@ void ProductQuantizer::decode(const uint8_t* code, float* x) const {
}

void ProductQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
for (size_t i = 0; i < n; i++) {
#pragma omp parallel for if (n > 100)
for (int64_t i = 0; i < n; i++) {
this->decode(code + code_size * i, x + d * i);
}
}
Expand Down
3 changes: 2 additions & 1 deletion faiss/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from faiss.extra_wrappers import kmin, kmax, pairwise_distances, rand, randint, \
lrand, randn, rand_smooth_vectors, eval_intersection, normalize_L2, \
ResultHeap, knn, Kmeans, checksum, matrix_bucket_sort_inplace, bucket_sort, \
merge_knn_results, MapInt64ToInt64, knn_hamming
merge_knn_results, MapInt64ToInt64, knn_hamming, \
pack_bitstrings, unpack_bitstrings


__version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR,
Expand Down
70 changes: 70 additions & 0 deletions faiss/python/class_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,74 @@ def replacement_search_and_reconstruct(self, x, k, *, params=None, D=None, I=Non
)
return D, I, R

def replacement_search_and_return_codes(
self, x, k, *,
include_listnos=False, params=None, D=None, I=None, codes=None):
"""Find the k nearest neighbors of the set of vectors x in the index,
and return the codes stored for these vectors
Parameters
----------
x : array_like
Query vectors, shape (n, d) where d is appropriate for the index.
`dtype` must be float32.
k : int
Number of nearest neighbors.
params : SearchParameters
Search parameters of the current search (overrides the class-level params)
include_listnos : bool, optional
whether to include the list ids in the first bytes of each code
D : array_like, optional
Distance array to store the result.
I : array_like, optional
Labels array to store the result.
codes : array_like, optional
codes array to store
Returns
-------
D : array_like
Distances of the nearest neighbors, shape (n, k). When not enough results are found
the label is set to +Inf or -Inf.
I : array_like
Labels of the nearest neighbors, shape (n, k). When not enough results are found,
the label is set to -1
R : array_like
Approximate (reconstructed) nearest neighbor vectors, shape (n, k, d).
"""
n, d = x.shape
assert d == self.d
x = np.ascontiguousarray(x, dtype='float32')

assert k > 0

if D is None:
D = np.empty((n, k), dtype=np.float32)
else:
assert D.shape == (n, k)

if I is None:
I = np.empty((n, k), dtype=np.int64)
else:
assert I.shape == (n, k)

code_size_1 = self.code_size
if include_listnos:
code_size_1 += self.coarse_code_size()

if codes is None:
codes = np.empty((n, k, code_size_1), dtype=np.uint8)
else:
assert codes.shape == (n, k, code_size_1)

self.search_and_return_codes_c(
n, swig_ptr(x),
k, swig_ptr(D),
swig_ptr(I), swig_ptr(codes), include_listnos,
params
)
return D, I, codes

def replacement_remove_ids(self, x):
"""Remove some ids from the index.
This is a O(ntotal) operation by default, so could be expensive.
Expand Down Expand Up @@ -734,6 +802,8 @@ def replacement_permute_entries(self, perm):
ignore_missing=True)
replace_method(the_class, 'search_and_reconstruct',
replacement_search_and_reconstruct, ignore_missing=True)
replace_method(the_class, 'search_and_return_codes',
replacement_search_and_return_codes, ignore_missing=True)

# these ones are IVF-specific
replace_method(the_class, 'search_preassigned',
Expand Down
72 changes: 72 additions & 0 deletions faiss/python/extra_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

import faiss

import collections.abc


###########################################
# Wrapper for a few functions
###########################################
Expand Down Expand Up @@ -579,3 +582,72 @@ def assign(self, x):
self.index.add(self.centroids)
D, I = self.index.search(x, 1)
return D.ravel(), I.ravel()


###########################################
# Packing and unpacking bistrings
###########################################

def is_sequence(x):
return isinstance(x, collections.abc.Sequence)

pack_bitstrings_c = pack_bitstrings

def pack_bitstrings(a, nbit):
"""
Pack a set integers (i, j) where i=0:n and j=0:M into
n bitstrings.
Output is an uint8 array of size (n, code_size), where code_size is
such that at most 7 bits per code are wasted.
If nbit is an integer: all entries takes nbit bits.
If nbit is an array: entry (i, j) takes nbit[j] bits.
"""
n, M = a.shape
a = np.ascontiguousarray(a, dtype='int32')
if is_sequence(nbit):
nbit = np.ascontiguousarray(nbit, dtype='int32')
assert nbit.shape == (M,)
code_size = int((nbit.sum() + 7) // 8)
b = np.empty((n, code_size), dtype='uint8')
pack_bitstrings_c(
n, M, swig_ptr(nbit), swig_ptr(a), swig_ptr(b), code_size)
else:
code_size = (M * nbit + 7) // 8
b = np.empty((n, code_size), dtype='uint8')
pack_bitstrings_c(n, M, nbit, swig_ptr(a), swig_ptr(b), code_size)
return b

unpack_bitstrings_c = unpack_bitstrings

def unpack_bitstrings(b, M_or_nbits, nbit=None):
"""
Unpack a set integers (i, j) where i=0:n and j=0:M from
n bitstrings (encoded as uint8s).
Input is an uint8 array of size (n, code_size), where code_size is
such that at most 7 bits per code are wasted.
Two forms:
- when called with (array, M, nbit): there are M entries of size
nbit per row
- when called with (array, nbits): element (i, j) is encoded in
nbits[j] bits
"""
n, code_size = b.shape
if nbit is None:
nbit = np.ascontiguousarray(M_or_nbits, dtype='int32')
M = len(nbit)
min_code_size = int((nbit.sum() + 7) // 8)
assert code_size >= min_code_size
a = np.empty((n, M), dtype='int32')
unpack_bitstrings_c(
n, M, swig_ptr(nbit),
swig_ptr(b), code_size, swig_ptr(a))
else:
M = M_or_nbits
min_code_size = (M * nbit + 7) // 8
assert code_size >= min_code_size
a = np.empty((n, M), dtype='int32')
unpack_bitstrings_c(
n, M, nbit, swig_ptr(b), code_size, swig_ptr(a))
return a
Loading

0 comments on commit b109d08

Please sign in to comment.