From b109d086a2a200b1697481dd3d79faedc585a623 Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Sat, 25 Nov 2023 13:57:25 -0800 Subject: [PATCH] Search and return codes (#3143) Summary: This PR adds a functionality where an IVF index can be searched and the corresponding codes be returned. It also adds a few functions to compress int arrays into a bit-compact representation. Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3143 Test Plan: ``` buck test //faiss/tests/:test_index_composite -- TestSearchAndReconstruct buck test //faiss/tests/:test_standalone_codec -- test_arrays ``` Reviewed By: algoriddle Differential Revision: D51544613 Pulled By: mdouze fbshipit-source-id: 875f72d0f9140096851592422570efa0f65431fc --- benchs/bench_all_ivf/bench_all_ivf.py | 582 +++++++++++++++----------- benchs/bench_hybrid_cpu_gpu.py | 9 +- contrib/evaluation.py | 18 +- faiss/IndexIVF.cpp | 107 ++++- faiss/IndexIVF.h | 18 + faiss/IndexIVFAdditiveQuantizer.cpp | 1 + faiss/impl/AdditiveQuantizer.cpp | 2 +- faiss/impl/ProductQuantizer.cpp | 3 +- faiss/python/__init__.py | 3 +- faiss/python/class_wrappers.py | 70 ++++ faiss/python/extra_wrappers.py | 72 ++++ faiss/utils/hamming.cpp | 84 ++++ faiss/utils/hamming.h | 58 +++ tests/test_index_composite.py | 157 ++++++- tests/test_standalone_codec.py | 22 +- 15 files changed, 929 insertions(+), 277 deletions(-) diff --git a/benchs/bench_all_ivf/bench_all_ivf.py b/benchs/bench_all_ivf/bench_all_ivf.py index e098e9527a..cb4e097a05 100644 --- a/benchs/bench_all_ivf/bench_all_ivf.py +++ b/benchs/bench_all_ivf/bench_all_ivf.py @@ -7,6 +7,7 @@ import os import sys import time +import json import faiss import numpy as np @@ -19,105 +20,6 @@ sanitize = datasets.sanitize -###################################################### -# Command-line parsing -###################################################### - - -parser = argparse.ArgumentParser() - - -def aa(*args, **kwargs): - group.add_argument(*args, **kwargs) - - -group = parser.add_argument_group('dataset options') - -aa('--db', default='deep1M', help='dataset') -aa('--compute_gt', default=False, action='store_true', - help='compute and store the groundtruth') -aa('--force_IP', default=False, action="store_true", - help='force IP search instead of L2') - -group = parser.add_argument_group('index consturction') - -aa('--indexkey', default='HNSW32', help='index_factory type') -aa('--maxtrain', default=256 * 256, type=int, - help='maximum number of training points (0 to set automatically)') -aa('--indexfile', default='', help='file to read or write index from') -aa('--add_bs', default=-1, type=int, - help='add elements index by batches of this size') - - -group = parser.add_argument_group('IVF options') -aa('--by_residual', default=-1, type=int, - help="set if index should use residuals (default=unchanged)") -aa('--no_precomputed_tables', action='store_true', default=False, - help='disable precomputed tables (uses less memory)') -aa('--get_centroids_from', default='', - help='get the centroids from this index (to speed up training)') -aa('--clustering_niter', default=-1, type=int, - help='number of clustering iterations (-1 = leave default)') -aa('--train_on_gpu', default=False, action='store_true', - help='do training on GPU') - - -group = parser.add_argument_group('index-specific options') -aa('--M0', default=-1, type=int, help='size of base level for HNSW') -aa('--RQ_train_default', default=False, action="store_true", - help='disable progressive dim training for RQ') -aa('--RQ_beam_size', default=-1, type=int, - help='set beam size at add time') -aa('--LSQ_encode_ils_iters', default=-1, type=int, - help='ILS iterations for LSQ') -aa('--RQ_use_beam_LUT', default=-1, type=int, - help='use beam LUT at add time') - -group = parser.add_argument_group('searching') - -aa('--k', default=100, type=int, help='nb of nearest neighbors') -aa('--inter', default=False, action='store_true', - help='use intersection measure instead of 1-recall as metric') -aa('--searchthreads', default=-1, type=int, - help='nb of threads to use at search time') -aa('--searchparams', nargs='+', default=['autotune'], - help="search parameters to use (can be autotune or a list of params)") -aa('--n_autotune', default=500, type=int, - help="max nb of autotune experiments") -aa('--autotune_max', default=[], nargs='*', - help='set max value for autotune variables format "var:val" (exclusive)') -aa('--autotune_range', default=[], nargs='*', - help='set complete autotune range, format "var:val1,val2,..."') -aa('--min_test_duration', default=3.0, type=float, - help='run test at least for so long to avoid jitter') - -args = parser.parse_args() - -print("args:", args) - -os.system('echo -n "nb processors "; ' - 'cat /proc/cpuinfo | grep ^processor | wc -l; ' - 'cat /proc/cpuinfo | grep ^"model name" | tail -1') - -###################################################### -# Load dataset -###################################################### - -ds = datasets.load_dataset( - dataset=args.db, compute_gt=args.compute_gt) - -if args.force_IP: - ds.metric = "IP" - -print(ds) - -nq, d = ds.nq, ds.d -nb, d = ds.nq, ds.d - - -###################################################### -# Make index -###################################################### def unwind_index_ivf(index): if isinstance(index, faiss.IndexPreTransform): @@ -125,6 +27,10 @@ def unwind_index_ivf(index): vt = index.chain.at(0) index_ivf, vt2 = unwind_index_ivf(faiss.downcast_index(index.index)) assert vt2 is None + if vt is None: + vt = lambda x: x + else: + vt = faiss.downcast_VectorTransform(vt) return index_ivf, vt if hasattr(faiss, "IndexRefine") and isinstance(index, faiss.IndexRefine): return unwind_index_ivf(faiss.downcast_index(index.base_index)) @@ -157,16 +63,50 @@ def apply_AQ_options(index, args): index.rq.use_beam_LUT = args.RQ_use_beam_LUT -if args.indexfile and os.path.exists(args.indexfile): - print("reading", args.indexfile) - index = faiss.read_index(args.indexfile) +def eval_setting(index, xq, gt, k, inter, min_time): + """ evaluate searching in terms of precision vs. speed """ + nq = xq.shape[0] + ivf_stats = faiss.cvar.indexIVF_stats + ivf_stats.reset() + nrun = 0 + t0 = time.time() + while True: + D, I = index.search(xq, k) + nrun += 1 + t1 = time.time() + if t1 - t0 > min_time: + break + ms_per_query = ((t1 - t0) * 1000.0 / nq / nrun) + res = { + "ms_per_query": ms_per_query, + "nrun": nrun + } + res["n"] = ms_per_query + if inter: + rank = k + inter_measure = faiss.eval_intersection(gt[:, :rank], I[:, :rank]) / (nq * rank) + print("%.4f" % inter_measure, end=' ') + res["inter_measure"] = inter_measure + else: + res["recalls"] = {} + for rank in 1, 10, 100: + recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq) + print("%.4f" % recall, end=' ') + res["recalls"][rank] = recall + print(" %9.5f " % ms_per_query, end=' ') + print("%12d " % (ivf_stats.ndis / nrun), end=' ') + print(nrun) + res["ndis"] = ivf_stats.ndis / nrun + return res - index_ivf, vec_transform = unwind_index_ivf(index) - if vec_transform is None: - vec_transform = lambda x: x +###################################################### +# Training +###################################################### -else: +def run_train(args, ds, res): + nq, d = ds.nq, ds.d + nb, d = ds.nq, ds.d print("build index, key=", args.indexkey) @@ -176,10 +116,6 @@ def apply_AQ_options(index, args): ) index_ivf, vec_transform = unwind_index_ivf(index) - if vec_transform is None: - vec_transform = lambda x: x - else: - vec_transform = faiss.downcast_VectorTransform(vec_transform) if args.by_residual != -1: by_residual = args.by_residual == 1 @@ -205,9 +141,14 @@ def apply_AQ_options(index, args): 64) print(base_index.nprobe) elif isinstance(quantizer, faiss.IndexHNSW): - print(" update quantizer efSearch=", quantizer.hnsw.efSearch, end=" -> ") - quantizer.hnsw.efSearch = 40 if index_ivf.nlist < 4e6 else 64 - print(quantizer.hnsw.efSearch) + hnsw = quantizer.hnsw + print( + f" update HNSW quantizer options, before: " + f"{hnsw.efSearch=:} {hnsw.efConstruction=:}" + ) + hnsw.efSearch = 40 if index_ivf.nlist < 4e6 else 64 + hnsw.efConstruction = 200 + print(f" after: {hnsw.efSearch=:} {hnsw.efConstruction=:}") apply_AQ_options(index_ivf or index, args) @@ -286,182 +227,341 @@ def apply_AQ_options(index, args): t0 = time.time() index.train(xt2) - print(" train in %.3f s" % (time.time() - t0)) + res.train_time = time.time() - t0 + print(" train in %.3f s" % res.train_time) + return index + +###################################################### +# Populating index +###################################################### + +def run_add(args, ds, index, res): print("adding") t0 = time.time() if args.add_bs == -1: + assert args.split == [1, 0], "split not supported with full batch add" index.add(sanitize(ds.get_database())) else: + totn = ds.nb // args.split[0] # approximate i0 = 0 - for xblock in ds.database_iterator(bs=args.add_bs): + print(f"Adding in block sizes {args.add_bs} with split {args.split}") + for xblock in ds.database_iterator(bs=args.add_bs, split=args.split): i1 = i0 + len(xblock) print(" adding %d:%d / %d [%.3f s, RSS %d kiB] " % ( - i0, i1, ds.nb, time.time() - t0, + i0, i1, totn, time.time() - t0, faiss.get_mem_usage_kb())) index.add(xblock) i0 = i1 - print(" add in %.3f s" % (time.time() - t0)) - if args.indexfile: - print("storing", args.indexfile) - faiss.write_index(index, args.indexfile) + res.t_add = time.time() - t0 + print(f" add in {res.t_add:.3f} s index size {index.ntotal}") -if args.no_precomputed_tables: - if isinstance(index_ivf, faiss.IndexIVFPQ): - print("disabling precomputed table") - index_ivf.use_precomputed_table = -1 - index_ivf.precomputed_table.clear() -if args.indexfile: - print("index size on disk: ", os.stat(args.indexfile).st_size) +###################################################### +# Search +###################################################### -if hasattr(index, "code_size"): - print("vector code_size", index.code_size) +def run_search(args, ds, index, res): -if hasattr(index_ivf, "code_size"): - print("vector code_size (IVF)", index_ivf.code_size) + index_ivf, vec_transform = unwind_index_ivf(index) -print("current RSS:", faiss.get_mem_usage_kb() * 1024) + if args.no_precomputed_tables: + if isinstance(index_ivf, faiss.IndexIVFPQ): + print("disabling precomputed table") + index_ivf.use_precomputed_table = -1 + index_ivf.precomputed_table.clear() -precomputed_table_size = 0 -if hasattr(index_ivf, 'precomputed_table'): - precomputed_table_size = index_ivf.precomputed_table.size() * 4 + if args.indexfile: + print("index size on disk: ", os.stat(args.indexfile).st_size) -print("precomputed tables size:", precomputed_table_size) + if hasattr(index, "code_size"): + print("vector code_size", index.code_size) + if hasattr(index_ivf, "code_size"): + print("vector code_size (IVF)", index_ivf.code_size) -############################################################# -# Index is ready -############################################################# + print("current RSS:", faiss.get_mem_usage_kb() * 1024) -xq = sanitize(ds.get_queries()) -gt = ds.get_groundtruth(k=args.k) -assert gt.shape[1] == args.k + precomputed_table_size = 0 + if hasattr(index_ivf, 'precomputed_table'): + precomputed_table_size = index_ivf.precomputed_table.size() * 4 -if args.searchthreads != -1: - print("Setting nb of threads to", args.searchthreads) - faiss.omp_set_num_threads(args.searchthreads) -else: - print("nb search threads: ", faiss.omp_get_max_threads()) + print("precomputed tables size:", precomputed_table_size) -ps = faiss.ParameterSpace() -ps.initialize(index) + # Index is ready -parametersets = args.searchparams + xq = sanitize(ds.get_queries()) + nq, d = xq.shape + gt = ds.get_groundtruth(k=args.k) + if not args.accept_short_gt: # Deep1B has only a single NN per query + assert gt.shape[1] == args.k + if args.searchthreads != -1: + print("Setting nb of threads to", args.searchthreads) + faiss.omp_set_num_threads(args.searchthreads) + else: + print("nb search threads: ", faiss.omp_get_max_threads()) -if args.inter: - header = ( - '%-40s inter@%3d time(ms/q) nb distances #runs' % - ("parameters", args.k) - ) -else: + ps = faiss.ParameterSpace() + ps.initialize(index) - header = ( - '%-40s R@1 R@10 R@100 time(ms/q) nb distances #runs' % - "parameters" - ) + parametersets = args.searchparams -def compute_inter(a, b): - nq, rank = a.shape - ninter = sum( - np.intersect1d(a[i, :rank], b[i, :rank]).size - for i in range(nq) - ) - return ninter / a.size + if args.inter: + header = ( + '%-40s inter@%3d time(ms/q) nb distances #runs' % + ("parameters", args.k) + ) + else: + header = ( + '%-40s R@1 R@10 R@100 time(ms/q) nb distances #runs' % + "parameters" + ) -def eval_setting(index, xq, gt, k, inter, min_time): - nq = xq.shape[0] - ivf_stats = faiss.cvar.indexIVF_stats - ivf_stats.reset() - nrun = 0 - t0 = time.time() - while True: - D, I = index.search(xq, k) - nrun += 1 - t1 = time.time() - if t1 - t0 > min_time: - break - ms_per_query = ((t1 - t0) * 1000.0 / nq / nrun) - if inter: - rank = k - inter_measure = compute_inter(gt[:, :rank], I[:, :rank]) - print("%.4f" % inter_measure, end=' ') - else: - for rank in 1, 10, 100: - n_ok = (I[:, :rank] == gt[:, :1]).sum() - print("%.4f" % (n_ok / float(nq)), end=' ') - print(" %9.5f " % ms_per_query, end=' ') - print("%12d " % (ivf_stats.ndis / nrun), end=' ') - print(nrun) + res.search_results = {} + if parametersets == ['autotune']: + + ps.n_experiments = args.n_autotune + ps.min_test_duration = args.min_test_duration + + for kv in args.autotune_max: + k, vmax = kv.split(':') + vmax = float(vmax) + print("limiting %s to %g" % (k, vmax)) + pr = ps.add_range(k) + values = faiss.vector_to_array(pr.values) + values = np.array([v for v in values if v < vmax]) + faiss.copy_array_to_vector(values, pr.values) + + for kv in args.autotune_range: + k, vals = kv.split(':') + vals = np.fromstring(vals, sep=',') + print("setting %s to %s" % (k, vals)) + pr = ps.add_range(k) + faiss.copy_array_to_vector(vals, pr.values) + + # setup the Criterion object + if args.inter: + print("Optimize for intersection @ ", args.k) + crit = faiss.IntersectionCriterion(nq, args.k) + else: + print("Optimize for 1-recall @ 1") + crit = faiss.OneRecallAtRCriterion(nq, 1) + # by default, the criterion will request only 1 NN + crit.nnn = args.k + crit.set_groundtruth(None, gt.astype('int64')) -if parametersets == ['autotune']: + # then we let Faiss find the optimal parameters by itself + print("exploring operating points, %d threads" % faiss.omp_get_max_threads()); + ps.display() - ps.n_experiments = args.n_autotune - ps.min_test_duration = args.min_test_duration + t0 = time.time() + op = ps.explore(index, xq, crit) + res.t_explore = time.time() - t0 + print("Done in %.3f s, available OPs:" % res.t_explore) - for kv in args.autotune_max: - k, vmax = kv.split(':') - vmax = float(vmax) - print("limiting %s to %g" % (k, vmax)) - pr = ps.add_range(k) - values = faiss.vector_to_array(pr.values) - values = np.array([v for v in values if v < vmax]) - faiss.copy_array_to_vector(values, pr.values) + op.display() - for kv in args.autotune_range: - k, vals = kv.split(':') - vals = np.fromstring(vals, sep=',') - print("setting %s to %s" % (k, vals)) - pr = ps.add_range(k) - faiss.copy_array_to_vector(vals, pr.values) + print("Re-running evaluation on selected OPs") + print(header) + opv = op.optimal_pts + maxw = max(max(len(opv.at(i).key) for i in range(opv.size())), 40) + for i in range(opv.size()): + opt = opv.at(i) + + ps.set_index_parameters(index, opt.key) + + print(opt.key.ljust(maxw), end=' ') + sys.stdout.flush() + + res_i = eval_setting(index, xq, gt, args.k, args.inter, args.min_test_duration) + res.search_results[opt.key] = res_i - # setup the Criterion object - if args.inter: - print("Optimize for intersection @ ", args.k) - crit = faiss.IntersectionCriterion(nq, args.k) else: - print("Optimize for 1-recall @ 1") - crit = faiss.OneRecallAtRCriterion(nq, 1) + print(header) + for param in parametersets: + print("%-40s " % param, end=' ') + sys.stdout.flush() + ps.set_index_parameters(index, param) - # by default, the criterion will request only 1 NN - crit.nnn = args.k - crit.set_groundtruth(None, gt.astype('int64')) + res_i = eval_setting(index, xq, gt, args.k, args.inter, args.min_test_duration) + res.search_results[param] = res_i - # then we let Faiss find the optimal parameters by itself - print("exploring operating points, %d threads" % faiss.omp_get_max_threads()); - ps.display() - t0 = time.time() - op = ps.explore(index, xq, crit) - print("Done in %.3f s, available OPs:" % (time.time() - t0)) - op.display() +###################################################### +# Driver function +###################################################### - print("Re-running evaluation on selected OPs") - print(header) - opv = op.optimal_pts - maxw = max(max(len(opv.at(i).key) for i in range(opv.size())), 40) - for i in range(opv.size()): - opt = opv.at(i) +def main(): + + parser = argparse.ArgumentParser() + + def aa(*args, **kwargs): + group.add_argument(*args, **kwargs) + + group = parser.add_argument_group('general options') + aa('--nthreads', default=-1, type=int, + help='nb of threads to use at train and add time') + aa('--json', default=False, action="store_true", + help="output stats in JSON format at the end") + aa('--todo', default=["check_files"], + choices=["train", "add", "search", "check_files"], + nargs="+", help='what to do (check_files means decide depending on which index files exist)') + + group = parser.add_argument_group('dataset options') + aa('--db', default='deep1M', help='dataset') + aa('--compute_gt', default=False, action='store_true', + help='compute and store the groundtruth') + aa('--force_IP', default=False, action="store_true", + help='force IP search instead of L2') + aa('--accept_short_gt', default=False, action='store_true', + help='work around a problem with Deep1B GT') + + group = parser.add_argument_group('index construction') + aa('--indexkey', default='HNSW32', help='index_factory type') + aa('--trained_indexfile', default='', + help='file to read or write a trained index from') + aa('--maxtrain', default=256 * 256, type=int, + help='maximum number of training points (0 to set automatically)') + aa('--indexfile', default='', help='file to read or write index from') + aa('--split', default=[1, 0], type=int, nargs=2, help="database split") + aa('--add_bs', default=-1, type=int, + help='add elements index by batches of this size') + + group = parser.add_argument_group('IVF options') + aa('--by_residual', default=-1, type=int, + help="set if index should use residuals (default=unchanged)") + aa('--no_precomputed_tables', action='store_true', default=False, + help='disable precomputed tables (uses less memory)') + aa('--get_centroids_from', default='', + help='get the centroids from this index (to speed up training)') + aa('--clustering_niter', default=-1, type=int, + help='number of clustering iterations (-1 = leave default)') + aa('--train_on_gpu', default=False, action='store_true', + help='do training on GPU') + + group = parser.add_argument_group('index-specific options') + aa('--M0', default=-1, type=int, help='size of base level for HNSW') + aa('--RQ_train_default', default=False, action="store_true", + help='disable progressive dim training for RQ') + aa('--RQ_beam_size', default=-1, type=int, + help='set beam size at add time') + aa('--LSQ_encode_ils_iters', default=-1, type=int, + help='ILS iterations for LSQ') + aa('--RQ_use_beam_LUT', default=-1, type=int, + help='use beam LUT at add time') + + group = parser.add_argument_group('searching') + aa('--k', default=100, type=int, help='nb of nearest neighbors') + aa('--inter', default=False, action='store_true', + help='use intersection measure instead of 1-recall as metric') + aa('--searchthreads', default=-1, type=int, + help='nb of threads to use at search time') + aa('--searchparams', nargs='+', default=['autotune'], + help="search parameters to use (can be autotune or a list of params)") + aa('--n_autotune', default=500, type=int, + help="max nb of autotune experiments") + aa('--autotune_max', default=[], nargs='*', + help='set max value for autotune variables format "var:val" (exclusive)') + aa('--autotune_range', default=[], nargs='*', + help='set complete autotune range, format "var:val1,val2,..."') + aa('--min_test_duration', default=3.0, type=float, + help='run test at least for so long to avoid jitter') + aa('--indexes_to_merge', default=[], nargs="*", + help="load these indexes to search and merge them before searching") + + args = parser.parse_args() + + if args.todo == ["check_files"]: + if os.path.exists(args.indexfile): + args.todo = ["search"] + elif os.path.exists(args.trained_indexfile): + args.todo = ["add", "search"] + else: + args.todo = ["train", "add", "search"] + print("setting todo to", args.todo) + + print("args:", args) + + os.system('echo -n "nb processors "; ' + 'cat /proc/cpuinfo | grep ^processor | wc -l; ' + 'cat /proc/cpuinfo | grep ^"model name" | tail -1') - ps.set_index_parameters(index, opt.key) + # object to collect results + res = argparse.Namespace() + res.args = args.__dict__ - print(opt.key.ljust(maxw), end=' ') - sys.stdout.flush() + res.cpu_model = [ + l for l in open("/proc/cpuinfo", "r") + if "model name" in l][0] - eval_setting(index, xq, gt, args.k, args.inter, args.min_test_duration) + print("Load dataset") -else: - print(header) - for param in parametersets: - print("%-40s " % param, end=' ') - sys.stdout.flush() - ps.set_index_parameters(index, param) + ds = datasets.load_dataset( + dataset=args.db, compute_gt=args.compute_gt) - eval_setting(index, xq, gt, args.k, args.inter, args.min_test_duration) + if args.force_IP: + ds.metric = "IP" + + print(ds) + + if args.nthreads != -1: + print("Set nb of threads to", args.nthreads) + faiss.omp_set_num_threads(args.nthreads) + else: + print("nb threads: ", faiss.omp_get_max_threads()) + + index = None + if "train" in args.todo: + print("================== Training index") + index = run_train(args, ds, res) + if args.trained_indexfile: + print("storing trained index", args.trained_indexfile) + faiss.write_index(index, args.trained_indexfile) + + if "add" in args.todo: + if not index: + assert args.trained_indexfile + print("reading trained index", args.trained_indexfile) + index = faiss.read_index(args.trained_indexfile) + + print("================== Adding vectors to index") + run_add(args, ds, index, res) + if args.indexfile: + print("storing", args.indexfile) + faiss.write_index(index, args.indexfile) + + if "search" in args.todo: + if not index: + if args.indexfile: + print("reading index", args.indexfile) + index = faiss.read_index(args.indexfile) + elif args.indexes_to_merge: + print(f"Merging {len(args.indexes_to_merge)} indexes") + sz = 0 + for fname in args.indexes_to_merge: + print(f" reading {fname} (current size {sz})") + index_i = faiss.read_index(fname) + if index is None: + index = index_i + else: + index.merge_from(index_i, index.ntotal) + sz = index.ntotal + else: + assert False, "provide --indexfile" + + print("================== Searching") + run_search(args, ds, index, res) + + if args.json: + print("JSON results:", json.dumps(res.__dict__)) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/benchs/bench_hybrid_cpu_gpu.py b/benchs/bench_hybrid_cpu_gpu.py index 8a509f323f..779a09fefa 100644 --- a/benchs/bench_hybrid_cpu_gpu.py +++ b/benchs/bench_hybrid_cpu_gpu.py @@ -530,14 +530,7 @@ def aa(*args, **kwargs): raise RuntimeError() totex = op.num_experiments() - rs = np.random.RandomState(123) - if totex < args.n_autotune: - experiments = rs.permutation(totex - 2) + 1 - else: - experiments = rs.randint( - totex - 2, size=args.n_autotune - 2, replace=False) - - experiments = [0, totex - 1] + list(experiments) + experiments = op.sample_experiments() print(f"total nb experiments {totex}, running {len(experiments)}") print("perform search") diff --git a/contrib/evaluation.py b/contrib/evaluation.py index 1f4068734e..50e8a93319 100644 --- a/contrib/evaluation.py +++ b/contrib/evaluation.py @@ -380,7 +380,23 @@ def do_nothing_key(self): return np.zeros(len(self.ranges), dtype=int) def num_experiments(self): - return np.prod([len(values) for name, values in self.ranges]) + return int(np.prod([len(values) for name, values in self.ranges])) + + def sample_experiments(self, n_autotune, rs=np.random): + """ sample a set of experiments of max size n_autotune + (run all experiments in random order if n_autotune is 0) + """ + assert n_autotune == 0 or n_autotune >= 2 + totex = self.num_experiments() + rs = np.random.RandomState(123) + if n_autotune == 0 or totex < n_autotune: + experiments = rs.permutation(totex - 2) + else: + experiments = rs.choice( + totex - 2, size=n_autotune - 2, replace=False) + + experiments = [0, totex - 1] + [int(cno) + 1 for cno in experiments] + return experiments def cno_to_key(self, cno): """Convert a sequential experiment number to a key""" diff --git a/faiss/IndexIVF.cpp b/faiss/IndexIVF.cpp index 6ff21429e5..c7780575e6 100644 --- a/faiss/IndexIVF.cpp +++ b/faiss/IndexIVF.cpp @@ -977,14 +977,12 @@ void IndexIVF::search_and_reconstruct( std::min(nlist, params ? params->nprobe : this->nprobe); FAISS_THROW_IF_NOT(nprobe > 0); - idx_t* idx = new idx_t[n * nprobe]; - ScopeDeleter del(idx); - float* coarse_dis = new float[n * nprobe]; - ScopeDeleter del2(coarse_dis); + std::unique_ptr idx(new idx_t[n * nprobe]); + std::unique_ptr coarse_dis(new float[n * nprobe]); - quantizer->search(n, x, nprobe, coarse_dis, idx); + quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get()); - invlists->prefetch_lists(idx, n * nprobe); + invlists->prefetch_lists(idx.get(), n * nprobe); // search_preassigned() with `store_pairs` enabled to obtain the list_no // and offset into `codes` for reconstruction @@ -992,29 +990,94 @@ void IndexIVF::search_and_reconstruct( n, x, k, - idx, - coarse_dis, + idx.get(), + coarse_dis.get(), distances, labels, true /* store_pairs */, params); - for (idx_t i = 0; i < n; ++i) { - for (idx_t j = 0; j < k; ++j) { - idx_t ij = i * k + j; - idx_t key = labels[ij]; - float* reconstructed = recons + ij * d; - if (key < 0) { - // Fill with NaNs - memset(reconstructed, -1, sizeof(*reconstructed) * d); - } else { - int list_no = lo_listno(key); - int offset = lo_offset(key); +#pragma omp parallel for if (n * k > 1000) + for (idx_t ij = 0; ij < n * k; ij++) { + idx_t key = labels[ij]; + float* reconstructed = recons + ij * d; + if (key < 0) { + // Fill with NaNs + memset(reconstructed, -1, sizeof(*reconstructed) * d); + } else { + int list_no = lo_listno(key); + int offset = lo_offset(key); + + // Update label to the actual id + labels[ij] = invlists->get_single_id(list_no, offset); + + reconstruct_from_offset(list_no, offset, reconstructed); + } + } +} + +void IndexIVF::search_and_return_codes( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + uint8_t* codes, + bool include_listno, + const SearchParameters* params_in) const { + const IVFSearchParameters* params = nullptr; + if (params_in) { + params = dynamic_cast(params_in); + FAISS_THROW_IF_NOT_MSG(params, "IndexIVF params have incorrect type"); + } + const size_t nprobe = + std::min(nlist, params ? params->nprobe : this->nprobe); + FAISS_THROW_IF_NOT(nprobe > 0); + + std::unique_ptr idx(new idx_t[n * nprobe]); + std::unique_ptr coarse_dis(new float[n * nprobe]); + + quantizer->search(n, x, nprobe, coarse_dis.get(), idx.get()); + + invlists->prefetch_lists(idx.get(), n * nprobe); + + // search_preassigned() with `store_pairs` enabled to obtain the list_no + // and offset into `codes` for reconstruction + search_preassigned( + n, + x, + k, + idx.get(), + coarse_dis.get(), + distances, + labels, + true /* store_pairs */, + params); + + size_t code_size_1 = code_size; + if (include_listno) { + code_size_1 += coarse_code_size(); + } + +#pragma omp parallel for if (n * k > 1000) + for (idx_t ij = 0; ij < n * k; ij++) { + idx_t key = labels[ij]; + uint8_t* code1 = codes + ij * code_size_1; + + if (key < 0) { + // Fill with 0xff + memset(code1, -1, code_size_1); + } else { + int list_no = lo_listno(key); + int offset = lo_offset(key); + const uint8_t* cc = invlists->get_single_code(list_no, offset); - // Update label to the actual id - labels[ij] = invlists->get_single_id(list_no, offset); + labels[ij] = invlists->get_single_id(list_no, offset); - reconstruct_from_offset(list_no, offset, reconstructed); + if (include_listno) { + encode_listno(list_no, code1); + code1 += code_size_1 - code_size; } + memcpy(code1, cc, code_size); } } } diff --git a/faiss/IndexIVF.h b/faiss/IndexIVF.h index a4a40194f9..d0981caa42 100644 --- a/faiss/IndexIVF.h +++ b/faiss/IndexIVF.h @@ -357,6 +357,24 @@ struct IndexIVF : Index, IndexIVFInterface { float* recons, const SearchParameters* params = nullptr) const override; + /** Similar to search, but also returns the codes corresponding to the + * stored vectors for the search results. + * + * @param codes codes (n, k, code_size) + * @param include_listno + * include the list ids in the code (in this case add + * ceil(log8(nlist)) to the code size) + */ + void search_and_return_codes( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + uint8_t* recons, + bool include_listno = false, + const SearchParameters* params = nullptr) const; + /** Reconstruct a vector given the location in terms of (inv list index + * inv list offset) instead of the id. * diff --git a/faiss/IndexIVFAdditiveQuantizer.cpp b/faiss/IndexIVFAdditiveQuantizer.cpp index 75b8517a0a..edaff51ec7 100644 --- a/faiss/IndexIVFAdditiveQuantizer.cpp +++ b/faiss/IndexIVFAdditiveQuantizer.cpp @@ -149,6 +149,7 @@ struct AQInvertedListScanner : InvertedListScanner { const float* q; /// following codes come from this inverted list void set_list(idx_t list_no, float coarse_dis) override { + this->list_no = list_no; if (ia.metric_type == METRIC_L2 && ia.by_residual) { ia.quantizer->compute_residual(q0, tmp.data(), list_no); q = tmp.data(); diff --git a/faiss/impl/AdditiveQuantizer.cpp b/faiss/impl/AdditiveQuantizer.cpp index c39d870e6d..42d37f32a9 100644 --- a/faiss/impl/AdditiveQuantizer.cpp +++ b/faiss/impl/AdditiveQuantizer.cpp @@ -261,7 +261,7 @@ void AdditiveQuantizer::decode(const uint8_t* code, float* x, size_t n) const { is_trained, "The additive quantizer is not trained yet."); // standard additive quantizer decoding -#pragma omp parallel for if (n > 1000) +#pragma omp parallel for if (n > 100) for (int64_t i = 0; i < n; i++) { BitstringReader bsr(code + i * code_size, code_size); float* xi = x + i * d; diff --git a/faiss/impl/ProductQuantizer.cpp b/faiss/impl/ProductQuantizer.cpp index 0c048e7c3e..342a6dc8d1 100644 --- a/faiss/impl/ProductQuantizer.cpp +++ b/faiss/impl/ProductQuantizer.cpp @@ -306,7 +306,8 @@ void ProductQuantizer::decode(const uint8_t* code, float* x) const { } void ProductQuantizer::decode(const uint8_t* code, float* x, size_t n) const { - for (size_t i = 0; i < n; i++) { +#pragma omp parallel for if (n > 100) + for (int64_t i = 0; i < n; i++) { this->decode(code + code_size * i, x + d * i); } } diff --git a/faiss/python/__init__.py b/faiss/python/__init__.py index 427cb31625..95be4254dc 100644 --- a/faiss/python/__init__.py +++ b/faiss/python/__init__.py @@ -22,7 +22,8 @@ from faiss.extra_wrappers import kmin, kmax, pairwise_distances, rand, randint, \ lrand, randn, rand_smooth_vectors, eval_intersection, normalize_L2, \ ResultHeap, knn, Kmeans, checksum, matrix_bucket_sort_inplace, bucket_sort, \ - merge_knn_results, MapInt64ToInt64, knn_hamming + merge_knn_results, MapInt64ToInt64, knn_hamming, \ + pack_bitstrings, unpack_bitstrings __version__ = "%d.%d.%d" % (FAISS_VERSION_MAJOR, diff --git a/faiss/python/class_wrappers.py b/faiss/python/class_wrappers.py index 3beb66141c..4a6808d286 100644 --- a/faiss/python/class_wrappers.py +++ b/faiss/python/class_wrappers.py @@ -402,6 +402,74 @@ def replacement_search_and_reconstruct(self, x, k, *, params=None, D=None, I=Non ) return D, I, R + def replacement_search_and_return_codes( + self, x, k, *, + include_listnos=False, params=None, D=None, I=None, codes=None): + """Find the k nearest neighbors of the set of vectors x in the index, + and return the codes stored for these vectors + + Parameters + ---------- + x : array_like + Query vectors, shape (n, d) where d is appropriate for the index. + `dtype` must be float32. + k : int + Number of nearest neighbors. + params : SearchParameters + Search parameters of the current search (overrides the class-level params) + include_listnos : bool, optional + whether to include the list ids in the first bytes of each code + D : array_like, optional + Distance array to store the result. + I : array_like, optional + Labels array to store the result. + codes : array_like, optional + codes array to store + + Returns + ------- + D : array_like + Distances of the nearest neighbors, shape (n, k). When not enough results are found + the label is set to +Inf or -Inf. + I : array_like + Labels of the nearest neighbors, shape (n, k). When not enough results are found, + the label is set to -1 + R : array_like + Approximate (reconstructed) nearest neighbor vectors, shape (n, k, d). + """ + n, d = x.shape + assert d == self.d + x = np.ascontiguousarray(x, dtype='float32') + + assert k > 0 + + if D is None: + D = np.empty((n, k), dtype=np.float32) + else: + assert D.shape == (n, k) + + if I is None: + I = np.empty((n, k), dtype=np.int64) + else: + assert I.shape == (n, k) + + code_size_1 = self.code_size + if include_listnos: + code_size_1 += self.coarse_code_size() + + if codes is None: + codes = np.empty((n, k, code_size_1), dtype=np.uint8) + else: + assert codes.shape == (n, k, code_size_1) + + self.search_and_return_codes_c( + n, swig_ptr(x), + k, swig_ptr(D), + swig_ptr(I), swig_ptr(codes), include_listnos, + params + ) + return D, I, codes + def replacement_remove_ids(self, x): """Remove some ids from the index. This is a O(ntotal) operation by default, so could be expensive. @@ -734,6 +802,8 @@ def replacement_permute_entries(self, perm): ignore_missing=True) replace_method(the_class, 'search_and_reconstruct', replacement_search_and_reconstruct, ignore_missing=True) + replace_method(the_class, 'search_and_return_codes', + replacement_search_and_return_codes, ignore_missing=True) # these ones are IVF-specific replace_method(the_class, 'search_preassigned', diff --git a/faiss/python/extra_wrappers.py b/faiss/python/extra_wrappers.py index 02f9f2954c..d7fd05bc9f 100644 --- a/faiss/python/extra_wrappers.py +++ b/faiss/python/extra_wrappers.py @@ -14,6 +14,9 @@ import faiss +import collections.abc + + ########################################### # Wrapper for a few functions ########################################### @@ -579,3 +582,72 @@ def assign(self, x): self.index.add(self.centroids) D, I = self.index.search(x, 1) return D.ravel(), I.ravel() + + +########################################### +# Packing and unpacking bistrings +########################################### + +def is_sequence(x): + return isinstance(x, collections.abc.Sequence) + +pack_bitstrings_c = pack_bitstrings + +def pack_bitstrings(a, nbit): + """ + Pack a set integers (i, j) where i=0:n and j=0:M into + n bitstrings. + Output is an uint8 array of size (n, code_size), where code_size is + such that at most 7 bits per code are wasted. + + If nbit is an integer: all entries takes nbit bits. + If nbit is an array: entry (i, j) takes nbit[j] bits. + """ + n, M = a.shape + a = np.ascontiguousarray(a, dtype='int32') + if is_sequence(nbit): + nbit = np.ascontiguousarray(nbit, dtype='int32') + assert nbit.shape == (M,) + code_size = int((nbit.sum() + 7) // 8) + b = np.empty((n, code_size), dtype='uint8') + pack_bitstrings_c( + n, M, swig_ptr(nbit), swig_ptr(a), swig_ptr(b), code_size) + else: + code_size = (M * nbit + 7) // 8 + b = np.empty((n, code_size), dtype='uint8') + pack_bitstrings_c(n, M, nbit, swig_ptr(a), swig_ptr(b), code_size) + return b + +unpack_bitstrings_c = unpack_bitstrings + +def unpack_bitstrings(b, M_or_nbits, nbit=None): + """ + Unpack a set integers (i, j) where i=0:n and j=0:M from + n bitstrings (encoded as uint8s). + Input is an uint8 array of size (n, code_size), where code_size is + such that at most 7 bits per code are wasted. + + Two forms: + - when called with (array, M, nbit): there are M entries of size + nbit per row + - when called with (array, nbits): element (i, j) is encoded in + nbits[j] bits + """ + n, code_size = b.shape + if nbit is None: + nbit = np.ascontiguousarray(M_or_nbits, dtype='int32') + M = len(nbit) + min_code_size = int((nbit.sum() + 7) // 8) + assert code_size >= min_code_size + a = np.empty((n, M), dtype='int32') + unpack_bitstrings_c( + n, M, swig_ptr(nbit), + swig_ptr(b), code_size, swig_ptr(a)) + else: + M = M_or_nbits + min_code_size = (M * nbit + 7) // 8 + assert code_size >= min_code_size + a = np.empty((n, M), dtype='int32') + unpack_bitstrings_c( + n, M, nbit, swig_ptr(b), code_size, swig_ptr(a)) + return a diff --git a/faiss/utils/hamming.cpp b/faiss/utils/hamming.cpp index 773cd34530..14b84f7ab6 100644 --- a/faiss/utils/hamming.cpp +++ b/faiss/utils/hamming.cpp @@ -681,4 +681,88 @@ void generalized_hammings_knn_hc( ha->reorder(); } +void pack_bitstrings( + size_t n, + size_t M, + int nbit, + const int32_t* unpacked, + uint8_t* packed, + size_t code_size) { + FAISS_THROW_IF_NOT(code_size >= (M * nbit + 7) / 8); +#pragma omp parallel for if (n > 1000) + for (int64_t i = 0; i < n; i++) { + const int32_t* in = unpacked + i * M; + uint8_t* out = packed + i * code_size; + BitstringWriter wr(out, code_size); + for (int j = 0; j < M; j++) { + wr.write(in[j], nbit); + } + } +} + +void pack_bitstrings( + size_t n, + size_t M, + const int32_t* nbit, + const int32_t* unpacked, + uint8_t* packed, + size_t code_size) { + int totbit = 0; + for (int j = 0; j < M; j++) { + totbit += nbit[j]; + } + FAISS_THROW_IF_NOT(code_size >= (totbit + 7) / 8); +#pragma omp parallel for if (n > 1000) + for (int64_t i = 0; i < n; i++) { + const int32_t* in = unpacked + i * M; + uint8_t* out = packed + i * code_size; + BitstringWriter wr(out, code_size); + for (int j = 0; j < M; j++) { + wr.write(in[j], nbit[j]); + } + } +} + +void unpack_bitstrings( + size_t n, + size_t M, + int nbit, + const uint8_t* packed, + size_t code_size, + int32_t* unpacked) { + FAISS_THROW_IF_NOT(code_size >= (M * nbit + 7) / 8); +#pragma omp parallel for if (n > 1000) + for (int64_t i = 0; i < n; i++) { + const uint8_t* in = packed + i * code_size; + int32_t* out = unpacked + i * M; + BitstringReader rd(in, code_size); + for (int j = 0; j < M; j++) { + out[j] = rd.read(nbit); + } + } +} + +void unpack_bitstrings( + size_t n, + size_t M, + const int32_t* nbit, + const uint8_t* packed, + size_t code_size, + int32_t* unpacked) { + int totbit = 0; + for (int j = 0; j < M; j++) { + totbit += nbit[j]; + } + FAISS_THROW_IF_NOT(code_size >= (totbit + 7) / 8); +#pragma omp parallel for if (n > 1000) + for (int64_t i = 0; i < n; i++) { + const uint8_t* in = packed + i * code_size; + int32_t* out = unpacked + i * M; + BitstringReader rd(in, code_size); + for (int j = 0; j < M; j++) { + out[j] = rd.read(nbit[j]); + } + } +} + } // namespace faiss diff --git a/faiss/utils/hamming.h b/faiss/utils/hamming.h index 84187d245c..7cdc05d252 100644 --- a/faiss/utils/hamming.h +++ b/faiss/utils/hamming.h @@ -222,6 +222,64 @@ void generalized_hammings_knn_hc( size_t code_size, int ordered = true); +/** Pack a set of n codes of size M * nbit + * + * @param n number of codes to pack + * @param M number of elementary codes per code + * @param nbit number of bits per elementary code + * @param unpacked input unpacked codes, size (n, M) + * @param packed output packed codes, size (n, code_size) + * @param code_size should be >= ceil(M * nbit / 8) + */ +void pack_bitstrings( + size_t n, + size_t M, + int nbit, + const int32_t* unpacked, + uint8_t* packed, + size_t code_size); + +/** Pack a set of n codes of variable sizes + * + * @param nbit number of bits per entry (size M) + */ +void pack_bitstrings( + size_t n, + size_t M, + const int32_t* nbits, + const int32_t* unpacked, + uint8_t* packed, + size_t code_size); + +/** Unpack a set of n codes of size M * nbit + * + * @param n number of codes to pack + * @param M number of elementary codes per code + * @param nbit number of bits per elementary code + * @param unpacked input unpacked codes, size (n, M) + * @param packed output packed codes, size (n, code_size) + * @param code_size should be >= ceil(M * nbit / 8) + */ +void unpack_bitstrings( + size_t n, + size_t M, + int nbit, + const uint8_t* packed, + size_t code_size, + int32_t* unpacked); + +/** Unpack a set of n codes of variable sizes + * + * @param nbit number of bits per entry (size M) + */ +void unpack_bitstrings( + size_t n, + size_t M, + const int32_t* nbits, + const uint8_t* packed, + size_t code_size, + int32_t* unpacked); + } // namespace faiss #include diff --git a/tests/test_index_composite.py b/tests/test_index_composite.py index 81a00cb938..a760c0cf09 100644 --- a/tests/test_index_composite.py +++ b/tests/test_index_composite.py @@ -14,7 +14,7 @@ import tempfile import platform -from common_faiss_tests import get_dataset_2 +from common_faiss_tests import get_dataset_2, get_dataset from faiss.contrib.datasets import SyntheticDataset from faiss.contrib.inspect_tools import make_LinearTransform_matrix from faiss.contrib.evaluation import check_ref_knn_with_draws @@ -822,3 +822,158 @@ def test_precomputed_tables(self): np.testing.assert_array_equal(Dnew, D2) np.testing.assert_array_equal(Inew, I2) + + + +class TestSearchAndReconstruct(unittest.TestCase): + + def run_search_and_reconstruct(self, index, xb, xq, k=10, eps=None): + n, d = xb.shape + assert xq.shape[1] == d + assert index.d == d + + D_ref, I_ref = index.search(xq, k) + R_ref = index.reconstruct_n(0, n) + D, I, R = index.search_and_reconstruct(xq, k) + + np.testing.assert_almost_equal(D, D_ref, decimal=5) + self.assertTrue((I == I_ref).all()) + self.assertEqual(R.shape[:2], I.shape) + self.assertEqual(R.shape[2], d) + + # (n, k, ..) -> (n * k, ..) + I_flat = I.reshape(-1) + R_flat = R.reshape(-1, d) + # Filter out -1s when not enough results + R_flat = R_flat[I_flat >= 0] + I_flat = I_flat[I_flat >= 0] + + recons_ref_err = np.mean(np.linalg.norm(R_flat - R_ref[I_flat])) + self.assertLessEqual(recons_ref_err, 1e-6) + + def norm1(x): + return np.sqrt((x ** 2).sum(axis=1)) + + recons_err = np.mean(norm1(R_flat - xb[I_flat])) + + print('Reconstruction error = %.3f' % recons_err) + if eps is not None: + self.assertLessEqual(recons_err, eps) + + return D, I, R + + def test_IndexFlat(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset(d, nb, nt, nq) + + index = faiss.IndexFlatL2(d) + index.add(xb) + + self.run_search_and_reconstruct(index, xb, xq, eps=0.0) + + def test_IndexIVFFlat(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset(d, nb, nt, nq) + + quantizer = faiss.IndexFlatL2(d) + index = faiss.IndexIVFFlat(quantizer, d, 32, faiss.METRIC_L2) + index.cp.min_points_per_centroid = 5 # quiet warning + index.nprobe = 4 + index.train(xt) + index.add(xb) + + self.run_search_and_reconstruct(index, xb, xq, eps=0.0) + + def test_IndexIVFPQ(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset(d, nb, nt, nq) + + quantizer = faiss.IndexFlatL2(d) + index = faiss.IndexIVFPQ(quantizer, d, 32, 8, 8) + index.cp.min_points_per_centroid = 5 # quiet warning + index.nprobe = 4 + index.train(xt) + index.add(xb) + + self.run_search_and_reconstruct(index, xb, xq, eps=1.0) + + def test_MultiIndex(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset(d, nb, nt, nq) + + index = faiss.index_factory(d, "IMI2x5,PQ8np") + faiss.ParameterSpace().set_index_parameter(index, "nprobe", 4) + index.train(xt) + index.add(xb) + + self.run_search_and_reconstruct(index, xb, xq, eps=1.0) + + def test_IndexTransform(self): + d = 32 + nb = 1000 + nt = 1500 + nq = 200 + + (xt, xb, xq) = get_dataset(d, nb, nt, nq) + + index = faiss.index_factory(d, "L2norm,PCA8,IVF32,PQ8np") + faiss.ParameterSpace().set_index_parameter(index, "nprobe", 4) + index.train(xt) + index.add(xb) + + self.run_search_and_reconstruct(index, xb, xq) + + +class TestSearchAndGetCodes(unittest.TestCase): + + def do_test(self, factory_string): + ds = SyntheticDataset(32, 1000, 100, 10) + + index = faiss.index_factory(ds.d, factory_string) + + index.train(ds.get_train()) + index.add(ds.get_database()) + + index.nprobe + index.nprobe = 10 + Dref, Iref = index.search(ds.get_queries(), 10) + + #print(index.search_and_return_codes) + D, I, codes = index.search_and_return_codes( + ds.get_queries(), 10, include_listnos=True) + + np.testing.assert_array_equal(I, Iref) + np.testing.assert_array_equal(D, Dref) + + # verify that we get the same distances when decompressing from + # returned codes (the codes are compatible with sa_decode) + for qi in range(ds.nq): + q = ds.get_queries()[qi] + xbi = index.sa_decode(codes[qi]) + D2 = ((q - xbi) ** 2).sum(1) + np.testing.assert_allclose(D2, D[qi], rtol=1e-5) + + def test_ivfpq(self): + self.do_test("IVF20,PQ4x4np") + + def test_ivfsq(self): + self.do_test("IVF20,SQ8") + + def test_ivfrq(self): + self.do_test("IVF20,RQ3x4") diff --git a/tests/test_standalone_codec.py b/tests/test_standalone_codec.py index 1e1993bb4c..7fdcf6849f 100644 --- a/tests/test_standalone_codec.py +++ b/tests/test_standalone_codec.py @@ -266,9 +266,9 @@ def test_ZnSphereCodecAlt24(self): class TestBitstring(unittest.TestCase): - """ Low-level bit string tests """ def test_rw(self): + """ Low-level bit string tests """ rs = np.random.RandomState(1234) nbyte = 1000 sz = 0 @@ -311,6 +311,26 @@ def test_rw(self): # print('nbit %d xref %x xnew %x' % (nbit, xref, xnew)) self.assertTrue(xnew == xref) + def test_arrays(self): + nbit = 5 + M = 10 + n = 20 + rs = np.random.RandomState(123) + a = rs.randint(1<