Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MB-62230 - Support for pre-filtering with kNN #255

Merged
merged 12 commits into from
Sep 9, 2024
122 changes: 82 additions & 40 deletions faiss_vector_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,54 +52,84 @@ func (vc *vectorIndexCache) Clear() {
vc.m.Unlock()
}

func (vc *vectorIndexCache) loadOrCreate(fieldID uint16, mem []byte, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, vecIDsToExclude []int64, err error) {
var found bool
index, vecDocIDMap, vecIDsToExclude, found = vc.loadFromCache(fieldID, except)
if !found {
index, vecDocIDMap, vecIDsToExclude, err = vc.createAndCache(fieldID, mem, except)
}
return index, vecDocIDMap, vecIDsToExclude, err
// loadDocVecIDMap indicates if a non-nil docVecIDMap should be returned.
// It is true when a filtered kNN query accesses the cache since it requires the
// map. It's false otherwise.
func (vc *vectorIndexCache) loadOrCreate(fieldID uint16, mem []byte,
loadDocVecIDMap bool, except *roaring.Bitmap) (
metonymic-smokey marked this conversation as resolved.
Show resolved Hide resolved
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, docVecIDMap map[uint32][]int64,
vecIDsToExclude []int64, err error) {
index, vecDocIDMap, docVecIDMap, vecIDsToExclude, err = vc.loadFromCache(
fieldID, loadDocVecIDMap, mem, except)
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, err
}

func (vc *vectorIndexCache) loadFromCache(fieldID uint16, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, vecIDsToExclude []int64, found bool) {
// function to load the vectorDocIDMap and if required, docVecIDMap from cache
// If not, it will create these and add them to the cache.
func (vc *vectorIndexCache) loadFromCache(fieldID uint16, loadDocVecIDMap bool,
mem []byte, except *roaring.Bitmap) (index *faiss.IndexImpl, vecDocIDMap map[int64]uint32,
docVecIDMap map[uint32][]int64, vecIDsToExclude []int64, err error) {

vc.m.RLock()
defer vc.m.RUnlock()

entry, ok := vc.cache[fieldID]
if !ok {
return nil, nil, nil, false
}

index, vecDocIDMap = entry.load()
vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except)
if ok {
vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except)
index, vecDocIDMap, docVecIDMap = entry.load()
if !loadDocVecIDMap || (loadDocVecIDMap && len(entry.docVecIDMap) > 0) {
vc.m.RUnlock()
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
}

return index, vecDocIDMap, vecIDsToExclude, true
}
vc.m.RUnlock()
vc.m.Lock()
// in cases where only the docVecID isn't part of the cache, build it and
// add it to the cache, while holding a lock to avoid concurrent modifications.
// typically seen for the first filtered query.
docVecIDMap = vc.addDocVecIDMapToCacheLOCKED(entry)
vc.m.Unlock()
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
abhinavdangeti marked this conversation as resolved.
Show resolved Hide resolved
}

func (vc *vectorIndexCache) createAndCache(fieldID uint16, mem []byte, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, vecIDsToExclude []int64, err error) {
vc.m.RUnlock()
// acquiring a lock since this is modifying the cache.
vc.m.Lock()
defer vc.m.Unlock()
return vc.createAndCacheLOCKED(fieldID, mem, loadDocVecIDMap, except)
}

// when there are multiple threads trying to build the index, guard redundant
// index creation by doing a double check and return if already created and
// cached.
entry, ok := vc.cache[fieldID]
if ok {
index, vecDocIDMap = entry.load()
vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except)
return index, vecDocIDMap, vecIDsToExclude, nil
func (vc *vectorIndexCache) addDocVecIDMapToCacheLOCKED(ce *cacheEntry) map[uint32][]int64 {
// Handle concurrent accesses (to avoid unnecessary work) by adding a
// check within the write lock here.
if ce.docVecIDMap != nil {
return ce.docVecIDMap
}

// if the cache doesn't have entry, construct the vector to doc id map and the
// vector index out of the mem bytes and update the cache under lock.
docVecIDMap := make(map[uint32][]int64)
for vecID, docID := range ce.vecDocIDMap {
docVecIDMap[docID] = append(docVecIDMap[docID], vecID)
}

ce.docVecIDMap = docVecIDMap
return docVecIDMap
}

// Rebuilding the cache on a miss.
func (vc *vectorIndexCache) createAndCacheLOCKED(fieldID uint16, mem []byte,
loadDocVecIDMap bool, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32,
docVecIDMap map[uint32][]int64, vecIDsToExclude []int64, err error) {

// if the cache doesn't have the entry, construct the vector to doc id map and
// the vector index out of the mem bytes and update the cache under lock.
pos := 0
numVecs, n := binary.Uvarint(mem[pos : pos+binary.MaxVarintLen64])
pos += n

vecDocIDMap = make(map[int64]uint32, numVecs)
if loadDocVecIDMap {
docVecIDMap = make(map[uint32][]int64, numVecs)
}
isExceptNotEmpty := except != nil && !except.IsEmpty()
for i := 0; i < int(numVecs); i++ {
vecID, n := binary.Varint(mem[pos : pos+binary.MaxVarintLen64])
Expand All @@ -113,22 +143,26 @@ func (vc *vectorIndexCache) createAndCache(fieldID uint16, mem []byte, except *r
continue
}
vecDocIDMap[vecID] = docIDUint32
if loadDocVecIDMap {
docVecIDMap[docIDUint32] = append(docVecIDMap[docIDUint32], vecID)
}
}

indexSize, n := binary.Uvarint(mem[pos : pos+binary.MaxVarintLen64])
pos += n

index, err = faiss.ReadIndexFromBuffer(mem[pos:pos+int(indexSize)], faissIOFlags)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}

