diff --git a/faiss/cppcontrib/factory_tools.cpp b/faiss/cppcontrib/factory_tools.cpp index 5c513a97b5..d1f283b8ff 100644 --- a/faiss/cppcontrib/factory_tools.cpp +++ b/faiss/cppcontrib/factory_tools.cpp @@ -25,7 +25,15 @@ const std::map sq_types = { }; int get_hnsw_M(const faiss::IndexHNSW* index) { - if (index->hnsw.cum_nneighbor_per_level.size() >= 1) { + if (index->hnsw.cum_nneighbor_per_level.size() > 1) { + return index->hnsw.cum_nneighbor_per_level[1] / 2; + } + // Avoid runtime error, just return 0. + return 0; +} + +int get_hnsw_M(const faiss::IndexBinaryHNSW* index) { + if (index->hnsw.cum_nneighbor_per_level.size() > 1) { return index->hnsw.cum_nneighbor_per_level[1] / 2; } // Avoid runtime error, just return 0. @@ -153,4 +161,32 @@ std::string reverse_index_factory(const faiss::Index* index) { return ""; } +std::string reverse_index_factory(const faiss::IndexBinary* index) { + std::string prefix; + if (dynamic_cast(index)) { + return "BFlat"; + } else if ( + const faiss::IndexBinaryIVF* ivf_index = + dynamic_cast(index)) { + const faiss::IndexBinary* quantizer = ivf_index->quantizer; + + if (dynamic_cast(quantizer)) { + return "BIVF" + std::to_string(ivf_index->nlist); + } else if ( + const faiss::IndexBinaryHNSW* hnsw_index = + dynamic_cast( + quantizer)) { + return "BIVF" + std::to_string(ivf_index->nlist) + "_HNSW" + + std::to_string(get_hnsw_M(hnsw_index)); + } + // Add further cases for BinaryIVF here. + } else if ( + const faiss::IndexBinaryHNSW* hnsw_index = + dynamic_cast(index)) { + return "BHNSW" + std::to_string(get_hnsw_M(hnsw_index)); + } + // Avoid runtime error, just return empty string for logging. + return ""; +} + } // namespace faiss diff --git a/faiss/cppcontrib/factory_tools.h b/faiss/cppcontrib/factory_tools.h index 4e4f68cbf8..f83a6db4ad 100644 --- a/faiss/cppcontrib/factory_tools.h +++ b/faiss/cppcontrib/factory_tools.h @@ -9,6 +9,9 @@ #pragma once +#include +#include +#include #include #include #include @@ -21,5 +24,6 @@ namespace faiss { std::string reverse_index_factory(const faiss::Index* index); +std::string reverse_index_factory(const faiss::IndexBinary* index); } // namespace faiss