diff --git a/examples/client/server.go b/examples/client/server.go
index 726d359a..385dd167 100644
--- a/examples/client/server.go
+++ b/examples/client/server.go
@@ -12,6 +12,10 @@ func main() {
// Handle all requests by serving a file of the same name
http.Handle("/", fs)
+ http.HandleFunc("/search", func(w http.ResponseWriter, r *http.Request) {
+ http.ServeFile(w, r, "./search.html")
+ })
+
// Define the port to listen on
port := "3001"
log.Printf("Listening on http://localhost:%s/", port)
diff --git a/examples/search/index.html b/examples/search/index.html
deleted file mode 100644
index f360530f..00000000
--- a/examples/search/index.html
+++ /dev/null
@@ -1,285 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
diff --git a/examples/search/server.go b/examples/search/server.go
deleted file mode 100644
index 35185397..00000000
--- a/examples/search/server.go
+++ /dev/null
@@ -1,24 +0,0 @@
-package main
-
-import (
- "log"
- "net/http"
-)
-
-func main() {
- // Set the directory to serve
- fs := http.FileServer(http.Dir("./"))
-
- // Handle all requests by serving a file of the same name
- http.Handle("/", fs)
-
- // Define the port to listen on
- port := "3000"
- log.Printf("Listening on http://localhost:%s/", port)
-
- // Start the server
- err := http.ListenAndServe(":"+port, nil)
- if err != nil {
- log.Fatal(err)
- }
-}
diff --git a/package-lock.json b/package-lock.json
index f88293de..e9c2ba58 100644
--- a/package-lock.json
+++ b/package-lock.json
@@ -1850,12 +1850,12 @@
}
},
"node_modules/braces": {
- "version": "3.0.2",
- "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz",
- "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==",
+ "version": "3.0.3",
+ "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.3.tgz",
+ "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==",
"dev": true,
"dependencies": {
- "fill-range": "^7.0.1"
+ "fill-range": "^7.1.1"
},
"engines": {
"node": ">=8"
@@ -2346,9 +2346,9 @@
}
},
"node_modules/fill-range": {
- "version": "7.0.1",
- "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz",
- "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==",
+ "version": "7.1.1",
+ "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.1.1.tgz",
+ "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==",
"dev": true,
"dependencies": {
"to-regex-range": "^5.0.1"
diff --git a/package.json b/package.json
index 09ebaca8..f66e10c5 100644
--- a/package.json
+++ b/package.json
@@ -5,9 +5,8 @@
"main": "index.js",
"scripts": {
"build": "esbuild src/index.ts --bundle --minify --sourcemap --outfile=dist/appendable.min.js",
- "warp": "rm -rf dist examples/client/appendable.min.js examples/client/appendable.min.js.map examples/search/appendable.min.js examples/search/appendable.min.js.map && esbuild src/index.ts --bundle --minify --sourcemap --outfile=dist/appendable.min.js",
+ "warp": "rm -rf dist examples/client/appendable.min.js examples/client/appendable.min.js.map && esbuild src/index.ts --bundle --minify --sourcemap --outfile=dist/appendable.min.js",
"client": "cd examples/client && go run server.go",
- "search": "cd examples/search && go run server.go",
"test": "jest"
},
"repository": {
diff --git a/pkg/hnsw/friends.go b/pkg/hnsw/friends.go
index bc782708..061b2fd0 100644
--- a/pkg/hnsw/friends.go
+++ b/pkg/hnsw/friends.go
@@ -8,15 +8,15 @@ import (
type Point []float32
type Friends struct {
- friends []*BaseQueue
+ friends []*DistHeap
}
// NewFriends creates a new vector, note the max level is inclusive.
func NewFriends(topLevel int) *Friends {
- friends := make([]*BaseQueue, topLevel+1)
+ friends := make([]*DistHeap, topLevel+1)
for i := 0; i <= topLevel; i++ {
- friends[i] = NewBaseQueue(MinComparator{})
+ friends[i] = NewDistHeap()
}
return &Friends{
@@ -51,7 +51,7 @@ func (v *Friends) InsertFriendsAtLevel(level int, friendId Id, dist float32) {
}
}
-func (v *Friends) GetFriendsAtLevel(level int) (*BaseQueue, error) {
+func (v *Friends) GetFriendsAtLevel(level int) (*DistHeap, error) {
if !v.HasLevel(level) {
return nil, errors.New("failed to get friends at level")
}
diff --git a/pkg/hnsw/friends_test.go b/pkg/hnsw/friends_test.go
index ea52d670..a39f1466 100644
--- a/pkg/hnsw/friends_test.go
+++ b/pkg/hnsw/friends_test.go
@@ -69,12 +69,18 @@ func TestVector_LevelManagement(t *testing.T) {
t.Fatalf("expected hex and oct friends list at level %v to be 1, got: %v || %v", i, hexFriends.Len(), octFriends.Len())
}
- top := hexFriends.Top()
+ top, err := hexFriends.PeekMinItem()
+ if err != nil {
+ t.Fatal(err)
+ }
if top.id != octId {
t.Fatalf("expected %v, got %v", octId, top.id)
}
- top = octFriends.Top()
+ top, err = octFriends.PeekMinItem()
+ if err != nil {
+ t.Fatal(err)
+ }
if top.id != hexId {
t.Fatalf("expected %v, got %v", hexId, top.id)
}
diff --git a/pkg/hnsw/heap.go b/pkg/hnsw/heap.go
new file mode 100644
index 00000000..b7ef650a
--- /dev/null
+++ b/pkg/hnsw/heap.go
@@ -0,0 +1,140 @@
+package hnsw
+
+import (
+ "fmt"
+)
+
+type Item struct {
+ id Id
+ dist float32
+}
+
+var EmptyHeapError = fmt.Errorf("Empty Heap")
+
+type DistHeap struct {
+ items []*Item
+ visited map[Id]bool
+}
+
+func NewDistHeap() *DistHeap {
+ d := &DistHeap{
+ items: make([]*Item, 0),
+ visited: make(map[Id]bool),
+ }
+ return d
+}
+func FromItems(items []*Item) *DistHeap {
+ visited := make(map[Id]bool)
+ for _, item := range items {
+ visited[item.id] = true
+ }
+
+ d := &DistHeap{items: items, visited: visited}
+ d.Init()
+
+ return d
+}
+
+func (d *DistHeap) Init() {
+ n := d.Len()
+ for i := n/2 - 1; i >= 0; i-- {
+ d.down(i, n)
+ }
+}
+
+func (d *DistHeap) PeekMinItem() (*Item, error) {
+ if d.IsEmpty() {
+ return nil, EmptyHeapError
+ }
+
+ return d.items[0], nil
+}
+func (d *DistHeap) PeekMaxItem() (*Item, error) {
+ if d.Len() == 0 {
+ return nil, EmptyHeapError
+ }
+
+ // Find the maximum element without removing it
+ n := d.Len()
+
+ i := 0
+ l := lchild(0)
+ if l < n && !d.Less(l, i) {
+ i = l
+ }
+
+ r := rchild(0)
+ if r < n && !d.Less(r, i) {
+ i = r
+ }
+
+ return d.items[i], nil
+}
+func (d *DistHeap) PopMinItem() (*Item, error) {
+ if d.IsEmpty() {
+ return nil, EmptyHeapError
+ }
+
+ n := d.Len() - 1
+ d.Swap(0, n)
+ d.down(0, n)
+ return d.Pop(), nil
+}
+func (d *DistHeap) PopMaxItem() (*Item, error) {
+ if d.IsEmpty() {
+ return nil, EmptyHeapError
+ }
+
+ n := d.Len()
+ i := 0
+ l := lchild(0)
+
+ if l < n && !d.Less(l, i) {
+ i = l
+ }
+
+ r := rchild(0)
+ if r < n && !d.Less(r, i) {
+ i = r
+ }
+
+ d.Swap(i, n-1)
+ d.down(i, n-1)
+
+ return d.Pop(), nil
+}
+func (d *DistHeap) Insert(id Id, dist float32) {
+ if d.visited[id] {
+ for idx, item := range d.items {
+ if item.id == id {
+ item.dist = dist
+ d.Fix(idx)
+ return
+ }
+ }
+ } else {
+ d.Push(&Item{id: id, dist: dist})
+ d.up(d.Len() - 1)
+ d.visited[id] = true
+ }
+}
+func (d *DistHeap) Fix(i int) {
+ if !d.down(i, d.Len()) {
+ d.up(i)
+ }
+}
+
+func (d DistHeap) IsEmpty() bool { return len(d.items) == 0 }
+func (d DistHeap) Len() int { return len(d.items) }
+func (d DistHeap) Less(i, j int) bool { return d.items[i].dist < d.items[j].dist }
+func (d DistHeap) Swap(i, j int) { d.items[i], d.items[j] = d.items[j], d.items[i] }
+func (d *DistHeap) Push(x *Item) {
+ (*d).items = append((*d).items, x)
+}
+func (d *DistHeap) Pop() *Item {
+ old := (*d).items
+ n := len(old)
+ x := old[n-1]
+ (*d).items = old[0 : n-1]
+ return x
+}
diff --git a/pkg/hnsw/heap_test.go b/pkg/hnsw/heap_test.go
new file mode 100644
index 00000000..ec131c78
--- /dev/null
+++ b/pkg/hnsw/heap_test.go
@@ -0,0 +1,174 @@
+package hnsw
+
+import "testing"
+
+func TestHeap(t *testing.T) {
+
+ t.Run("basic min max properties", func(t *testing.T) {
+ h := NewDistHeap()
+
+ for i := 10; i > 0; i-- {
+ h.Insert(Id(i), float32(10-i))
+ }
+
+ if h.Len() != 10 {
+ t.Fatalf("heap length should be 10, got %v", h.Len())
+ }
+
+ expectedId := Id(10)
+ for !h.IsEmpty() {
+ peekMinItem, err := h.PeekMinItem()
+ if err != nil {
+ t.Fatalf("failed to peek min item: %v", err)
+ }
+
+ minItem, err := h.PopMinItem()
+ if err != nil {
+ t.Fatalf("failed to pop min item, err: %v", err)
+ }
+
+ if peekMinItem.id != minItem.id {
+ t.Fatalf("mismatched item id, expected %v, got %v", expectedId, peekMinItem.id)
+ }
+
+ if minItem.id != expectedId {
+ t.Fatalf("mismatched ids, expected %v, got: %v", expectedId, minItem.id)
+ }
+
+ expectedId -= 1
+ }
+ })
+
+ t.Run("basic min max properties 2", func(t *testing.T) {
+ h := NewDistHeap()
+
+ for i := 0; i <= 10; i++ {
+ h.Insert(Id(i), float32(10-i))
+ }
+
+ maxExpectedId := Id(0)
+ minExpectedId := Id(10)
+
+ for !h.IsEmpty() {
+ peekMaxItem, err := h.PeekMaxItem()
+
+ if err != nil {
+ t.Fatalf("failed to peek max item, err: %v", err)
+ }
+
+ maxItem, err := h.PopMaxItem()
+
+ if err != nil {
+ t.Fatalf("failed to pop max item, err: %v", err)
+ }
+
+ if peekMaxItem.id != maxItem.id {
+ t.Fatalf("mismatched max ids, expected %v, got: %v", maxItem.id, peekMaxItem.id)
+ }
+
+ if maxItem.id != maxExpectedId {
+ t.Fatalf("expected id to be %v, got %v", maxExpectedId, maxItem.id)
+ }
+
+ if h.IsEmpty() {
+ continue
+ }
+
+ peekMinItem, err := h.PeekMinItem()
+ if err != nil {
+ t.Fatalf("failed to peek min item, err: %v", err)
+ }
+
+ minItem, err := h.PopMinItem()
+
+ if err != nil {
+ t.Fatalf("failed to pop min item, err: %v", err)
+ }
+
+ if peekMinItem.id != minItem.id {
+ t.Fatalf("mismatched min ids, expected %v, got: %v", maxItem.id, peekMaxItem.id)
+ }
+
+ if minItem.id != minExpectedId {
+ t.Fatalf("expected id to be %v, got %v", minExpectedId, minItem.id)
+ }
+
+ minExpectedId -= 1
+ maxExpectedId += 1
+ }
+ })
+
+ t.Run("bricks and ladders || min heap", func(t *testing.T) {
+ type Case struct {
+ heights []int
+ bricks int
+ ladders int
+ expected int
+ }
+
+ cases := [3]Case{
+ {
+ heights: []int{4, 2, 7, 6, 9, 14, 12},
+ bricks: 5,
+ ladders: 1,
+ expected: 4,
+ },
+ {
+ heights: []int{4, 12, 2, 7, 3, 18, 20, 3, 19},
+ bricks: 10,
+ ladders: 2,
+ expected: 7,
+ },
+ {
+ heights: []int{14, 3, 19, 3},
+ bricks: 17,
+ ladders: 0,
+ expected: 3,
+ },
+ }
+
+ for _, c := range cases {
+ res, err := furthestBuildings(c.heights, c.bricks, c.ladders)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if res != c.expected {
+ t.Errorf("got %d, want %d", res, c.expected)
+ }
+ }
+ })
+}
+
+func furthestBuildings(heights []int, bricks, ladders int) (int, error) {
+
+ ladderJumps := NewDistHeap()
+
+ for idx := 0; idx < len(heights)-1; idx++ {
+ height := heights[idx]
+ nextHeight := heights[idx+1]
+
+ if height >= nextHeight {
+ continue
+ }
+
+ jump := nextHeight - height
+
+ ladderJumps.Insert(Id(idx), float32(jump))
+
+ if ladderJumps.Len() > ladders {
+ minLadderJump, err := ladderJumps.PopMinItem()
+ if err != nil {
+ return -1, err
+ }
+
+ if bricks-int(minLadderJump.dist) < 0 {
+ return idx, nil
+ }
+
+ bricks -= int(minLadderJump.dist)
+ }
+ }
+
+ return len(heights) - 1, nil
+}
diff --git a/pkg/hnsw/hnsw.go b/pkg/hnsw/hnsw.go
index 8cd14b02..750d0dfa 100644
--- a/pkg/hnsw/hnsw.go
+++ b/pkg/hnsw/hnsw.go
@@ -60,23 +60,26 @@ func (h *Hnsw) GenerateId() Id {
return Id(len(h.points))
}
-func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, level int) (*BaseQueue, error) {
+func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, level int) (*DistHeap, error) {
visited := make([]bool, len(h.friends)+1)
- candidatesForQ := NewBaseQueue(MinComparator{})
- foundNNToQ := NewBaseQueue(MaxComparator{})
+ candidatesForQ := NewDistHeap()
+ foundNNToQ := NewDistHeap() // this is a max
// note entryItem.dist should be the distance to Q
candidatesForQ.Insert(entryItem.id, entryItem.dist)
foundNNToQ.Insert(entryItem.id, entryItem.dist)
for !candidatesForQ.IsEmpty() {
- closestCandidate, err := candidatesForQ.PopItem()
+ closestCandidate, err := candidatesForQ.PopMinItem()
if err != nil {
return nil, fmt.Errorf("error during searching level %d: %w", level, err)
}
- furthestFoundNN := foundNNToQ.Top()
+ furthestFoundNN, err := foundNNToQ.PeekMaxItem()
+ if err != nil {
+ return nil, fmt.Errorf("error during searching level %d: %w", level, err)
+ }
// if distance(c, q) > distance(f, q)
if closestCandidate.dist > furthestFoundNN.dist {
@@ -94,7 +97,10 @@ func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, lev
if !visited[ccFriendId] {
visited[ccFriendId] = true
- furthestFoundNN = foundNNToQ.Top()
+ furthestFoundNN, err = foundNNToQ.PeekMaxItem()
+ if err != nil {
+ return nil, fmt.Errorf("error during searching level %d: %w", level, err)
+ }
ccFriendPoint, ok := h.points[ccFriendId]
if !ok {
@@ -108,7 +114,7 @@ func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, lev
foundNNToQ.Insert(ccFriendId, ccFriendDistToQ)
if foundNNToQ.Len() > numNearestToQToReturn {
- if _, err = foundNNToQ.PopItem(); err != nil {
+ if _, err = foundNNToQ.PopMaxItem(); err != nil {
return nil, fmt.Errorf("error during searching level %d: %w", level, err)
}
}
@@ -118,7 +124,7 @@ func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, lev
}
- return FromBaseQueue(foundNNToQ, MinComparator{}), nil
+ return foundNNToQ, nil
}
func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item {
@@ -141,7 +147,7 @@ func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item {
return epItem
}
- newEpItem, err := closestNeighborsToQ.PopItem()
+ newEpItem, err := closestNeighborsToQ.PopMinItem()
if err != nil {
panic(err)
}
@@ -152,7 +158,7 @@ func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item {
return epItem
}
-func (h *Hnsw) selectNeighbors(nearestNeighbors *BaseQueue) ([]*Item, error) {
+func (h *Hnsw) selectNeighbors(nearestNeighbors *DistHeap) ([]*Item, error) {
if nearestNeighbors.Len() <= h.M {
return nearestNeighbors.items, nil
}
@@ -160,7 +166,7 @@ func (h *Hnsw) selectNeighbors(nearestNeighbors *BaseQueue) ([]*Item, error) {
nearestItems := make([]*Item, h.M)
for i := 0; i < h.M; i++ {
- nearestItem, err := nearestNeighbors.PopItem()
+ nearestItem, err := nearestNeighbors.PopMinItem()
if err != nil {
return nil, err
@@ -218,31 +224,19 @@ func (h *Hnsw) InsertVector(q Point) error {
return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}
- maxNumberOfNeighbors := h.M
- if level == 0 {
- maxNumberOfNeighbors = h.Mmax0
- }
-
- eConnections := neighborFriendsAtLevel
- if eConnections.Len() > maxNumberOfNeighbors {
- var items []*Item
-
- for !neighborFriendsAtLevel.IsEmpty() {
- nearestNeighborFriendAtLevelItem, err := neighborFriendsAtLevel.PopItem()
- if err != nil {
- return fmt.Errorf("failed to pop from neighborFriendsAtLevel: %w", err)
- }
-
- items = append(items, nearestNeighborFriendAtLevelItem)
+ for neighborFriendsAtLevel.Len() > h.M {
+ _, err := neighborFriendsAtLevel.PopMaxItem()
+ if err != nil {
+ return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}
eConnections = FromItems(items, MinComparator{})
}
- h.friends[neighbor.id].friends[level] = eConnections
+ h.friends[neighbor.id].friends[level] = neighborFriendsAtLevel
}
- newEntryItem, err := nnToQAtLevel.PopItem()
+ newEntryItem, err := nnToQAtLevel.PopMinItem()
if err != nil {
return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}
@@ -261,7 +255,7 @@ func (h *Hnsw) isValidPoint(point Point) bool {
return len(point) == h.vectorDimensionality
}
-func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*BaseQueue, error) {
+func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*DistHeap, error) {
entryPoint, ok := h.points[h.entryPointId]
if !ok {
@@ -287,7 +281,10 @@ func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*BaseQueue, error)
return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}
- entryItem = nearestNeighborQueueAtLevel.Top()
+ entryItem, err = nearestNeighborQueueAtLevel.PeekMinItem()
+ if err != nil {
+ return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
+ }
}
// level 0
@@ -296,24 +293,12 @@ func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*BaseQueue, error)
return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %d", h.entryPointId, 0)
}
- if nearestNeighborQueueAtLevel0.Len() <= numNeighborsToReturn {
- return nearestNeighborQueueAtLevel0, nil
- }
-
- var items []*Item
-
- for !nearestNeighborQueueAtLevel0.IsEmpty() {
- if len(items) == numNeighborsToReturn {
- return FromItems(items, MinComparator{}), nil
- }
-
- nearestNeighborAtLevel0Item, err := nearestNeighborQueueAtLevel0.PopItem()
+ for nearestNeighborQueueAtLevel0.Len() > numNeighborsToReturn {
+ _, err := nearestNeighborQueueAtLevel0.PopMaxItem()
if err != nil {
- return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %d", h.entryPointId, 0)
+ return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", h.entryPointId, err)
}
-
- items = append(items, nearestNeighborAtLevel0Item)
}
- return FromItems(items, MinComparator{}), nil
+ return nearestNeighborQueueAtLevel0, nil
}
diff --git a/pkg/hnsw/hnsw_test.go b/pkg/hnsw/hnsw_test.go
index 6e0d0eca..2573c4a8 100644
--- a/pkg/hnsw/hnsw_test.go
+++ b/pkg/hnsw/hnsw_test.go
@@ -104,7 +104,7 @@ func TestHnsw_SearchLevel(t *testing.T) {
t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor)
}
- closestItem, err := closestNeighbor.PopItem()
+ closestItem, err := closestNeighbor.PopMinItem()
if err != nil {
t.Fatal(err)
@@ -149,7 +149,7 @@ func TestHnsw_SearchLevel(t *testing.T) {
t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor)
}
- closestItem, err := closestNeighbor.PopItem()
+ closestItem, err := closestNeighbor.PopMinItem()
if err != nil {
t.Fatal(err)
@@ -188,7 +188,7 @@ func TestHnsw_SearchLevel(t *testing.T) {
t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor)
}
- closestItem, err := closestNeighbor.PopItem()
+ closestItem, err := closestNeighbor.PopMinItem()
if err != nil {
t.Fatal(err)
@@ -218,7 +218,7 @@ func TestHnsw_SearchLevel(t *testing.T) {
t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor.Len())
}
- closestItem, err := closestNeighbor.PopItem()
+ closestItem, err := closestNeighbor.PopMinItem()
if err != nil {
t.Fatal(err)
}
@@ -253,7 +253,7 @@ func TestHnsw_SearchLevel(t *testing.T) {
var closestIds []Id
for !closestNeighbor.IsEmpty() {
- closestItem, err := closestNeighbor.PopItem()
+ closestItem, err := closestNeighbor.PopMinItem()
if err != nil {
t.Fatal(err)
}
@@ -330,7 +330,7 @@ func TestHnsw_FindCloserEntryPoint(t *testing.T) {
func TestHnsw_SelectNeighbors(t *testing.T) {
t.Run("selects neighbors given overflow", func(t *testing.T) {
- nearestNeighbors := NewBaseQueue(MinComparator{})
+ nearestNeighbors := NewDistHeap()
M := 4
@@ -352,7 +352,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) {
}
// for the sake of testing, let's rebuild the pq and assert ids are correct
- reneighbors := NewBaseQueue(MinComparator{})
+ reneighbors := NewDistHeap()
for _, item := range neighbors {
reneighbors.Insert(item.id, item.dist)
@@ -360,7 +360,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) {
expectedId := Id(0)
for !reneighbors.IsEmpty() {
- nn, err := reneighbors.PopItem()
+ nn, err := reneighbors.PopMinItem()
if err != nil {
t.Fatal(err)
@@ -378,7 +378,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) {
M := 10
h := NewHnsw(2, 10, M, Point{0, 0})
- nnQueue := NewBaseQueue(MinComparator{})
+ nnQueue := NewDistHeap()
for i := 0; i < 3; i++ {
nnQueue.Insert(Id(i), float32(i))
@@ -394,7 +394,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) {
t.Fatalf("select neighbors should have at least 3 neighbors, got: %v", len(neighbors))
}
- reneighbors := NewBaseQueue(MinComparator{})
+ reneighbors := NewDistHeap()
for _, item := range neighbors {
reneighbors.Insert(item.id, item.dist)
@@ -402,7 +402,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) {
expectedId := Id(0)
for !reneighbors.IsEmpty() {
- nn, err := reneighbors.PopItem()
+ nn, err := reneighbors.PopMinItem()
if err != nil {
t.Fatal(err)
@@ -538,7 +538,7 @@ func TestHnsw_KnnSearch(t *testing.T) {
expectedId := Id(3)
for !nearestNeighbors.IsEmpty() {
- nearestNeighbor, err := nearestNeighbors.PopItem()
+ nearestNeighbor, err := nearestNeighbors.PopMinItem()
if err != nil {
t.Fatalf("failed to pop item: %v, err: %v", nearestNeighbors, err)
}
@@ -574,7 +574,7 @@ func TestHnsw_KnnSearch(t *testing.T) {
var gotIds []Id
for !closestToQ.IsEmpty() {
- closest, err := closestToQ.PopItem()
+ closest, err := closestToQ.PopMinItem()
if err != nil {
t.Fatalf("failed to pop item: %v, err: %v", closestToQ, err)
}
@@ -615,7 +615,7 @@ func TestHnsw_KnnSearch(t *testing.T) {
var got []Id
for !closestNeighbors.IsEmpty() {
- closest, err := closestNeighbors.PopItem()
+ closest, err := closestNeighbors.PopMinItem()
if err != nil {
t.Fatalf("failed to pop item: %v, err: %v", closestNeighbors, err)
}
@@ -649,7 +649,7 @@ func TestHnsw_KnnSearch(t *testing.T) {
expectedId := Id(0)
for found.IsEmpty() {
- nnItem, err := found.PopItem()
+ nnItem, err := found.PopMinItem()
if err != nil {
t.Fatalf("failed to pop item: %v, err: %v", found, err)
}
diff --git a/pkg/hnsw/minmax.go b/pkg/hnsw/minmax.go
new file mode 100644
index 00000000..fd7868c1
--- /dev/null
+++ b/pkg/hnsw/minmax.go
@@ -0,0 +1,115 @@
+// Package minmaxheap provides min-max heap operations for any type that
+// implements heap.Interface. A min-max heap can be used to implement a
+// double-ended priority queue.
+//
+// Min-max heap implementation from the 1986 paper "Min-Max Heaps and
+// Generalized Priority Queues" by Atkinson et. al.
+// https://doi.org/10.1145/6617.6621.
+
+package hnsw
+
+import (
+ "math/bits"
+)
+
+func level(i int) int {
+ // floor(log2(i + 1))
+ return bits.Len(uint(i)+1) - 1
+}
+
+func isMinLevel(i int) bool {
+ return level(i)%2 == 0
+}
+
+func lchild(i int) int {
+ return i*2 + 1
+}
+
+func rchild(i int) int {
+ return i*2 + 2
+}
+
+func parent(i int) int {
+ return (i - 1) / 2
+}
+
+func hasParent(i int) bool {
+ return i > 0
+}
+
+func hasGrandparent(i int) bool {
+ return i > 2
+}
+
+func grandparent(i int) int {
+ return parent(parent(i))
+}
+
+func (d *DistHeap) down(i, n int) bool {
+ min := isMinLevel(i)
+ i0 := i
+ for {
+ m := i
+
+ l := lchild(i)
+ if l >= n || l < 0 /* overflow */ {
+ break
+ }
+ if d.Less(l, m) == min {
+ m = l
+ }
+
+ r := rchild(i)
+ if r < n && d.Less(r, m) == min {
+ m = r
+ }
+
+ // grandchildren are contiguous i*4+3+{0,1,2,3}
+ for g := lchild(l); g < n && g <= rchild(r); g++ {
+ if d.Less(g, m) == min {
+ m = g
+ }
+ }
+
+ if m == i {
+ break
+ }
+
+ d.Swap(i, m)
+
+ if m == l || m == r {
+ break
+ }
+
+ // m is grandchild
+ p := parent(m)
+ if d.Less(p, m) == min {
+ d.Swap(m, p)
+ }
+ i = m
+ }
+ return i > i0
+}
+
+func (d *DistHeap) up(i int) {
+ min := isMinLevel(i)
+
+ if hasParent(i) {
+ p := parent(i)
+ if d.Less(p, i) == min {
+ d.Swap(i, p)
+ min = !min
+ i = p
+ }
+ }
+
+ for hasGrandparent(i) {
+ g := grandparent(i)
+ if d.Less(i, g) != min {
+ return
+ }
+
+ d.Swap(i, g)
+ i = g
+ }
+}
diff --git a/pkg/hnsw/pq.go b/pkg/hnsw/pq.go
deleted file mode 100644
index 8f5c0e04..00000000
--- a/pkg/hnsw/pq.go
+++ /dev/null
@@ -1,173 +0,0 @@
-package hnsw
-
-import (
- "container/heap"
- "fmt"
-)
-
-type Comparator interface {
- Less(i, j *Item) bool
-}
-
-// MaxComparator implements the Comparator interface for a max-heap.
-type MaxComparator struct{}
-
-func (c MaxComparator) Less(i, j *Item) bool {
- return i.dist > j.dist
-}
-
-// MinComparator implements the Comparator interface for a min-heap.
-type MinComparator struct{}
-
-func (c MinComparator) Less(i, j *Item) bool {
- return i.dist < j.dist
-}
-
-type Item struct {
- id Id
- dist float32
- index int
-}
-
-type Heapy interface {
- heap.Interface
- Insert(id Id, dist float32)
- IsEmpty() bool
- Len() int
- PopItem() *Item
- Top() *Item
- Take(count int) (*BaseQueue, error)
- update(item *Item, id Id, dist float32)
-}
-
-// Nothing from BaseQueue should be used. Only use the Max and Min queue.
-// BaseQueue isn't even a heap! It misses the Less() method which the Min/Max queue implement.
-type BaseQueue struct {
- visitedIds map[Id]*Item
- items []*Item
- comparator Comparator
-}
-
-func (bq *BaseQueue) Take(count int, comparator Comparator) (*BaseQueue, error) {
- if len(bq.items) < count {
- return nil, fmt.Errorf("queue only has %v items, but want to take %v", len(bq.items), count)
- }
-
- pq := NewBaseQueue(comparator)
-
- ct := 0
- for {
- if ct == count {
- break
- }
-
- peeled, err := bq.PopItem()
- if err != nil {
- return nil, err
- }
-
- pq.Insert(peeled.id, peeled.dist)
-
- ct++
- }
-
- return pq, nil
-}
-
-func (bq BaseQueue) Len() int { return len(bq.items) }
-func (bq BaseQueue) Swap(i, j int) {
- pq := bq.items
- pq[i], pq[j] = pq[j], pq[i]
- pq[i].index = i
- pq[j].index = j
-}
-
-func (bq *BaseQueue) Push(x any) {
- n := len(bq.items)
- item := x.(*Item)
- item.index = n
- bq.items = append(bq.items, item)
-}
-
-func (bq *BaseQueue) Top() *Item {
- if len(bq.items) == 0 {
- return nil
- }
- return bq.items[0]
-}
-
-func (bq *BaseQueue) IsEmpty() bool {
- return len(bq.items) == 0
-}
-
-func (bq *BaseQueue) Pop() any {
- old := bq.items
- n := len(old)
- item := old[n-1]
- old[n-1] = nil
- item.index = -1
- bq.items = old[0 : n-1]
- return item
-}
-
-func (bq *BaseQueue) Less(i, j int) bool {
- return bq.comparator.Less(bq.items[i], bq.items[j])
-}
-
-func (bq *BaseQueue) Insert(id Id, dist float32) {
- if item, ok := bq.visitedIds[id]; ok {
- bq.update(item, id, dist)
- return
- }
-
- newItem := Item{id: id, dist: dist}
- heap.Push(bq, &newItem)
- bq.visitedIds[id] = &newItem
-
-}
-
-func NewBaseQueue(comparator Comparator) *BaseQueue {
- bq := &BaseQueue{
- visitedIds: map[Id]*Item{},
- comparator: comparator,
- }
- heap.Init(bq)
- return bq
-}
-
-func (bq *BaseQueue) PopItem() (*Item, error) {
- if bq.Len() == 0 {
- return nil, fmt.Errorf("no items to peel")
- }
- popped := heap.Pop(bq).(*Item)
- delete(bq.visitedIds, popped.id)
- return popped, nil
-}
-
-func (bq *BaseQueue) update(item *Item, id Id, dist float32) {
- item.id = id
- item.dist = dist
- heap.Fix(bq, item.index)
-}
-
-func FromBaseQueue(bq *BaseQueue, comparator Comparator) *BaseQueue {
- newBq := NewBaseQueue(comparator)
-
- for _, item := range bq.items {
- newBq.Insert(item.id, item.dist)
- }
-
- return newBq
-}
-
-func FromItems(items []*Item, comparator Comparator) *BaseQueue {
- bq := &BaseQueue{
- visitedIds: map[Id]*Item{},
- items: items,
- comparator: comparator,
- }
-
- heap.Init(bq)
-
- return bq
-}
diff --git a/pkg/hnsw/pq_test.go b/pkg/hnsw/pq_test.go
deleted file mode 100644
index 5bc1539e..00000000
--- a/pkg/hnsw/pq_test.go
+++ /dev/null
@@ -1,105 +0,0 @@
-package hnsw
-
-import (
- "testing"
-)
-
-func TestPQ(t *testing.T) {
-
- t.Run("bricks and ladders || min heap", func(t *testing.T) {
- type Case struct {
- heights []int
- bricks int
- ladders int
- expected int
- }
-
- cases := [3]Case{
- {
- heights: []int{4, 2, 7, 6, 9, 14, 12},
- bricks: 5,
- ladders: 1,
- expected: 4,
- },
- {
- heights: []int{4, 12, 2, 7, 3, 18, 20, 3, 19},
- bricks: 10,
- ladders: 2,
- expected: 7,
- },
- {
- heights: []int{14, 3, 19, 3},
- bricks: 17,
- ladders: 0,
- expected: 3,
- },
- }
-
- for _, c := range cases {
- res, err := furthestBuildings(c.heights, c.bricks, c.ladders)
- if err != nil {
- t.Fatal(err)
- }
-
- if res != c.expected {
- t.Errorf("got %d, want %d", res, c.expected)
- }
- }
- })
-
- t.Run("interchange", func(t *testing.T) {
- bq := NewBaseQueue(MinComparator{})
- for i := 0; i < 100; i++ {
- bq.Insert(Id(i), float32(i))
- }
-
- incBq := FromBaseQueue(bq, MaxComparator{})
-
- i := Id(99)
- for !incBq.IsEmpty() {
- item, err := incBq.PopItem()
- if err != nil {
- t.Fatal(err)
- }
-
- if item.id != i {
- t.Fatalf("got %d, want %d", item.id, i)
- }
-
- i -= 1
- }
- })
-}
-
-func furthestBuildings(heights []int, bricks, ladders int) (int, error) {
-
- ladderJumps := NewBaseQueue(MinComparator{})
-
- for idx := 0; idx < len(heights)-1; idx++ {
- height := heights[idx]
- nextHeight := heights[idx+1]
-
- if height >= nextHeight {
- continue
- }
-
- jump := nextHeight - height
-
- ladderJumps.Insert(Id(idx), float32(jump))
-
- if ladderJumps.Len() > ladders {
- minLadderJump, err := ladderJumps.PopItem()
- if err != nil {
- return -1, err
- }
-
- if bricks-int(minLadderJump.dist) < 0 {
- return idx, nil
- }
-
- bricks -= int(minLadderJump.dist)
- }
- }
-
- return len(heights) - 1, nil
-}