Skip to content

Commit

Permalink
MB-62230 - Support for pre-filtering with kNN
Browse files Browse the repository at this point in the history
  • Loading branch information
metonymic-smokey committed Aug 12, 2024
1 parent e177648 commit 356757b
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 37 deletions.
44 changes: 29 additions & 15 deletions faiss_vector_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,33 +53,37 @@ func (vc *vectorIndexCache) Clear() {
}

func (vc *vectorIndexCache) loadOrCreate(fieldID uint16, mem []byte, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, vecIDsToExclude []int64, err error) {
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, docVecIDMap map[uint32]int64,
vecIDsToExclude []int64, err error) {
var found bool
index, vecDocIDMap, vecIDsToExclude, found = vc.loadFromCache(fieldID, except)
index, vecDocIDMap, docVecIDMap, vecIDsToExclude, found = vc.loadFromCache(fieldID, except)
if !found {
index, vecDocIDMap, vecIDsToExclude, err = vc.createAndCache(fieldID, mem, except)
index, vecDocIDMap, docVecIDMap, vecIDsToExclude, err =
vc.createAndCache(fieldID, mem, except)
}
return index, vecDocIDMap, vecIDsToExclude, err
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, err
}

func (vc *vectorIndexCache) loadFromCache(fieldID uint16, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, vecIDsToExclude []int64, found bool) {
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, docVecIDMap map[uint32]int64,
vecIDsToExclude []int64, found bool) {
vc.m.RLock()
defer vc.m.RUnlock()

entry, ok := vc.cache[fieldID]
if !ok {
return nil, nil, nil, false
return nil, nil, nil, nil, false
}

index, vecDocIDMap = entry.load()
index, vecDocIDMap, docVecIDMap = entry.load()
vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except)

return index, vecDocIDMap, vecIDsToExclude, true
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, true
}