vc.insertLOCKED(fieldID, index, vecDocIDMap)
return index, vecDocIDMap, vecIDsToExclude, nil
vc.insertLOCKED(fieldID, index, vecDocIDMap, loadDocVecIDMap, docVecIDMap)
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
}

func (vc *vectorIndexCache) insertLOCKED(fieldIDPlus1 uint16,
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32) {
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, loadDocVecIDMap bool,
docVecIDMap map[uint32][]int64) {
// the first time we've hit the cache, try to spawn a monitoring routine
// which will reconcile the moving averages for all the fields being hit
if len(vc.cache) == 0 {
Expand All @@ -142,7 +176,8 @@ func (vc *vectorIndexCache) insertLOCKED(fieldIDPlus1 uint16,
// this makes the average to be kept above the threshold value for a
// longer time and thereby the index to be resident in the cache
// for longer time.
vc.cache[fieldIDPlus1] = createCacheEntry(index, vecDocIDMap, 0.4)
vc.cache[fieldIDPlus1] = createCacheEntry(index, vecDocIDMap,
loadDocVecIDMap, docVecIDMap, 0.4)
}
}

Expand Down Expand Up @@ -235,8 +270,9 @@ func (e *ewma) add(val uint64) {

// -----------------------------------------------------------------------------

func createCacheEntry(index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, alpha float64) *cacheEntry {
return &cacheEntry{
func createCacheEntry(index *faiss.IndexImpl, vecDocIDMap map[int64]uint32,
loadDocVecIDMap bool, docVecIDMap map[uint32][]int64, alpha float64) *cacheEntry {
ce := &cacheEntry{
index: index,
vecDocIDMap: vecDocIDMap,
tracker: &ewma{
Expand All @@ -245,6 +281,10 @@ func createCacheEntry(index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, alph
},
refs: 1,
}
if loadDocVecIDMap {
ce.docVecIDMap = docVecIDMap
}
return ce
}

type cacheEntry struct {
Expand All @@ -257,6 +297,7 @@ type cacheEntry struct {

index *faiss.IndexImpl
vecDocIDMap map[int64]uint32
docVecIDMap map[uint32][]int64
}

func (ce *cacheEntry) incHit() {
Expand All @@ -271,17 +312,18 @@ func (ce *cacheEntry) decRef() {
atomic.AddInt64(&ce.refs, -1)
}

func (ce *cacheEntry) load() (*faiss.IndexImpl, map[int64]uint32) {
func (ce *cacheEntry) load() (*faiss.IndexImpl, map[int64]uint32, map[uint32][]int64) {
ce.incHit()
ce.addRef()
return ce.index, ce.vecDocIDMap
return ce.index, ce.vecDocIDMap, ce.docVecIDMap
}

func (ce *cacheEntry) close() {
go func() {
ce.index.Close()
ce.index = nil
ce.vecDocIDMap = nil
ce.docVecIDMap = nil
}()
}

Expand Down
105 changes: 83 additions & 22 deletions faiss_vector_posting.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,16 +267,26 @@ func (vpl *VecPostingsIterator) BytesWritten() uint64 {

// vectorIndexWrapper conforms to scorch_segment_api's VectorIndex interface
type vectorIndexWrapper struct {
search func(qVector []float32, k int64, params json.RawMessage) (segment.VecPostingsList, error)
close func()
size func() uint64
search func(qVector []float32, k int64,
params json.RawMessage) (segment.VecPostingsList, error)
searchWithFilter func(qVector []float32, k int64, eligibleDocIDs []uint64,
params json.RawMessage) (segment.VecPostingsList, error)
close func()
size func() uint64
}

func (i *vectorIndexWrapper) Search(qVector []float32, k int64, params json.RawMessage) (
func (i *vectorIndexWrapper) Search(qVector []float32, k int64,
params json.RawMessage) (
segment.VecPostingsList, error) {
return i.search(qVector, k, params)
}

func (i *vectorIndexWrapper) SearchWithFilter(qVector []float32, k int64,
eligibleDocIDs []uint64, params json.RawMessage) (
segment.VecPostingsList, error) {
return i.searchWithFilter(qVector, k, eligibleDocIDs, params)
}

func (i *vectorIndexWrapper) Close() {
i.close()
}
Expand All @@ -288,20 +298,40 @@ func (i *vectorIndexWrapper) Size() uint64 {
// InterpretVectorIndex returns a construct of closures (vectorIndexWrapper)
// that will allow the caller to -
// (1) search within an attached vector index
// (2) close attached vector index
// (3) get the size of the attached vector index
func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap) (
// (2) search limited to a subset of documents within an attached vector index
// (3) close attached vector index
// (4) get the size of the attached vector index
func (sb *SegmentBase) InterpretVectorIndex(field string, requiresFiltering bool,
except *roaring.Bitmap) (
segment.VectorIndex, error) {
// Params needed for the closures
var vecIndex *faiss.IndexImpl
var vecDocIDMap map[int64]uint32
var docVecIDMap map[uint32][]int64
var vectorIDsToExclude []int64
var fieldIDPlus1 uint16
var vecIndexSize uint64

// Utility function to add the corresponding docID and scores for each vector
// returned after the kNN query to the newly
// created vecPostingsList
addIDsToPostingsList := func(pl *VecPostingsList, ids []int64, scores []float32) {
for i := 0; i < len(ids); i++ {
vecID := ids[i]
// Checking if it's present in the vecDocIDMap.
// If -1 is returned as an ID(insufficient vectors), this will ensure
// it isn't added to the final postings list.
if docID, ok := vecDocIDMap[vecID]; ok {
code := getVectorCode(docID, scores[i])
pl.postings.Add(uint64(code))
}
}
}

var (
wrapVecIndex = &vectorIndexWrapper{
search: func(qVector []float32, k int64, params json.RawMessage) (segment.VecPostingsList, error) {
search: func(qVector []float32, k int64, params json.RawMessage) (
segment.VecPostingsList, error) {
// 1. returned postings list (of type PostingsList) has two types of information - docNum and its score.
// 2. both the values can be represented using roaring bitmaps.
// 3. the Iterator (of type PostingsIterator) returned would operate in terms of VecPostings.
Expand All @@ -318,23 +348,53 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap
return rv, nil
}

scores, ids, err := vecIndex.SearchWithoutIDs(qVector, k, vectorIDsToExclude, params)
scores, ids, err := vecIndex.SearchWithoutIDs(qVector, k,
vectorIDsToExclude, params)
if err != nil {
return nil, err
}
// for every similar vector returned by the Search() API, add the corresponding
// docID and the score to the newly created vecPostingsList
for i := 0; i < len(ids); i++ {
vecID := ids[i]
// Checking if it's present in the vecDocIDMap.
// If -1 is returned as an ID(insufficient vectors), this will ensure
// it isn't added to the final postings list.
if docID, ok := vecDocIDMap[vecID]; ok {
code := getVectorCode(docID, scores[i])
rv.postings.Add(uint64(code))
}

addIDsToPostingsList(rv, ids, scores)

return rv, nil
},
searchWithFilter: func(qVector []float32, k int64,
eligibleDocIDs []uint64, params json.RawMessage) (
segment.VecPostingsList, error) {
// 1. returned postings list (of type PostingsList) has two types of information - docNum and its score.
// 2. both the values can be represented using roaring bitmaps.
// 3. the Iterator (of type PostingsIterator) returned would operate in terms of VecPostings.
// 4. VecPostings would just have the docNum and the score. Every call of Next()
// and Advance just returns the next VecPostings. The caller would do a vp.Number()
// and the Score() to get the corresponding values
rv := &VecPostingsList{
except: nil, // todo: handle the except bitmap within postings iterator.
postings: roaring64.New(),
}

if vecIndex == nil || vecIndex.D() != len(qVector) {
// vector index not found or dimensionality mismatched
return rv, nil
}

if len(eligibleDocIDs) > 0 {
// Non-zero documents eligible per the filter query.

// vector IDs corresponding to the local doc numbers to be
// considered for the search
vectorIDsToInclude := make([]int64, 0, len(eligibleDocIDs))
for _, id := range eligibleDocIDs {
vectorIDsToInclude = append(vectorIDsToInclude, docVecIDMap[uint32(id)]...)
}

scores, ids, err := vecIndex.SearchWithIDs(qVector, k,
vectorIDsToInclude, params)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't you be removing vectorIDsToExclude from vectorIDsToInclude before this or are we certain there never really is going to be an overlap there because of the pre-filter?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per my understanding, deleted results aren't returned as part of the initial filter search, whose results form the basis of the vector include list.

if err != nil {
return nil, err
}

addIDsToPostingsList(rv, ids, scores)
}
return rv, nil
},
close: func() {
Expand Down Expand Up @@ -372,8 +432,9 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap
pos += n
}

vecIndex, vecDocIDMap, vectorIDsToExclude, err =
sb.vecIndexCache.loadOrCreate(fieldIDPlus1, sb.mem[pos:], except)
vecIndex, vecDocIDMap, docVecIDMap, vectorIDsToExclude, err =
sb.vecIndexCache.loadOrCreate(fieldIDPlus1, sb.mem[pos:], requiresFiltering,
except)

if vecIndex != nil {
vecIndexSize = vecIndex.Size()
Expand Down
Loading