From 430a4834719da6144166398e71a1fab76c8a7e58 Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Fri, 14 Jun 2024 09:57:52 -0400 Subject: [PATCH] Distheap (#342) * minmax dist heap * complete refactor * feat: distheap * fix * refactor --- pkg/hnsw/friends.go | 8 +- pkg/hnsw/friends_test.go | 10 ++- pkg/hnsw/heap.go | 140 +++++++++++++++++++++++++++++++ pkg/hnsw/heap_test.go | 174 +++++++++++++++++++++++++++++++++++++++ pkg/hnsw/hnsw.go | 65 +++++++-------- pkg/hnsw/hnsw_test.go | 30 +++---- pkg/hnsw/minmax.go | 101 +++-------------------- pkg/hnsw/pq.go | 173 -------------------------------------- pkg/hnsw/pq_test.go | 105 ----------------------- 9 files changed, 383 insertions(+), 423 deletions(-) create mode 100644 pkg/hnsw/heap.go create mode 100644 pkg/hnsw/heap_test.go delete mode 100644 pkg/hnsw/pq.go delete mode 100644 pkg/hnsw/pq_test.go diff --git a/pkg/hnsw/friends.go b/pkg/hnsw/friends.go index bc78270..061b2fd 100644 --- a/pkg/hnsw/friends.go +++ b/pkg/hnsw/friends.go @@ -8,15 +8,15 @@ import ( type Point []float32 type Friends struct { - friends []*BaseQueue + friends []*DistHeap } // NewFriends creates a new vector, note the max level is inclusive. func NewFriends(topLevel int) *Friends { - friends := make([]*BaseQueue, topLevel+1) + friends := make([]*DistHeap, topLevel+1) for i := 0; i <= topLevel; i++ { - friends[i] = NewBaseQueue(MinComparator{}) + friends[i] = NewDistHeap() } return &Friends{ @@ -51,7 +51,7 @@ func (v *Friends) InsertFriendsAtLevel(level int, friendId Id, dist float32) { } } -func (v *Friends) GetFriendsAtLevel(level int) (*BaseQueue, error) { +func (v *Friends) GetFriendsAtLevel(level int) (*DistHeap, error) { if !v.HasLevel(level) { return nil, errors.New("failed to get friends at level") } diff --git a/pkg/hnsw/friends_test.go b/pkg/hnsw/friends_test.go index ea52d67..a39f146 100644 --- a/pkg/hnsw/friends_test.go +++ b/pkg/hnsw/friends_test.go @@ -69,12 +69,18 @@ func TestVector_LevelManagement(t *testing.T) { t.Fatalf("expected hex and oct friends list at level %v to be 1, got: %v || %v", i, hexFriends.Len(), octFriends.Len()) } - top := hexFriends.Top() + top, err := hexFriends.PeekMinItem() + if err != nil { + t.Fatal(err) + } if top.id != octId { t.Fatalf("expected %v, got %v", octId, top.id) } - top = octFriends.Top() + top, err = octFriends.PeekMinItem() + if err != nil { + t.Fatal(err) + } if top.id != hexId { t.Fatalf("expected %v, got %v", hexId, top.id) } diff --git a/pkg/hnsw/heap.go b/pkg/hnsw/heap.go new file mode 100644 index 0000000..b7ef650 --- /dev/null +++ b/pkg/hnsw/heap.go @@ -0,0 +1,140 @@ +package hnsw + +import ( + "fmt" +) + +type Item struct { + id Id + dist float32 +} + +var EmptyHeapError = fmt.Errorf("Empty Heap") + +type DistHeap struct { + items []*Item + visited map[Id]bool +} + +func NewDistHeap() *DistHeap { + d := &DistHeap{ + items: make([]*Item, 0), + visited: make(map[Id]bool), + } + return d +} +func FromItems(items []*Item) *DistHeap { + visited := make(map[Id]bool) + for _, item := range items { + visited[item.id] = true + } + + d := &DistHeap{items: items, visited: visited} + d.Init() + + return d +} + +func (d *DistHeap) Init() { + n := d.Len() + for i := n/2 - 1; i >= 0; i-- { + d.down(i, n) + } +} + +func (d *DistHeap) PeekMinItem() (*Item, error) { + if d.IsEmpty() { + return nil, EmptyHeapError + } + + return d.items[0], nil +} +func (d *DistHeap) PeekMaxItem() (*Item, error) { + if d.Len() == 0 { + return nil, EmptyHeapError + } + + // Find the maximum element without removing it + n := d.Len() + + i := 0 + l := lchild(0) + if l < n && !d.Less(l, i) { + i = l + } + + r := rchild(0) + if r < n && !d.Less(r, i) { + i = r + } + + return d.items[i], nil +} +func (d *DistHeap) PopMinItem() (*Item, error) { + if d.IsEmpty() { + return nil, EmptyHeapError + } + + n := d.Len() - 1 + d.Swap(0, n) + d.down(0, n) + return d.Pop(), nil +} +func (d *DistHeap) PopMaxItem() (*Item, error) { + if d.IsEmpty() { + return nil, EmptyHeapError + } + + n := d.Len() + i := 0 + l := lchild(0) + + if l < n && !d.Less(l, i) { + i = l + } + + r := rchild(0) + if r < n && !d.Less(r, i) { + i = r + } + + d.Swap(i, n-1) + d.down(i, n-1) + + return d.Pop(), nil +} +func (d *DistHeap) Insert(id Id, dist float32) { + if d.visited[id] { + for idx, item := range d.items { + if item.id == id { + item.dist = dist + d.Fix(idx) + return + } + } + } else { + d.Push(&Item{id: id, dist: dist}) + d.up(d.Len() - 1) + d.visited[id] = true + } +} +func (d *DistHeap) Fix(i int) { + if !d.down(i, d.Len()) { + d.up(i) + } +} + +func (d DistHeap) IsEmpty() bool { return len(d.items) == 0 } +func (d DistHeap) Len() int { return len(d.items) } +func (d DistHeap) Less(i, j int) bool { return d.items[i].dist < d.items[j].dist } +func (d DistHeap) Swap(i, j int) { d.items[i], d.items[j] = d.items[j], d.items[i] } +func (d *DistHeap) Push(x *Item) { + (*d).items = append((*d).items, x) +} +func (d *DistHeap) Pop() *Item { + old := (*d).items + n := len(old) + x := old[n-1] + (*d).items = old[0 : n-1] + return x +} diff --git a/pkg/hnsw/heap_test.go b/pkg/hnsw/heap_test.go new file mode 100644 index 0000000..ec131c7 --- /dev/null +++ b/pkg/hnsw/heap_test.go @@ -0,0 +1,174 @@ +package hnsw + +import "testing" + +func TestHeap(t *testing.T) { + + t.Run("basic min max properties", func(t *testing.T) { + h := NewDistHeap() + + for i := 10; i > 0; i-- { + h.Insert(Id(i), float32(10-i)) + } + + if h.Len() != 10 { + t.Fatalf("heap length should be 10, got %v", h.Len()) + } + + expectedId := Id(10) + for !h.IsEmpty() { + peekMinItem, err := h.PeekMinItem() + if err != nil { + t.Fatalf("failed to peek min item: %v", err) + } + + minItem, err := h.PopMinItem() + if err != nil { + t.Fatalf("failed to pop min item, err: %v", err) + } + + if peekMinItem.id != minItem.id { + t.Fatalf("mismatched item id, expected %v, got %v", expectedId, peekMinItem.id) + } + + if minItem.id != expectedId { + t.Fatalf("mismatched ids, expected %v, got: %v", expectedId, minItem.id) + } + + expectedId -= 1 + } + }) + + t.Run("basic min max properties 2", func(t *testing.T) { + h := NewDistHeap() + + for i := 0; i <= 10; i++ { + h.Insert(Id(i), float32(10-i)) + } + + maxExpectedId := Id(0) + minExpectedId := Id(10) + + for !h.IsEmpty() { + peekMaxItem, err := h.PeekMaxItem() + + if err != nil { + t.Fatalf("failed to peek max item, err: %v", err) + } + + maxItem, err := h.PopMaxItem() + + if err != nil { + t.Fatalf("failed to pop max item, err: %v", err) + } + + if peekMaxItem.id != maxItem.id { + t.Fatalf("mismatched max ids, expected %v, got: %v", maxItem.id, peekMaxItem.id) + } + + if maxItem.id != maxExpectedId { + t.Fatalf("expected id to be %v, got %v", maxExpectedId, maxItem.id) + } + + if h.IsEmpty() { + continue + } + + peekMinItem, err := h.PeekMinItem() + if err != nil { + t.Fatalf("failed to peek min item, err: %v", err) + } + + minItem, err := h.PopMinItem() + + if err != nil { + t.Fatalf("failed to pop min item, err: %v", err) + } + + if peekMinItem.id != minItem.id { + t.Fatalf("mismatched min ids, expected %v, got: %v", maxItem.id, peekMaxItem.id) + } + + if minItem.id != minExpectedId { + t.Fatalf("expected id to be %v, got %v", minExpectedId, minItem.id) + } + + minExpectedId -= 1 + maxExpectedId += 1 + } + }) + + t.Run("bricks and ladders || min heap", func(t *testing.T) { + type Case struct { + heights []int + bricks int + ladders int + expected int + } + + cases := [3]Case{ + { + heights: []int{4, 2, 7, 6, 9, 14, 12}, + bricks: 5, + ladders: 1, + expected: 4, + }, + { + heights: []int{4, 12, 2, 7, 3, 18, 20, 3, 19}, + bricks: 10, + ladders: 2, + expected: 7, + }, + { + heights: []int{14, 3, 19, 3}, + bricks: 17, + ladders: 0, + expected: 3, + }, + } + + for _, c := range cases { + res, err := furthestBuildings(c.heights, c.bricks, c.ladders) + if err != nil { + t.Fatal(err) + } + + if res != c.expected { + t.Errorf("got %d, want %d", res, c.expected) + } + } + }) +} + +func furthestBuildings(heights []int, bricks, ladders int) (int, error) { + + ladderJumps := NewDistHeap() + + for idx := 0; idx < len(heights)-1; idx++ { + height := heights[idx] + nextHeight := heights[idx+1] + + if height >= nextHeight { + continue + } + + jump := nextHeight - height + + ladderJumps.Insert(Id(idx), float32(jump)) + + if ladderJumps.Len() > ladders { + minLadderJump, err := ladderJumps.PopMinItem() + if err != nil { + return -1, err + } + + if bricks-int(minLadderJump.dist) < 0 { + return idx, nil + } + + bricks -= int(minLadderJump.dist) + } + } + + return len(heights) - 1, nil +} diff --git a/pkg/hnsw/hnsw.go b/pkg/hnsw/hnsw.go index 99c9d9b..9416df6 100644 --- a/pkg/hnsw/hnsw.go +++ b/pkg/hnsw/hnsw.go @@ -60,23 +60,26 @@ func (h *Hnsw) GenerateId() Id { return Id(len(h.points)) } -func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, level int) (*BaseQueue, error) { +func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, level int) (*DistHeap, error) { visited := make([]bool, len(h.friends)+1) - candidatesForQ := NewBaseQueue(MinComparator{}) - foundNNToQ := NewBaseQueue(MaxComparator{}) + candidatesForQ := NewDistHeap() + foundNNToQ := NewDistHeap() // this is a max // note entryItem.dist should be the distance to Q candidatesForQ.Insert(entryItem.id, entryItem.dist) foundNNToQ.Insert(entryItem.id, entryItem.dist) for !candidatesForQ.IsEmpty() { - closestCandidate, err := candidatesForQ.PopItem() + closestCandidate, err := candidatesForQ.PopMinItem() if err != nil { return nil, fmt.Errorf("error during searching level %d: %w", level, err) } - furthestFoundNN := foundNNToQ.Top() + furthestFoundNN, err := foundNNToQ.PeekMaxItem() + if err != nil { + return nil, fmt.Errorf("error during searching level %d: %w", level, err) + } // if distance(c, q) > distance(f, q) if closestCandidate.dist > furthestFoundNN.dist { @@ -94,7 +97,10 @@ func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, lev if !visited[ccFriendId] { visited[ccFriendId] = true - furthestFoundNN = foundNNToQ.Top() + furthestFoundNN, err = foundNNToQ.PeekMaxItem() + if err != nil { + return nil, fmt.Errorf("error during searching level %d: %w", level, err) + } ccFriendPoint, ok := h.points[ccFriendId] if !ok { @@ -108,7 +114,7 @@ func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, lev foundNNToQ.Insert(ccFriendId, ccFriendDistToQ) if foundNNToQ.Len() > numNearestToQToReturn { - if _, err = foundNNToQ.PopItem(); err != nil { + if _, err = foundNNToQ.PopMaxItem(); err != nil { return nil, fmt.Errorf("error during searching level %d: %w", level, err) } } @@ -118,7 +124,7 @@ func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, lev } - return FromBaseQueue(foundNNToQ, MinComparator{}), nil + return foundNNToQ, nil } func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item { @@ -141,7 +147,7 @@ func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item { return epItem } - newEpItem, err := closestNeighborsToQ.PopItem() + newEpItem, err := closestNeighborsToQ.PopMinItem() if err != nil { panic(err) } @@ -152,7 +158,7 @@ func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item { return epItem } -func (h *Hnsw) selectNeighbors(nearestNeighbors *BaseQueue) ([]*Item, error) { +func (h *Hnsw) selectNeighbors(nearestNeighbors *DistHeap) ([]*Item, error) { if nearestNeighbors.Len() <= h.M { return nearestNeighbors.items, nil } @@ -160,7 +166,7 @@ func (h *Hnsw) selectNeighbors(nearestNeighbors *BaseQueue) ([]*Item, error) { nearestItems := make([]*Item, h.M) for i := 0; i < h.M; i++ { - nearestItem, err := nearestNeighbors.PopItem() + nearestItem, err := nearestNeighbors.PopMinItem() if err != nil { return nil, err @@ -221,19 +227,17 @@ func (h *Hnsw) InsertVector(q Point) error { return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) } - maxNeighborsFriendsAtLevel := FromBaseQueue(neighborFriendsAtLevel, MaxComparator{}) - - for maxNeighborsFriendsAtLevel.Len() > h.M { - _, err = maxNeighborsFriendsAtLevel.PopItem() + for neighborFriendsAtLevel.Len() > h.M { + _, err := neighborFriendsAtLevel.PopMaxItem() if err != nil { return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) } } - h.friends[neighbor.id].friends[level] = FromBaseQueue(maxNeighborsFriendsAtLevel, MinComparator{}) + h.friends[neighbor.id].friends[level] = neighborFriendsAtLevel } - newEntryItem, err := nnToQAtLevel.PopItem() + newEntryItem, err := nnToQAtLevel.PopMinItem() if err != nil { return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) } @@ -252,7 +256,7 @@ func (h *Hnsw) isValidPoint(point Point) bool { return len(point) == h.vectorDimensionality } -func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*BaseQueue, error) { +func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*DistHeap, error) { entryPoint, ok := h.points[h.entryPointId] if !ok { @@ -278,7 +282,10 @@ func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*BaseQueue, error) return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) } - entryItem = nearestNeighborQueueAtLevel.Top() + entryItem, err = nearestNeighborQueueAtLevel.PeekMinItem() + if err != nil { + return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) + } } // level 0 @@ -287,24 +294,12 @@ func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*BaseQueue, error) return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %d", h.entryPointId, 0) } - if nearestNeighborQueueAtLevel0.Len() <= numNeighborsToReturn { - return nearestNeighborQueueAtLevel0, nil - } - - var items []*Item - - for !nearestNeighborQueueAtLevel0.IsEmpty() { - if len(items) == numNeighborsToReturn { - return FromItems(items, MinComparator{}), nil - } - - nearestNeighborAtLevel0Item, err := nearestNeighborQueueAtLevel0.PopItem() + for nearestNeighborQueueAtLevel0.Len() > numNeighborsToReturn { + _, err := nearestNeighborQueueAtLevel0.PopMaxItem() if err != nil { - return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %d", h.entryPointId, 0) + return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", h.entryPointId, err) } - - items = append(items, nearestNeighborAtLevel0Item) } - return FromItems(items, MinComparator{}), nil + return nearestNeighborQueueAtLevel0, nil } diff --git a/pkg/hnsw/hnsw_test.go b/pkg/hnsw/hnsw_test.go index ca36756..905bb1c 100644 --- a/pkg/hnsw/hnsw_test.go +++ b/pkg/hnsw/hnsw_test.go @@ -104,7 +104,7 @@ func TestHnsw_SearchLevel(t *testing.T) { t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor) } - closestItem, err := closestNeighbor.PopItem() + closestItem, err := closestNeighbor.PopMinItem() if err != nil { t.Fatal(err) @@ -149,7 +149,7 @@ func TestHnsw_SearchLevel(t *testing.T) { t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor) } - closestItem, err := closestNeighbor.PopItem() + closestItem, err := closestNeighbor.PopMinItem() if err != nil { t.Fatal(err) @@ -188,7 +188,7 @@ func TestHnsw_SearchLevel(t *testing.T) { t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor) } - closestItem, err := closestNeighbor.PopItem() + closestItem, err := closestNeighbor.PopMinItem() if err != nil { t.Fatal(err) @@ -218,7 +218,7 @@ func TestHnsw_SearchLevel(t *testing.T) { t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor.Len()) } - closestItem, err := closestNeighbor.PopItem() + closestItem, err := closestNeighbor.PopMinItem() if err != nil { t.Fatal(err) } @@ -253,7 +253,7 @@ func TestHnsw_SearchLevel(t *testing.T) { var closestIds []Id for !closestNeighbor.IsEmpty() { - closestItem, err := closestNeighbor.PopItem() + closestItem, err := closestNeighbor.PopMinItem() if err != nil { t.Fatal(err) } @@ -330,7 +330,7 @@ func TestHnsw_FindCloserEntryPoint(t *testing.T) { func TestHnsw_SelectNeighbors(t *testing.T) { t.Run("selects neighbors given overflow", func(t *testing.T) { - nearestNeighbors := NewBaseQueue(MinComparator{}) + nearestNeighbors := NewDistHeap() M := 4 @@ -352,7 +352,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) { } // for the sake of testing, let's rebuild the pq and assert ids are correct - reneighbors := NewBaseQueue(MinComparator{}) + reneighbors := NewDistHeap() for _, item := range neighbors { reneighbors.Insert(item.id, item.dist) @@ -360,7 +360,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) { expectedId := Id(0) for !reneighbors.IsEmpty() { - nn, err := reneighbors.PopItem() + nn, err := reneighbors.PopMinItem() if err != nil { t.Fatal(err) @@ -378,7 +378,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) { M := 10 h := NewHnsw(2, 10, M, Point{0, 0}) - nnQueue := NewBaseQueue(MinComparator{}) + nnQueue := NewDistHeap() for i := 0; i < 3; i++ { nnQueue.Insert(Id(i), float32(i)) @@ -394,7 +394,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) { t.Fatalf("select neighbors should have at least 3 neighbors, got: %v", len(neighbors)) } - reneighbors := NewBaseQueue(MinComparator{}) + reneighbors := NewDistHeap() for _, item := range neighbors { reneighbors.Insert(item.id, item.dist) @@ -402,7 +402,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) { expectedId := Id(0) for !reneighbors.IsEmpty() { - nn, err := reneighbors.PopItem() + nn, err := reneighbors.PopMinItem() if err != nil { t.Fatal(err) @@ -510,7 +510,7 @@ func TestHnsw_KnnSearch(t *testing.T) { expectedId := Id(3) for !nearestNeighbors.IsEmpty() { - nearestNeighbor, err := nearestNeighbors.PopItem() + nearestNeighbor, err := nearestNeighbors.PopMinItem() if err != nil { t.Fatalf("failed to pop item: %v, err: %v", nearestNeighbors, err) } @@ -546,7 +546,7 @@ func TestHnsw_KnnSearch(t *testing.T) { var gotIds []Id for !closestToQ.IsEmpty() { - closest, err := closestToQ.PopItem() + closest, err := closestToQ.PopMinItem() if err != nil { t.Fatalf("failed to pop item: %v, err: %v", closestToQ, err) } @@ -587,7 +587,7 @@ func TestHnsw_KnnSearch(t *testing.T) { var got []Id for !closestNeighbors.IsEmpty() { - closest, err := closestNeighbors.PopItem() + closest, err := closestNeighbors.PopMinItem() if err != nil { t.Fatalf("failed to pop item: %v, err: %v", closestNeighbors, err) } @@ -621,7 +621,7 @@ func TestHnsw_KnnSearch(t *testing.T) { expectedId := Id(0) for found.IsEmpty() { - nnItem, err := found.PopItem() + nnItem, err := found.PopMinItem() if err != nil { t.Fatalf("failed to pop item: %v, err: %v", found, err) } diff --git a/pkg/hnsw/minmax.go b/pkg/hnsw/minmax.go index b6658c1..fd7868c 100644 --- a/pkg/hnsw/minmax.go +++ b/pkg/hnsw/minmax.go @@ -9,14 +9,9 @@ package hnsw import ( - "container/heap" "math/bits" ) -// Interface copied from the heap package, so code that imports minmaxheap does -// not also have to import "container/heap". -type Interface = heap.Interface - func level(i int) int { // floor(log2(i + 1)) return bits.Len(uint(i)+1) - 1 @@ -50,7 +45,7 @@ func grandparent(i int) int { return parent(parent(i)) } -func down(h Interface, i, n int) bool { +func (d *DistHeap) down(i, n int) bool { min := isMinLevel(i) i0 := i for { @@ -60,18 +55,18 @@ func down(h Interface, i, n int) bool { if l >= n || l < 0 /* overflow */ { break } - if h.Less(l, m) == min { + if d.Less(l, m) == min { m = l } r := rchild(i) - if r < n && h.Less(r, m) == min { + if r < n && d.Less(r, m) == min { m = r } // grandchildren are contiguous i*4+3+{0,1,2,3} for g := lchild(l); g < n && g <= rchild(r); g++ { - if h.Less(g, m) == min { + if d.Less(g, m) == min { m = g } } @@ -80,7 +75,7 @@ func down(h Interface, i, n int) bool { break } - h.Swap(i, m) + d.Swap(i, m) if m == l || m == r { break @@ -88,21 +83,21 @@ func down(h Interface, i, n int) bool { // m is grandchild p := parent(m) - if h.Less(p, m) == min { - h.Swap(m, p) + if d.Less(p, m) == min { + d.Swap(m, p) } i = m } return i > i0 } -func up(h Interface, i int) { +func (d *DistHeap) up(i int) { min := isMinLevel(i) if hasParent(i) { p := parent(i) - if h.Less(p, i) == min { - h.Swap(i, p) + if d.Less(p, i) == min { + d.Swap(i, p) min = !min i = p } @@ -110,83 +105,11 @@ func up(h Interface, i int) { for hasGrandparent(i) { g := grandparent(i) - if h.Less(i, g) != min { + if d.Less(i, g) != min { return } - h.Swap(i, g) + d.Swap(i, g) i = g } } - -// Init establishes the heap invariants required by the other routines in this -// package. Init may be called whenever the heap invariants may have been -// invalidated. -// The complexity is O(n) where n = h.Len(). -func Init(h Interface) { - n := h.Len() - for i := n/2 - 1; i >= 0; i-- { - down(h, i, n) - } -} - -// Push pushes the element x onto the heap. -// The complexity is O(log n) where n = h.Len(). -func Push(h Interface, x interface{}) { - h.Push(x) - up(h, h.Len()-1) -} - -// Pop removes and returns the minimum element (according to Less) from the heap. -// The complexity is O(log n) where n = h.Len(). -func Pop(h Interface) interface{} { - n := h.Len() - 1 - h.Swap(0, n) - down(h, 0, n) - return h.Pop() -} - -// PopMax removes and returns the maximum element (according to Less) from the heap. -// The complexity is O(log n) where n = h.Len(). -func PopMax(h Interface) interface{} { - n := h.Len() - - i := 0 - l := lchild(0) - if l < n && !h.Less(l, i) { - i = l - } - - r := rchild(0) - if r < n && !h.Less(r, i) { - i = r - } - - h.Swap(i, n-1) - down(h, i, n-1) - return h.Pop() -} - -// Remove removes and returns the element at index i from the heap. -// The complexity is O(log n) where n = h.Len(). -func Remove(h Interface, i int) interface{} { - n := h.Len() - 1 - if n != i { - h.Swap(i, n) - if !down(h, i, n) { - up(h, i) - } - } - return h.Pop() -} - -// Fix re-establishes the heap ordering after the element at index i has -// changed its value. Changing the value of the element at index i and then -// calling Fix is equivalent to, but less expensive than, calling Remove(h, i) -// followed by a Push of the new value. -// The complexity is O(log n) where n = h.Len(). -func Fix(h Interface, i int) { - if !down(h, i, h.Len()) { - up(h, i) - } -} diff --git a/pkg/hnsw/pq.go b/pkg/hnsw/pq.go deleted file mode 100644 index 8f5c0e0..0000000 --- a/pkg/hnsw/pq.go +++ /dev/null @@ -1,173 +0,0 @@ -package hnsw - -import ( - "container/heap" - "fmt" -) - -type Comparator interface { - Less(i, j *Item) bool -} - -// MaxComparator implements the Comparator interface for a max-heap. -type MaxComparator struct{} - -func (c MaxComparator) Less(i, j *Item) bool { - return i.dist > j.dist -} - -// MinComparator implements the Comparator interface for a min-heap. -type MinComparator struct{} - -func (c MinComparator) Less(i, j *Item) bool { - return i.dist < j.dist -} - -type Item struct { - id Id - dist float32 - index int -} - -type Heapy interface { - heap.Interface - Insert(id Id, dist float32) - IsEmpty() bool - Len() int - PopItem() *Item - Top() *Item - Take(count int) (*BaseQueue, error) - update(item *Item, id Id, dist float32) -} - -// Nothing from BaseQueue should be used. Only use the Max and Min queue. -// BaseQueue isn't even a heap! It misses the Less() method which the Min/Max queue implement. -type BaseQueue struct { - visitedIds map[Id]*Item - items []*Item - comparator Comparator -} - -func (bq *BaseQueue) Take(count int, comparator Comparator) (*BaseQueue, error) { - if len(bq.items) < count { - return nil, fmt.Errorf("queue only has %v items, but want to take %v", len(bq.items), count) - } - - pq := NewBaseQueue(comparator) - - ct := 0 - for { - if ct == count { - break - } - - peeled, err := bq.PopItem() - if err != nil { - return nil, err - } - - pq.Insert(peeled.id, peeled.dist) - - ct++ - } - - return pq, nil -} - -func (bq BaseQueue) Len() int { return len(bq.items) } -func (bq BaseQueue) Swap(i, j int) { - pq := bq.items - pq[i], pq[j] = pq[j], pq[i] - pq[i].index = i - pq[j].index = j -} - -func (bq *BaseQueue) Push(x any) { - n := len(bq.items) - item := x.(*Item) - item.index = n - bq.items = append(bq.items, item) -} - -func (bq *BaseQueue) Top() *Item { - if len(bq.items) == 0 { - return nil - } - return bq.items[0] -} - -func (bq *BaseQueue) IsEmpty() bool { - return len(bq.items) == 0 -} - -func (bq *BaseQueue) Pop() any { - old := bq.items - n := len(old) - item := old[n-1] - old[n-1] = nil - item.index = -1 - bq.items = old[0 : n-1] - return item -} - -func (bq *BaseQueue) Less(i, j int) bool { - return bq.comparator.Less(bq.items[i], bq.items[j]) -} - -func (bq *BaseQueue) Insert(id Id, dist float32) { - if item, ok := bq.visitedIds[id]; ok { - bq.update(item, id, dist) - return - } - - newItem := Item{id: id, dist: dist} - heap.Push(bq, &newItem) - bq.visitedIds[id] = &newItem - -} - -func NewBaseQueue(comparator Comparator) *BaseQueue { - bq := &BaseQueue{ - visitedIds: map[Id]*Item{}, - comparator: comparator, - } - heap.Init(bq) - return bq -} - -func (bq *BaseQueue) PopItem() (*Item, error) { - if bq.Len() == 0 { - return nil, fmt.Errorf("no items to peel") - } - popped := heap.Pop(bq).(*Item) - delete(bq.visitedIds, popped.id) - return popped, nil -} - -func (bq *BaseQueue) update(item *Item, id Id, dist float32) { - item.id = id - item.dist = dist - heap.Fix(bq, item.index) -} - -func FromBaseQueue(bq *BaseQueue, comparator Comparator) *BaseQueue { - newBq := NewBaseQueue(comparator) - - for _, item := range bq.items { - newBq.Insert(item.id, item.dist) - } - - return newBq -} - -func FromItems(items []*Item, comparator Comparator) *BaseQueue { - bq := &BaseQueue{ - visitedIds: map[Id]*Item{}, - items: items, - comparator: comparator, - } - - heap.Init(bq) - - return bq -} diff --git a/pkg/hnsw/pq_test.go b/pkg/hnsw/pq_test.go deleted file mode 100644 index 5bc1539..0000000 --- a/pkg/hnsw/pq_test.go +++ /dev/null @@ -1,105 +0,0 @@ -package hnsw - -import ( - "testing" -) - -func TestPQ(t *testing.T) { - - t.Run("bricks and ladders || min heap", func(t *testing.T) { - type Case struct { - heights []int - bricks int - ladders int - expected int - } - - cases := [3]Case{ - { - heights: []int{4, 2, 7, 6, 9, 14, 12}, - bricks: 5, - ladders: 1, - expected: 4, - }, - { - heights: []int{4, 12, 2, 7, 3, 18, 20, 3, 19}, - bricks: 10, - ladders: 2, - expected: 7, - }, - { - heights: []int{14, 3, 19, 3}, - bricks: 17, - ladders: 0, - expected: 3, - }, - } - - for _, c := range cases { - res, err := furthestBuildings(c.heights, c.bricks, c.ladders) - if err != nil { - t.Fatal(err) - } - - if res != c.expected { - t.Errorf("got %d, want %d", res, c.expected) - } - } - }) - - t.Run("interchange", func(t *testing.T) { - bq := NewBaseQueue(MinComparator{}) - for i := 0; i < 100; i++ { - bq.Insert(Id(i), float32(i)) - } - - incBq := FromBaseQueue(bq, MaxComparator{}) - - i := Id(99) - for !incBq.IsEmpty() { - item, err := incBq.PopItem() - if err != nil { - t.Fatal(err) - } - - if item.id != i { - t.Fatalf("got %d, want %d", item.id, i) - } - - i -= 1 - } - }) -} - -func furthestBuildings(heights []int, bricks, ladders int) (int, error) { - - ladderJumps := NewBaseQueue(MinComparator{}) - - for idx := 0; idx < len(heights)-1; idx++ { - height := heights[idx] - nextHeight := heights[idx+1] - - if height >= nextHeight { - continue - } - - jump := nextHeight - height - - ladderJumps.Insert(Id(idx), float32(jump)) - - if ladderJumps.Len() > ladders { - minLadderJump, err := ladderJumps.PopItem() - if err != nil { - return -1, err - } - - if bricks-int(minLadderJump.dist) < 0 { - return idx, nil - } - - bricks -= int(minLadderJump.dist) - } - } - - return len(heights) - 1, nil -}