diff --git a/faiss_vector_posting.go b/faiss_vector_posting.go index 6b9840f..0f169a9 100644 --- a/faiss_vector_posting.go +++ b/faiss_vector_posting.go @@ -475,7 +475,9 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, requiresFiltering bool } } - selector, err = faiss.NewIDSelectorNot(ineligibleVectorIDs) + if len(vectorIDsToInclude) > 0 { + selector, err = faiss.NewIDSelectorNot(ineligibleVectorIDs) + } } else { // Getting the vector IDs corresponding to the eligible // doc IDs. @@ -508,45 +510,49 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, requiresFiltering bool } } - selector, err = faiss.NewIDSelectorBatch(vectorIDsToInclude) + if len(vectorIDsToInclude) > 0 { + selector, err = faiss.NewIDSelectorBatch(vectorIDsToInclude) + } } if err != nil { return nil, err } - // Ordering the retrieved centroid IDs by increasing order - // of distance i.e. decreasing order of proximity to query vector. - closestCentroidIDs, centroidDistances, _ := - vecIndex.ObtainClustersWithDistancesFromIVFIndex(qVector, - eligibleCentroidIDs) - - // Getting the nprobe value set at index time. - nprobe := vecIndex.GetNProbe() - - eligibleDocsTillNow := int64(0) - minEligibleCentroids := 0 - for i, centroidID := range closestCentroidIDs { - eligibleDocsTillNow += int64(centroidVecIDMap[centroidID].GetCardinality()) - if eligibleDocsTillNow >= k && i >= int(nprobe-1) { - // Continue till at least 'K' cumulative vectors are - // collected or 'nprobe' clusters are examined, whichever - // comes later. + if len(vectorIDsToInclude) > 0 { + // Ordering the retrieved centroid IDs by increasing order + // of distance i.e. decreasing order of proximity to query vector. + closestCentroidIDs, centroidDistances, _ := + vecIndex.ObtainClustersWithDistancesFromIVFIndex(qVector, + eligibleCentroidIDs) + + // Getting the nprobe value set at index time. + nprobe := vecIndex.GetNProbe() + + eligibleDocsTillNow := int64(0) + minEligibleCentroids := 0 + for i, centroidID := range closestCentroidIDs { + eligibleDocsTillNow += int64(centroidVecIDMap[centroidID].GetCardinality()) + if eligibleDocsTillNow >= k && i >= int(nprobe-1) { + // Continue till at least 'K' cumulative vectors are + // collected or 'nprobe' clusters are examined, whichever + // comes later. + minEligibleCentroids = i + 1 + break + } minEligibleCentroids = i + 1 - break } - minEligibleCentroids = i + 1 - } - // Search the clusters specified by 'closestCentroidIDs' for - // vectors whose IDs are present in 'vectorIDsToInclude' - scores, ids, err := vecIndex.SearchClustersFromIVFIndex( - selector, len(vectorIDsToInclude), closestCentroidIDs, - minEligibleCentroids, k, qVector, centroidDistances, params) - if err != nil { - return nil, err - } + // Search the clusters specified by 'closestCentroidIDs' for + // vectors whose IDs are present in 'vectorIDsToInclude' + scores, ids, err := vecIndex.SearchClustersFromIVFIndex( + selector, len(vectorIDsToInclude), closestCentroidIDs, + minEligibleCentroids, k, qVector, centroidDistances, params) + if err != nil { + return nil, err + } - addIDsToPostingsList(rv, ids, scores) + addIDsToPostingsList(rv, ids, scores) + } return rv, nil } return rv, nil