From 00a72735ec7f7ff97bd6c34ce7a99fd0100b6698 Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Thu, 13 Jun 2024 15:24:29 -0400 Subject: [PATCH] complete refactor --- pkg/hnsw/friends.go | 8 +- pkg/hnsw/friends_test.go | 10 ++- pkg/hnsw/hnsw.go | 65 +++++++-------- pkg/hnsw/hnsw_test.go | 30 +++---- pkg/hnsw/pq.go | 173 --------------------------------------- pkg/hnsw/pq_test.go | 105 ------------------------ 6 files changed, 57 insertions(+), 334 deletions(-) delete mode 100644 pkg/hnsw/pq.go delete mode 100644 pkg/hnsw/pq_test.go diff --git a/pkg/hnsw/friends.go b/pkg/hnsw/friends.go index bc78270..061b2fd 100644 --- a/pkg/hnsw/friends.go +++ b/pkg/hnsw/friends.go @@ -8,15 +8,15 @@ import ( type Point []float32 type Friends struct { - friends []*BaseQueue + friends []*DistHeap } // NewFriends creates a new vector, note the max level is inclusive. func NewFriends(topLevel int) *Friends { - friends := make([]*BaseQueue, topLevel+1) + friends := make([]*DistHeap, topLevel+1) for i := 0; i <= topLevel; i++ { - friends[i] = NewBaseQueue(MinComparator{}) + friends[i] = NewDistHeap() } return &Friends{ @@ -51,7 +51,7 @@ func (v *Friends) InsertFriendsAtLevel(level int, friendId Id, dist float32) { } } -func (v *Friends) GetFriendsAtLevel(level int) (*BaseQueue, error) { +func (v *Friends) GetFriendsAtLevel(level int) (*DistHeap, error) { if !v.HasLevel(level) { return nil, errors.New("failed to get friends at level") } diff --git a/pkg/hnsw/friends_test.go b/pkg/hnsw/friends_test.go index ea52d67..a39f146 100644 --- a/pkg/hnsw/friends_test.go +++ b/pkg/hnsw/friends_test.go @@ -69,12 +69,18 @@ func TestVector_LevelManagement(t *testing.T) { t.Fatalf("expected hex and oct friends list at level %v to be 1, got: %v || %v", i, hexFriends.Len(), octFriends.Len()) } - top := hexFriends.Top() + top, err := hexFriends.PeekMinItem() + if err != nil { + t.Fatal(err) + } if top.id != octId { t.Fatalf("expected %v, got %v", octId, top.id) } - top = octFriends.Top() + top, err = octFriends.PeekMinItem() + if err != nil { + t.Fatal(err) + } if top.id != hexId { t.Fatalf("expected %v, got %v", hexId, top.id) } diff --git a/pkg/hnsw/hnsw.go b/pkg/hnsw/hnsw.go index 99c9d9b..9416df6 100644 --- a/pkg/hnsw/hnsw.go +++ b/pkg/hnsw/hnsw.go @@ -60,23 +60,26 @@ func (h *Hnsw) GenerateId() Id { return Id(len(h.points)) } -func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, level int) (*BaseQueue, error) { +func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, level int) (*DistHeap, error) { visited := make([]bool, len(h.friends)+1) - candidatesForQ := NewBaseQueue(MinComparator{}) - foundNNToQ := NewBaseQueue(MaxComparator{}) + candidatesForQ := NewDistHeap() + foundNNToQ := NewDistHeap() // this is a max // note entryItem.dist should be the distance to Q candidatesForQ.Insert(entryItem.id, entryItem.dist) foundNNToQ.Insert(entryItem.id, entryItem.dist) for !candidatesForQ.IsEmpty() { - closestCandidate, err := candidatesForQ.PopItem() + closestCandidate, err := candidatesForQ.PopMinItem() if err != nil { return nil, fmt.Errorf("error during searching level %d: %w", level, err) } - furthestFoundNN := foundNNToQ.Top() + furthestFoundNN, err := foundNNToQ.PeekMaxItem() + if err != nil { + return nil, fmt.Errorf("error during searching level %d: %w", level, err) + } // if distance(c, q) > distance(f, q) if closestCandidate.dist > furthestFoundNN.dist { @@ -94,7 +97,10 @@ func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, lev if !visited[ccFriendId] { visited[ccFriendId] = true - furthestFoundNN = foundNNToQ.Top() + furthestFoundNN, err = foundNNToQ.PeekMaxItem() + if err != nil { + return nil, fmt.Errorf("error during searching level %d: %w", level, err) + } ccFriendPoint, ok := h.points[ccFriendId] if !ok { @@ -108,7 +114,7 @@ func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, lev foundNNToQ.Insert(ccFriendId, ccFriendDistToQ) if foundNNToQ.Len() > numNearestToQToReturn { - if _, err = foundNNToQ.PopItem(); err != nil { + if _, err = foundNNToQ.PopMaxItem(); err != nil { return nil, fmt.Errorf("error during searching level %d: %w", level, err) } } @@ -118,7 +124,7 @@ func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, lev } - return FromBaseQueue(foundNNToQ, MinComparator{}), nil + return foundNNToQ, nil } func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item { @@ -141,7 +147,7 @@ func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item { return epItem } - newEpItem, err := closestNeighborsToQ.PopItem() + newEpItem, err := closestNeighborsToQ.PopMinItem() if err != nil { panic(err) } @@ -152,7 +158,7 @@ func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item { return epItem } -func (h *Hnsw) selectNeighbors(nearestNeighbors *BaseQueue) ([]*Item, error) { +func (h *Hnsw) selectNeighbors(nearestNeighbors *DistHeap) ([]*Item, error) { if nearestNeighbors.Len() <= h.M { return nearestNeighbors.items, nil } @@ -160,7 +166,7 @@ func (h *Hnsw) selectNeighbors(nearestNeighbors *BaseQueue) ([]*Item, error) { nearestItems := make([]*Item, h.M) for i := 0; i < h.M; i++ { - nearestItem, err := nearestNeighbors.PopItem() + nearestItem, err := nearestNeighbors.PopMinItem() if err != nil { return nil, err @@ -221,19 +227,17 @@ func (h *Hnsw) InsertVector(q Point) error { return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) } - maxNeighborsFriendsAtLevel := FromBaseQueue(neighborFriendsAtLevel, MaxComparator{}) - - for maxNeighborsFriendsAtLevel.Len() > h.M { - _, err = maxNeighborsFriendsAtLevel.PopItem() + for neighborFriendsAtLevel.Len() > h.M { + _, err := neighborFriendsAtLevel.PopMaxItem() if err != nil { return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) } } - h.friends[neighbor.id].friends[level] = FromBaseQueue(maxNeighborsFriendsAtLevel, MinComparator{}) + h.friends[neighbor.id].friends[level] = neighborFriendsAtLevel } - newEntryItem, err := nnToQAtLevel.PopItem() + newEntryItem, err := nnToQAtLevel.PopMinItem() if err != nil { return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) } @@ -252,7 +256,7 @@ func (h *Hnsw) isValidPoint(point Point) bool { return len(point) == h.vectorDimensionality } -func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*BaseQueue, error) { +func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*DistHeap, error) { entryPoint, ok := h.points[h.entryPointId] if !ok { @@ -278,7 +282,10 @@ func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*BaseQueue, error) return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) } - entryItem = nearestNeighborQueueAtLevel.Top() + entryItem, err = nearestNeighborQueueAtLevel.PeekMinItem() + if err != nil { + return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) + } } // level 0 @@ -287,24 +294,12 @@ func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*BaseQueue, error) return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %d", h.entryPointId, 0) } - if nearestNeighborQueueAtLevel0.Len() <= numNeighborsToReturn { - return nearestNeighborQueueAtLevel0, nil - } - - var items []*Item - - for !nearestNeighborQueueAtLevel0.IsEmpty() { - if len(items) == numNeighborsToReturn { - return FromItems(items, MinComparator{}), nil - } - - nearestNeighborAtLevel0Item, err := nearestNeighborQueueAtLevel0.PopItem() + for nearestNeighborQueueAtLevel0.Len() > numNeighborsToReturn { + _, err := nearestNeighborQueueAtLevel0.PopMaxItem() if err != nil { - return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %d", h.entryPointId, 0) + return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", h.entryPointId, err) } - - items = append(items, nearestNeighborAtLevel0Item) } - return FromItems(items, MinComparator{}), nil + return nearestNeighborQueueAtLevel0, nil } diff --git a/pkg/hnsw/hnsw_test.go b/pkg/hnsw/hnsw_test.go index ca36756..905bb1c 100644 --- a/pkg/hnsw/hnsw_test.go +++ b/pkg/hnsw/hnsw_test.go @@ -104,7 +104,7 @@ func TestHnsw_SearchLevel(t *testing.T) { t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor) } - closestItem, err := closestNeighbor.PopItem() + closestItem, err := closestNeighbor.PopMinItem() if err != nil { t.Fatal(err) @@ -149,7 +149,7 @@ func TestHnsw_SearchLevel(t *testing.T) { t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor) } - closestItem, err := closestNeighbor.PopItem() + closestItem, err := closestNeighbor.PopMinItem() if err != nil { t.Fatal(err) @@ -188,7 +188,7 @@ func TestHnsw_SearchLevel(t *testing.T) { t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor) } - closestItem, err := closestNeighbor.PopItem() + closestItem, err := closestNeighbor.PopMinItem() if err != nil { t.Fatal(err) @@ -218,7 +218,7 @@ func TestHnsw_SearchLevel(t *testing.T) { t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor.Len()) } - closestItem, err := closestNeighbor.PopItem() + closestItem, err := closestNeighbor.PopMinItem() if err != nil { t.Fatal(err) } @@ -253,7 +253,7 @@ func TestHnsw_SearchLevel(t *testing.T) { var closestIds []Id for !closestNeighbor.IsEmpty() { - closestItem, err := closestNeighbor.PopItem() + closestItem, err := closestNeighbor.PopMinItem() if err != nil { t.Fatal(err) } @@ -330,7 +330,7 @@ func TestHnsw_FindCloserEntryPoint(t *testing.T) { func TestHnsw_SelectNeighbors(t *testing.T) { t.Run("selects neighbors given overflow", func(t *testing.T) { - nearestNeighbors := NewBaseQueue(MinComparator{}) + nearestNeighbors := NewDistHeap() M := 4 @@ -352,7 +352,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) { } // for the sake of testing, let's rebuild the pq and assert ids are correct - reneighbors := NewBaseQueue(MinComparator{}) + reneighbors := NewDistHeap() for _, item := range neighbors { reneighbors.Insert(item.id, item.dist) @@ -360,7 +360,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) { expectedId := Id(0) for !reneighbors.IsEmpty() { - nn, err := reneighbors.PopItem() + nn, err := reneighbors.PopMinItem() if err != nil { t.Fatal(err) @@ -378,7 +378,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) { M := 10 h := NewHnsw(2, 10, M, Point{0, 0}) - nnQueue := NewBaseQueue(MinComparator{}) + nnQueue := NewDistHeap() for i := 0; i < 3; i++ { nnQueue.Insert(Id(i), float32(i)) @@ -394,7 +394,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) { t.Fatalf("select neighbors should have at least 3 neighbors, got: %v", len(neighbors)) } - reneighbors := NewBaseQueue(MinComparator{}) + reneighbors := NewDistHeap() for _, item := range neighbors { reneighbors.Insert(item.id, item.dist) @@ -402,7 +402,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) { expectedId := Id(0) for !reneighbors.IsEmpty() { - nn, err := reneighbors.PopItem() + nn, err := reneighbors.PopMinItem() if err != nil { t.Fatal(err) @@ -510,7 +510,7 @@ func TestHnsw_KnnSearch(t *testing.T) { expectedId := Id(3) for !nearestNeighbors.IsEmpty() { - nearestNeighbor, err := nearestNeighbors.PopItem() + nearestNeighbor, err := nearestNeighbors.PopMinItem() if err != nil { t.Fatalf("failed to pop item: %v, err: %v", nearestNeighbors, err) } @@ -546,7 +546,7 @@ func TestHnsw_KnnSearch(t *testing.T) { var gotIds []Id for !closestToQ.IsEmpty() { - closest, err := closestToQ.PopItem() + closest, err := closestToQ.PopMinItem() if err != nil { t.Fatalf("failed to pop item: %v, err: %v", closestToQ, err) } @@ -587,7 +587,7 @@ func TestHnsw_KnnSearch(t *testing.T) { var got []Id for !closestNeighbors.IsEmpty() { - closest, err := closestNeighbors.PopItem() + closest, err := closestNeighbors.PopMinItem() if err != nil { t.Fatalf("failed to pop item: %v, err: %v", closestNeighbors, err) } @@ -621,7 +621,7 @@ func TestHnsw_KnnSearch(t *testing.T) { expectedId := Id(0) for found.IsEmpty() { - nnItem, err := found.PopItem() + nnItem, err := found.PopMinItem() if err != nil { t.Fatalf("failed to pop item: %v, err: %v", found, err) } diff --git a/pkg/hnsw/pq.go b/pkg/hnsw/pq.go deleted file mode 100644 index 8f5c0e0..0000000 --- a/pkg/hnsw/pq.go +++ /dev/null @@ -1,173 +0,0 @@ -package hnsw - -import ( - "container/heap" - "fmt" -) - -type Comparator interface { - Less(i, j *Item) bool -} - -// MaxComparator implements the Comparator interface for a max-heap. -type MaxComparator struct{} - -func (c MaxComparator) Less(i, j *Item) bool { - return i.dist > j.dist -} - -// MinComparator implements the Comparator interface for a min-heap. -type MinComparator struct{} - -func (c MinComparator) Less(i, j *Item) bool { - return i.dist < j.dist -} - -type Item struct { - id Id - dist float32 - index int -} - -type Heapy interface { - heap.Interface - Insert(id Id, dist float32) - IsEmpty() bool - Len() int - PopItem() *Item - Top() *Item - Take(count int) (*BaseQueue, error) - update(item *Item, id Id, dist float32) -} - -// Nothing from BaseQueue should be used. Only use the Max and Min queue. -// BaseQueue isn't even a heap! It misses the Less() method which the Min/Max queue implement. -type BaseQueue struct { - visitedIds map[Id]*Item - items []*Item - comparator Comparator -} - -func (bq *BaseQueue) Take(count int, comparator Comparator) (*BaseQueue, error) { - if len(bq.items) < count { - return nil, fmt.Errorf("queue only has %v items, but want to take %v", len(bq.items), count) - } - - pq := NewBaseQueue(comparator) - - ct := 0 - for { - if ct == count { - break - } - - peeled, err := bq.PopItem() - if err != nil { - return nil, err - } - - pq.Insert(peeled.id, peeled.dist) - - ct++ - } - - return pq, nil -} - -func (bq BaseQueue) Len() int { return len(bq.items) } -func (bq BaseQueue) Swap(i, j int) { - pq := bq.items - pq[i], pq[j] = pq[j], pq[i] - pq[i].index = i - pq[j].index = j -} - -func (bq *BaseQueue) Push(x any) { - n := len(bq.items) - item := x.(*Item) - item.index = n - bq.items = append(bq.items, item) -} - -func (bq *BaseQueue) Top() *Item { - if len(bq.items) == 0 { - return nil - } - return bq.items[0] -} - -func (bq *BaseQueue) IsEmpty() bool { - return len(bq.items) == 0 -} - -func (bq *BaseQueue) Pop() any { - old := bq.items - n := len(old) - item := old[n-1] - old[n-1] = nil - item.index = -1 - bq.items = old[0 : n-1] - return item -} - -func (bq *BaseQueue) Less(i, j int) bool { - return bq.comparator.Less(bq.items[i], bq.items[j]) -} - -func (bq *BaseQueue) Insert(id Id, dist float32) { - if item, ok := bq.visitedIds[id]; ok { - bq.update(item, id, dist) - return - } - - newItem := Item{id: id, dist: dist} - heap.Push(bq, &newItem) - bq.visitedIds[id] = &newItem - -} - -func NewBaseQueue(comparator Comparator) *BaseQueue { - bq := &BaseQueue{ - visitedIds: map[Id]*Item{}, - comparator: comparator, - } - heap.Init(bq) - return bq -} - -func (bq *BaseQueue) PopItem() (*Item, error) { - if bq.Len() == 0 { - return nil, fmt.Errorf("no items to peel") - } - popped := heap.Pop(bq).(*Item) - delete(bq.visitedIds, popped.id) - return popped, nil -} - -func (bq *BaseQueue) update(item *Item, id Id, dist float32) { - item.id = id - item.dist = dist - heap.Fix(bq, item.index) -} - -func FromBaseQueue(bq *BaseQueue, comparator Comparator) *BaseQueue { - newBq := NewBaseQueue(comparator) - - for _, item := range bq.items { - newBq.Insert(item.id, item.dist) - } - - return newBq -} - -func FromItems(items []*Item, comparator Comparator) *BaseQueue { - bq := &BaseQueue{ - visitedIds: map[Id]*Item{}, - items: items, - comparator: comparator, - } - - heap.Init(bq) - - return bq -} diff --git a/pkg/hnsw/pq_test.go b/pkg/hnsw/pq_test.go deleted file mode 100644 index 5bc1539..0000000 --- a/pkg/hnsw/pq_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package hnsw - -import ( - "testing" -) - -func TestPQ(t *testing.T) { - - t.Run("bricks and ladders || min heap", func(t *testing.T) { - type Case struct { - heights []int - bricks int - ladders int - expected int - } - - cases := [3]Case{ - { - heights: []int{4, 2, 7, 6, 9, 14, 12}, - bricks: 5, - ladders: 1, - expected: 4, - }, - { - heights: []int{4, 12, 2, 7, 3, 18, 20, 3, 19}, - bricks: 10, - ladders: 2, - expected: 7, - }, - { - heights: []int{14, 3, 19, 3}, - bricks: 17, - ladders: 0, - expected: 3, - }, - } - - for _, c := range cases { - res, err := furthestBuildings(c.heights, c.bricks, c.ladders) - if err != nil { - t.Fatal(err) - } - - if res != c.expected { - t.Errorf("got %d, want %d", res, c.expected) - } - } - }) - - t.Run("interchange", func(t *testing.T) { - bq := NewBaseQueue(MinComparator{}) - for i := 0; i < 100; i++ { - bq.Insert(Id(i), float32(i)) - } - - incBq := FromBaseQueue(bq, MaxComparator{}) - - i := Id(99) - for !incBq.IsEmpty() { - item, err := incBq.PopItem() - if err != nil { - t.Fatal(err) - } - - if item.id != i { - t.Fatalf("got %d, want %d", item.id, i) - } - - i -= 1 - } - }) -} - -func furthestBuildings(heights []int, bricks, ladders int) (int, error) { - - ladderJumps := NewBaseQueue(MinComparator{}) - - for idx := 0; idx < len(heights)-1; idx++ { - height := heights[idx] - nextHeight := heights[idx+1] - - if height >= nextHeight { - continue - } - - jump := nextHeight - height - - ladderJumps.Insert(Id(idx), float32(jump)) - - if ladderJumps.Len() > ladders { - minLadderJump, err := ladderJumps.PopItem() - if err != nil { - return -1, err - } - - if bricks-int(minLadderJump.dist) < 0 { - return idx, nil - } - - bricks -= int(minLadderJump.dist) - } - } - - return len(heights) - 1, nil -}