func (vc *vectorIndexCache) createAndCache(fieldID uint16, mem []byte, except *roaring.Bitmap) (
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, vecIDsToExclude []int64, err error) {
index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, docVecIDMap map[uint32]int64,
vecIDsToExclude []int64, err error) {
vc.m.Lock()
defer vc.m.Unlock()

Expand All @@ -88,9 +92,9 @@ func (vc *vectorIndexCache) createAndCache(fieldID uint16, mem []byte, except *r
// cached.
entry, ok := vc.cache[fieldID]
if ok {
index, vecDocIDMap = entry.load()
index, vecDocIDMap, docVecIDMap = entry.load()
vecIDsToExclude = getVecIDsToExclude(vecDocIDMap, except)
return index, vecDocIDMap, vecIDsToExclude, nil
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
}

// if the cache doesn't have entry, construct the vector to doc id map and the
Expand All @@ -100,6 +104,7 @@ func (vc *vectorIndexCache) createAndCache(fieldID uint16, mem []byte, except *r
pos += n

vecDocIDMap = make(map[int64]uint32, numVecs)
docVecIDMap = make(map[uint32]int64, numVecs)
isExceptNotEmpty := except != nil && !except.IsEmpty()
for i := 0; i < int(numVecs); i++ {
vecID, n := binary.Varint(mem[pos : pos+binary.MaxVarintLen64])
Expand All @@ -113,18 +118,19 @@ func (vc *vectorIndexCache) createAndCache(fieldID uint16, mem []byte, except *r
continue
}
vecDocIDMap[vecID] = docIDUint32
docVecIDMap[docIDUint32] = vecID
}

indexSize, n := binary.Uvarint(mem[pos : pos+binary.MaxVarintLen64])
pos += n

index, err = faiss.ReadIndexFromBuffer(mem[pos:pos+int(indexSize)], faissIOFlags)
if err != nil {
return nil, nil, nil, err
return nil, nil, nil, nil, err
}

vc.insertLOCKED(fieldID, index, vecDocIDMap)
return index, vecDocIDMap, vecIDsToExclude, nil
return index, vecDocIDMap, docVecIDMap, vecIDsToExclude, nil
}

func (vc *vectorIndexCache) insertLOCKED(fieldIDPlus1 uint16,
Expand Down Expand Up @@ -236,9 +242,15 @@ func (e *ewma) add(val uint64) {
// -----------------------------------------------------------------------------

func createCacheEntry(index *faiss.IndexImpl, vecDocIDMap map[int64]uint32, alpha float64) *cacheEntry {
docVecIDMap := make(map[uint32]int64)
for vecID, docID := range vecDocIDMap {
docVecIDMap[docID] = vecID
}

return &cacheEntry{
index: index,
vecDocIDMap: vecDocIDMap,
docIDVecMap: docVecIDMap,
tracker: &ewma{
alpha: alpha,
sample: 1,
Expand All @@ -257,6 +269,7 @@ type cacheEntry struct {

index *faiss.IndexImpl
vecDocIDMap map[int64]uint32
docIDVecMap map[uint32]int64
}

func (ce *cacheEntry) incHit() {
Expand All @@ -271,17 +284,18 @@ func (ce *cacheEntry) decRef() {
atomic.AddInt64(&ce.refs, -1)
}

func (ce *cacheEntry) load() (*faiss.IndexImpl, map[int64]uint32) {
func (ce *cacheEntry) load() (*faiss.IndexImpl, map[int64]uint32, map[uint32]int64) {
ce.incHit()
ce.addRef()
return ce.index, ce.vecDocIDMap
return ce.index, ce.vecDocIDMap, ce.docIDVecMap
}

func (ce *cacheEntry) close() {
go func() {
ce.index.Close()
ce.index = nil
ce.vecDocIDMap = nil
ce.docIDVecMap = nil
}()
}

Expand Down
65 changes: 43 additions & 22 deletions faiss_vector_posting.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,17 @@ func (vpl *VecPostingsIterator) BytesWritten() uint64 {

// vectorIndexWrapper conforms to scorch_segment_api's VectorIndex interface
type vectorIndexWrapper struct {
search func(qVector []float32, k int64, params json.RawMessage) (segment.VecPostingsList, error)
close func()
size func() uint64
search func(qVector []float32, k int64, eligibleDocIDs *roaring.Bitmap,
params json.RawMessage) (segment.VecPostingsList, error)
close func()
size func() uint64
}

func (i *vectorIndexWrapper) Search(qVector []float32, k int64, params json.RawMessage) (
func (i *vectorIndexWrapper) Search(qVector []float32, k int64,
eligibleDocIDs *roaring.Bitmap, params json.RawMessage) (
segment.VecPostingsList, error) {
return i.search(qVector, k, params)
return i.search(qVector, k, eligibleDocIDs, params)

}

func (i *vectorIndexWrapper) Close() {
Expand All @@ -295,13 +298,15 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap
// Params needed for the closures
var vecIndex *faiss.IndexImpl
var vecDocIDMap map[int64]uint32
var vectorIDsToExclude []int64
var docVecIDMap map[uint32]int64
// var vectorIDsToExclude []int64
var fieldIDPlus1 uint16
var vecIndexSize uint64

var (
wrapVecIndex = &vectorIndexWrapper{
search: func(qVector []float32, k int64, params json.RawMessage) (segment.VecPostingsList, error) {
search: func(qVector []float32, k int64, eligibleDocIDs *roaring.Bitmap,
params json.RawMessage) (segment.VecPostingsList, error) {
// 1. returned postings list (of type PostingsList) has two types of information - docNum and its score.
// 2. both the values can be represented using roaring bitmaps.
// 3. the Iterator (of type PostingsIterator) returned would operate in terms of VecPostings.
Expand All @@ -318,20 +323,36 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap
return rv, nil
}

scores, ids, err := vecIndex.SearchWithoutIDs(qVector, k, vectorIDsToExclude, params)
if err != nil {
return nil, err
}
// for every similar vector returned by the Search() API, add the corresponding
// docID and the score to the newly created vecPostingsList
for i := 0; i < len(ids); i++ {
vecID := ids[i]
// Checking if it's present in the vecDocIDMap.
// If -1 is returned as an ID(insufficient vectors), this will ensure
// it isn't added to the final postings list.
if docID, ok := vecDocIDMap[vecID]; ok {
code := getVectorCode(docID, scores[i])
rv.postings.Add(uint64(code))
var scores []float32
var ids []int64
var err error

// None of the documents are eligible per the filter query.
if eligibleDocIDs.Stats().Cardinality > 0 {
// vector IDs corresponding to the local doc numbers to be
// considered for the search
vectorIDsToInclude := make([]int64, eligibleDocIDs.Stats().Cardinality)
for i, id := range eligibleDocIDs.ToArray() {
vectorIDsToInclude[int64(i)] = int64(docVecIDMap[id])
}

scores, ids, err = vecIndex.SearchWithIDs(qVector, k,
vectorIDsToInclude, params)
if err != nil {
return nil, err
}
// for every similar vector returned by the Search() API,
// add the corresponding docID and the score to the newly
// created vecPostingsList
for i := 0; i < len(ids); i++ {
vecID := ids[i]
// Checking if it's present in the vecDocIDMap.
// If -1 is returned as an ID(insufficient vectors), this will ensure
// it isn't added to the final postings list.
if docID, ok := vecDocIDMap[vecID]; ok {
code := getVectorCode(docID, scores[i])
rv.postings.Add(uint64(code))
}
}
}

Expand Down Expand Up @@ -372,7 +393,7 @@ func (sb *SegmentBase) InterpretVectorIndex(field string, except *roaring.Bitmap
pos += n
}

vecIndex, vecDocIDMap, vectorIDsToExclude, err =
vecIndex, vecDocIDMap, docVecIDMap, _, err =
sb.vecIndexCache.loadOrCreate(fieldIDPlus1, sb.mem[pos:], except)

if vecIndex != nil {
Expand Down

0 comments on commit 356757b

Please sign in to comment.