Skip to content

Commit

Permalink
feat: knnsearch (#331)
Browse files Browse the repository at this point in the history
  • Loading branch information
friendlymatthew authored Jun 7, 2024
1 parent ca3bc25 commit 715410f
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.
57 changes: 57 additions & 0 deletions pkg/hnsw/hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,60 @@ func (h *Hnsw) InsertVector(q Point) error {
func (h *Hnsw) isValidPoint(point Point) bool {
return len(point) == h.vectorDimensionality
}

func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*BaseQueue, error) {
entryPoint, ok := h.points[h.entryPointId]

if !ok {
return nil, fmt.Errorf("no point found for entry point %v", h.entryPointId)
}

entryPointFriends, ok := h.friends[h.entryPointId]
if !ok {
return nil, fmt.Errorf("no friends found for entry point %v", h.entryPointId)
}

entryPointTopLevel := entryPointFriends.TopLevel()

entryItem := &Item{
id: h.entryPointId,
dist: EuclidDistance(q, *entryPoint),
}

for level := entryPointTopLevel; level > 0; level-- {
nearestNeighborQueueAtLevel, err := h.searchLevel(&q, entryItem, 1, level)

if err != nil {
return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}

entryItem = nearestNeighborQueueAtLevel.Top()
}

// level 0
nearestNeighborQueueAtLevel0, err := h.searchLevel(&q, entryItem, h.efConstruction, 0)
if err != nil {
return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %d", h.entryPointId, 0)
}

if nearestNeighborQueueAtLevel0.Len() <= numNeighborsToReturn {
return nearestNeighborQueueAtLevel0, nil
}

var items []*Item

for !nearestNeighborQueueAtLevel0.IsEmpty() {
if len(items) == numNeighborsToReturn {
return FromItems(items, MinComparator{}), nil
}

nearestNeighborAtLevel0Item, err := nearestNeighborQueueAtLevel0.PopItem()
if err != nil {
return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %d", h.entryPointId, 0)
}

items = append(items, nearestNeighborAtLevel0Item)
}

return FromItems(items, MinComparator{}), nil
}
45 changes: 45 additions & 0 deletions pkg/hnsw/hnsw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -470,3 +470,48 @@ func TestHnsw_InsertVector(t *testing.T) {

})
}

func TestHnsw_KnnSearch(t *testing.T) {
t.Run("basic search knn", func(t *testing.T) {
h := NewHnsw(2, 4, 4, Point{0, 0})

// id: 1
if err := h.InsertVector(Point{3, 3}); err != nil {
t.Fatalf("failed to insert point: %v, err: %v", Point{3, 3}, err)
}

// id: 2
if err := h.InsertVector(Point{4, 4}); err != nil {
t.Fatalf("failed to insert point %v, err: %v", Point{4, 4}, err)
}

// id: 3
if err := h.InsertVector(Point{5, 5}); err != nil {
t.Fatalf("failed to insert point %v, err: %v", Point{5, 5}, err)
}

nearestNeighbors, err := h.KnnSearch(Point{5, 5}, 3)
if err != nil {
t.Fatal(err)
}

if nearestNeighbors.Len() != 3 {
t.Fatalf("expected to have 3 neighbors, got %v", nearestNeighbors)
}

expectedId := Id(3)

for !nearestNeighbors.IsEmpty() {
nearestNeighbor, err := nearestNeighbors.PopItem()
if err != nil {
t.Fatalf("failed to pop item: %v, err: %v", nearestNeighbors, err)
}

if nearestNeighbor.id != expectedId {
t.Fatalf("expected item to be %v, got %v", expectedId, nearestNeighbor.id)
}

expectedId -= 1
}
})
}
12 changes: 12 additions & 0 deletions pkg/hnsw/pq.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,15 @@ func FromBaseQueue(bq *BaseQueue, comparator Comparator) *BaseQueue {

return newBq
}

func FromItems(items []*Item, comparator Comparator) *BaseQueue {
bq := &BaseQueue{
visitedIds: map[Id]*Item{},
items: items,
comparator: comparator,
}

heap.Init(bq)

return bq
}

0 comments on commit 715410f

Please sign in to comment.