From 6360aa85254340e257d5873c18d6412d0a3caac4 Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Thu, 30 May 2024 14:17:52 -0400 Subject: [PATCH 1/6] feat: hnsw --- pkg/hnsw/eucqueue_test.go | 209 ------------ pkg/hnsw/hnsw.go | 292 ++-------------- pkg/hnsw/hnsw_test.go | 574 -------------------------------- pkg/hnsw/node.go | 119 ------- pkg/hnsw/node_test.go | 114 ------- pkg/hnsw/{eucqueue.go => pq.go} | 14 +- pkg/hnsw/vector.go | 122 +++++++ pkg/hnsw/vector_test.go | 123 +++++++ 8 files changed, 276 insertions(+), 1291 deletions(-) delete mode 100644 pkg/hnsw/eucqueue_test.go delete mode 100644 pkg/hnsw/hnsw_test.go delete mode 100644 pkg/hnsw/node.go delete mode 100644 pkg/hnsw/node_test.go rename pkg/hnsw/{eucqueue.go => pq.go} (90%) create mode 100644 pkg/hnsw/vector.go create mode 100644 pkg/hnsw/vector_test.go diff --git a/pkg/hnsw/eucqueue_test.go b/pkg/hnsw/eucqueue_test.go deleted file mode 100644 index 0d22c65..0000000 --- a/pkg/hnsw/eucqueue_test.go +++ /dev/null @@ -1,209 +0,0 @@ -package hnsw - -import ( - "testing" -) - -func TestEucQueue(t *testing.T) { - t.Run("builds min queue", func(t *testing.T) { - v0 := Vector{1.0} - - vs := [5]Vector{ - {2.3}, // id: 0, dist: 1.3, p: 4 - {1.1}, // id: 1, dist: 0.1, p: 1 - {2.0}, // id: 2, dist: 1.0, p: 3 - {3.3}, // id: 3, dist: 2.3, p: 5 - {0.8}, // id: 4, dist: 0.2, p: 2 - } - - eq := NewBaseQueue(MinComparator{}) - - if !eq.IsEmpty() || eq.Len() != 0 || len(eq.visitedIds) != 0 { - t.Fatalf("created new eq, expected empty, got %v len", eq.Len()) - } - - for i, v := range vs { - dist := EuclidDist(v0, v) - eq.Insert(NodeId(i), dist) - - if i+1 != eq.Len() { - t.Fatalf("inserting element %v means eq should have length of %v, got: %v", i, i+1, eq.Len()) - } - - if _, ok := eq.visitedIds[NodeId(i)]; !ok { - t.Fatalf("expected node id %v to be in visited set", i) - } - - } - - expected := [5]Item{ - {id: 1, dist: 0.1}, - {id: 4, dist: 0.2}, - {id: 2, dist: 1.0}, - {id: 0, dist: 1.3}, - {id: 3, dist: 2.3}, - } - - i := 0 - for eq.Len() > 0 { - item, err := eq.Peel() - if err != nil { - t.Fatal(err) - } - if item.id != expected[i].id { - t.Fatalf("expected item id %v, got %v at %v", expected[i].id, item.id, i) - } - - if !NearlyEqual(float64(item.dist), float64(expected[i].dist)) { - t.Fatalf("not equal, got %v, and %v", item.dist, expected[i].dist) - } - - if _, ok := eq.visitedIds[item.id]; ok { - t.Fatalf("expected item id %v to be popped!", item.id) - } - - i++ - } - }) - - t.Run("builds max queue", func(t *testing.T) { - v0 := Vector{1.0} - - vs := [5]Vector{ - {2.3}, // id: 0, dist: 1.3, p: 4 - {1.1}, // id: 1, dist: 0.1, p: 1 - {2.0}, // id: 2, dist: 1.0, p: 3 - {3.3}, // id: 3, dist: 2.3, p: 5 - {0.8}, // id: 4, dist: 0.2, p: 2 - } - - eq := NewBaseQueue(MaxComparator{}) - - if !eq.IsEmpty() || eq.Len() != 0 || len(eq.visitedIds) != 0 { - t.Fatalf("created new eq, expected empty, got %v len", eq.Len()) - } - - for i, v := range vs { - dist := EuclidDist(v0, v) - eq.Insert(NodeId(i), dist) - - if i+1 != eq.Len() { - t.Fatalf("inserting element %v means eq should have length of %v, got: %v", i, i+1, eq.Len()) - } - - if _, ok := eq.visitedIds[NodeId(i)]; !ok { - t.Fatalf("expected node id %v to be in visited set", i) - } - } - - expected := [5]Item{ - {id: 3, dist: 2.3}, - {id: 0, dist: 1.3}, - {id: 2, dist: 1.0}, - {id: 4, dist: 0.2}, - {id: 1, dist: 0.1}, - } - - i := 0 - for eq.Len() > 0 { - item, err := eq.Peel() - if err != nil { - t.Fatal(err) - } - if item.id != expected[i].id || !NearlyEqual(float64(item.dist), float64(expected[i].dist)) { - t.Fatalf("expected item id: %v, got id: %v at i: %v", expected[i].id, item.id, i) - } - - if _, ok := eq.visitedIds[item.id]; ok { - t.Fatalf("expected item id %v to be popped!", item.id) - } - - i++ - } - }) - - t.Run("takes correctly", func(t *testing.T) { - mq := FromBaseQueue([]*Item{ - {id: 1, dist: 33}, - {id: 2, dist: 32}, - {id: 3, dist: 69}, - {id: 4, dist: 3}, - {id: 6, dist: 0.01}, - }, MinComparator{}) - - pq, err := mq.Take(3, MinComparator{}) - if err != nil { - t.Fatalf("failed to take 3") - } - - if pq.Len() != 3 { - t.Fatalf("expected len: %v, got %v", 3, pq.Len()) - } - - if len(pq.visitedIds) != 3 { - t.Fatalf("expected # of visited ids to be %v, got %v", 3, len(pq.visitedIds)) - } - - }) - - t.Run("updates already existing id with new priority", func(t *testing.T) { - mq := FromBaseQueue([]*Item{ - {id: 1, dist: 2.2}, - {id: 2, dist: 3.0}, - }, MinComparator{}) - - mq.Insert(1, 2.1) - - if mq.Len() != 2 { - t.Fatalf("update shouldn't incur another element. expected length: %v, got: %v", 2, mq.Len()) - } - - if mq.Peek().id != 1 { - t.Fatalf("expected first id to be 1, got: %v", mq.Peek().id) - } - - if mq.Peek().dist != 2.1 { - t.Fatalf("expected distance to be updated to %v, got %v", 2.1, mq.Peek().dist) - } - - _, err := mq.Peel() - - if err != nil { - t.Fatalf("%v", err) - } - - if mq.Peek().id != 2 { - t.Fatalf("expected second id to be 2, got %v", mq.Peek().id) - } - - if mq.Peek().dist != 3.0 { - t.Fatalf("expected distance %v, got %v", 3.0, mq.Peek().dist) - } - }) - - t.Run("inserting with same id yields update, not insertion", func(t *testing.T) { - mq := FromBaseQueue([]*Item{ - {id: 1, dist: 40}, - }, MinComparator{}) - - for i := 2; i <= 100; i++ { - mq.Insert(1, float32(i)) - - if mq.Len() != 1 { - t.Fatalf("expected len to be %v, got %v", 1, mq.Len()) - } - - if !NearlyEqual(float64(mq.Peek().dist), float64(i)) { - t.Fatalf("expected distance to be the newly updated %v, got %v", i, mq.Peek().dist) - } - } - - if mq.Len() != 1 { - t.Fatalf("expected len to be %v, got %v", 1, mq.Len()) - } - - if !NearlyEqual(float64(mq.Peek().dist), float64(100)) { - t.Fatalf("expected distance to be the newly updated %v, got %v", 100, mq.Peek().dist) - } - }) -} diff --git a/pkg/hnsw/hnsw.go b/pkg/hnsw/hnsw.go index 1a6165b..08323a7 100644 --- a/pkg/hnsw/hnsw.go +++ b/pkg/hnsw/hnsw.go @@ -2,293 +2,49 @@ package hnsw import ( "fmt" - "math" - "math/rand" - "sync/atomic" ) +type Id = uint + type Hnsw struct { - vectorDim uint + vectorDimensionality int - Nodes map[NodeId]*Node + Vectors map[Id]*Vector - EntryNodeId NodeId - NextNodeId NodeId + normFactorForLevelGeneration int - MaxLevel int + // efConstruction is the size of the dynamic candIdate list + efConstruction uint // default number of connections M int - // Maximum number of connections per element per level - MMax, MMax0 int - - // Size of dynamic candidate list during construction - EfConstruction int - - // Normalization factor for level generation - levelMultiplier float64 -} - -func NewHNSW(d uint, m int, efc int, entryVector Vector) *Hnsw { - h := &Hnsw{vectorDim: d} - h.checkVectorDim(entryVector) - - nt := make(map[NodeId]*Node) - enId := NodeId(0) // special. Reserved for the entryPointNode - - entryPoint := NewNode(enId, entryVector, 0) - nt[enId] = entryPoint - - h.M = m - h.Nodes = nt - h.EntryNodeId = enId - h.NextNodeId = enId + 1 - h.levelMultiplier = 1 / math.Log(float64(m)) - h.EfConstruction = efc - h.MMax = m - h.MMax0 = m * 2 - - return h -} - -func (h *Hnsw) checkVectorDim(v Vector) { - if h.vectorDim != uint(len(v)) { - panic(fmt.Sprintf("vector (%v) is invalid, expected dim length %v, got %v", v, h.vectorDim, len(v))) - } -} - -func (h *Hnsw) getNextNodeId() NodeId { - return atomic.AddUint32(&h.NextNodeId, 1) - 1 -} - -func (h *Hnsw) spawnLevel() int { - return int(math.Floor(-math.Log(rand.Float64() * h.levelMultiplier))) -} - -func (h *Hnsw) searchLevel(q Vector, entryNodeItem *Item, ef int, levelId int) (*BaseQueue, error) { - // visited is a bitset that keeps track of all nodes that have been visited. - // we know the size of visited will never exceed len(h.Nodes) - visited := make([]bool, len(h.Nodes)) - - if entryNodeItem.id != h.EntryNodeId { - panic(fmt.Sprintf("debug: this should not occur. entry node mismatch got %v, expected: %v", entryNodeItem.id, h.EntryNodeId)) - } - - visited[entryNodeItem.id] = true - - candidates := NewBaseQueue(MinComparator{}) - candidates.Insert(entryNodeItem.id, entryNodeItem.dist) - - nearestNeighborsToQForEf := NewBaseQueue(MaxComparator{}) - nearestNeighborsToQForEf.Insert(entryNodeItem.id, entryNodeItem.dist) - - for !candidates.IsEmpty() { - // extract nearest element from C to q - closestCandidate, err := candidates.Peel() - if err != nil { - return nil, err - } - - // get the furthest element from W to q - furthestNN, err := nearestNeighborsToQForEf.Peel() - if err != nil { - return nil, err - } - - closestCandidateToQDist := closestCandidate.dist - furthestNNToQDist := furthestNN.dist - - if closestCandidateToQDist > furthestNNToQDist { - // all elements in W are evaluated - break - } - - if h.Nodes[closestCandidate.id].HasLevel(levelId) { - friends := h.Nodes[closestCandidate.id].GetFriendsAtLevel(levelId) - - for _, friend := range friends.items { - friendId := friend.id - - if !visited[friendId] { - visited[friendId] = true - furthestNNItem := nearestNeighborsToQForEf.Peek() - - friendToQDist := EuclidDist(h.Nodes[friendId].v, q) - - if nearestNeighborsToQForEf.Len() < ef { - nearestNeighborsToQForEf.Insert(friendId, friendToQDist) - candidates.Insert(friendId, friendToQDist) - } else if friendToQDist < furthestNNItem.dist { - nearestNeighborsToQForEf.Pop() - nearestNeighborsToQForEf.Insert(friendId, friendToQDist) - candidates.Insert(friendId, friendToQDist) - } - - return nearestNeighborsToQForEf, nil - } - } - } - } - - return nearestNeighborsToQForEf, nil -} - -func (h *Hnsw) selectNeighbors(candidates *BaseQueue, numNeighborsToReturn int) (*BaseQueue, error) { - if candidates.Len() <= numNeighborsToReturn { - return candidates, nil - } - - pq, err := candidates.Take(numNeighborsToReturn, MinComparator{}) - if err != nil { - return nil, fmt.Errorf("an error occured during take: %v", err) - } - - return pq, nil + // mmax, mmax0 is the maximum number of connections for each element per layer + mmax, mmax0 int } -func (h *Hnsw) KnnSearch(q Vector, kNeighborsToReturn, ef int) (*BaseQueue, error) { - currentNearestElements := NewBaseQueue(MinComparator{}) - entryPointNode := h.Nodes[h.EntryNodeId] - entryPointItem := &Item{id: h.EntryNodeId, dist: entryPointNode.VecDistFromVec(q)} - newEntryItem := h.findCloserEntryPoint(entryPointItem, q, 0) - - numNearestToQAtBase, err := h.searchLevel(q, newEntryItem, ef, 0) - - if err != nil { - return nil, err - } - - for !numNearestToQAtBase.IsEmpty() { - peeled, err := numNearestToQAtBase.Peel() - if err != nil { - return nil, err - } - currentNearestElements.Insert(peeled.id, peeled.dist) - } - - if currentNearestElements.Len() < kNeighborsToReturn { - return nil, fmt.Errorf("the currentNearestElement length %v", currentNearestElements.Len()) +func NewHnsw(d int, efConstruction uint, M, mmax, mmax0 int) *Hnsw { + if d <= 0 { + panic("vector dimensionality cannot be less than 1") } - pq, err := currentNearestElements.Take(kNeighborsToReturn, MinComparator{}) - if err != nil { - return nil, fmt.Errorf("failed to knnsearch, err: %v", err) + return &Hnsw{ + vectorDimensionality: d, + efConstruction: efConstruction, + M: M, + mmax: mmax, + mmax0: mmax0, } - - return pq, nil } -func (h *Hnsw) Link(friendItem *Item, node *Node, level int) { - dist := node.VecDistFromNode(h.Nodes[friendItem.id]) - - // update both friends - friend, ok := h.Nodes[friendItem.id] - - if !ok { - panic("should not happen") - } - - if friend.HasLevel(level) && node.HasLevel(level) { - friend.InsertFriendsAtLevel(level, node.id, dist) - node.InsertFriendsAtLevel(level, friend.id, dist) +func (h *Hnsw) InsertVector(v *Vector) error { + if !h.validateVector(v) { + return fmt.Errorf("invalidvector") } -} - -func (h *Hnsw) findCloserEntryPoint(ep *Item, q Vector, qLevel int) *Item { - for level := h.MaxLevel; level > qLevel; level-- { - friends := h.Nodes[ep.id].GetFriendsAtLevel(level) - for _, friend := range friends.items { - friendDist := h.Nodes[friend.id].VecDistFromVec(q) - - if friendDist < ep.dist { - ep = &Item{id: friend.id, dist: friend.dist} - } - } - } - return ep + return nil } -func (h *Hnsw) Insert(q Vector) error { - h.checkVectorDim(q) - - ep := h.Nodes[h.EntryNodeId] - currentTopLevel := ep.level - - // 1. build Node for vec q - qLevel := h.spawnLevel() - qNode := NewNode(h.getNextNodeId(), q, qLevel) - - epItem := &Item{id: ep.id, dist: ep.VecDistFromVec(q)} - - // 2. find the correct entry point - newEpItem := h.findCloserEntryPoint(epItem, q, qLevel) - - // 3. make the second pass, this time create connections - for level := min(currentTopLevel, qLevel); level >= 0; level-- { - nnToQAtLevel, err := h.searchLevel(q, newEpItem, h.EfConstruction, level) - if err != nil { - return fmt.Errorf("failed to make connections, %v", err) - } - - neighbors, err := h.selectNeighbors(nnToQAtLevel, h.M) - - if err != nil { - return err - } - - for !neighbors.IsEmpty() { - peeled, err := neighbors.Peel() - if err != nil { - return err - } - qNode.InsertFriendsAtLevel(level, peeled.id, peeled.dist) - } - } - - // 4. add qNode into the `Nodes` table - h.Nodes[qNode.id] = qNode - - // 5. Link connections - for level := min(currentTopLevel, qLevel); level >= 0; level-- { - friendsAtLevel := qNode.GetFriendsAtLevel(level) - - for !friendsAtLevel.IsEmpty() { - qfriend, err := friendsAtLevel.Peel() - if err != nil { - return err - } - h.Link(qfriend, qNode, level) - - qFriendNode := h.Nodes[qfriend.id] - qFriendNodeFriendsAtLevel := qFriendNode.GetFriendsAtLevel(level) - numFriendsForQFriendAtLevel := qFriendNodeFriendsAtLevel.Len() - - if (level == 0 && numFriendsForQFriendAtLevel > h.MMax0) || (level != 0 && numFriendsForQFriendAtLevel > h.MMax) { - var amt int - if level == 0 { - amt = h.MMax0 - } else { - amt = h.MMax - } - - pq, err := qFriendNodeFriendsAtLevel.Take(amt, MinComparator{}) - if err != nil { - return fmt.Errorf("failed to take friend id %v's %v at level %v", qfriend.id, amt, level) - } - - // shrink connections for a friend at level - h.Nodes[qfriend.id].friends[level] = pq - } - } - } - - // 6. update attr - if h.MaxLevel < qLevel { - h.MaxLevel = qLevel - h.EntryNodeId = qNode.id - } - - return nil +func (h *Hnsw) validateVector(v *Vector) bool { + return len(v.point) != h.vectorDimensionality } diff --git a/pkg/hnsw/hnsw_test.go b/pkg/hnsw/hnsw_test.go deleted file mode 100644 index c809fbf..0000000 --- a/pkg/hnsw/hnsw_test.go +++ /dev/null @@ -1,574 +0,0 @@ -package hnsw - -import ( - "fmt" - "reflect" - "testing" -) - -func TestHnsw(t *testing.T) { - t.Run("builds graph", func(t *testing.T) { - n := NewNode(0, []float32{0.1, 0.2}, 0) - h := NewHNSW(2, 32, 32, []float32{0.1, 0.2}) - if h.MaxLevel != n.level { - t.Fatalf("expected max level to default to %v, got %v", n.level, h.MaxLevel) - } - }) -} - -func TestHnswSelect(t *testing.T) { - - t.Run("selects m nearest elements to q", func(t *testing.T) { - candidates := FromBaseQueue([]*Item{ - {id: 1, dist: 30}, - {id: 2, dist: 29}, - {id: 3, dist: 28}, - {id: 4, dist: 27}, - {id: 5, dist: 26}, - {id: 6, dist: 25}, - {id: 7, dist: 24}, - {id: 8, dist: 23}, - {id: 9, dist: 22}, - {id: 10, dist: 21}, - {id: 11, dist: 20}, - }, MinComparator{}) - - h := NewHNSW(2, 32, 1, []float32{0, 0}) - - cn, err := h.selectNeighbors(candidates, 10) - - if err != nil { - t.Fatal(err) - } - - if cn.Len() != 10 { - t.Fatalf("did not take 10 items") - } - - expected := 11 - i := 0 - for !cn.IsEmpty() { - peeled, err := cn.Peel() - if err != nil { - t.Fatal(err) - } - if peeled.id != NodeId(expected) { - t.Fatalf("expected %v, but got %v at %v", expected, peeled.id, i) - } - - expected-- - i++ - } - }) - - t.Run("over selects! greedy", func(t *testing.T) { - candidates := FromBaseQueue([]*Item{ - {id: 1, dist: 30}, - {id: 2, dist: 0.6}, - {id: 3, dist: 8}, - }, MinComparator{}) - - h := NewHNSW(2, 32, 1, []float32{0, 0}) - - res, err := h.selectNeighbors(candidates, 10) - if err != nil || res.Len() != 3 { - t.Fatal("if num neighbors to return is greater than candidates, we should just be returning the candidates") - } - }) -} - -func TestHnsw_Insert(t *testing.T) { - - t.Run("nodes[0] is root", func(t *testing.T) { - h := NewHNSW(2, 32, 32, []float32{11, 11}) - - if len(h.Nodes) != 1 { - t.Fatalf("hnsw should be initialized with root node but got len: %v", len(h.Nodes)) - } - - if h.Nodes[0].id != 0 { - t.Fatalf("expected node id at 0 to be initialized but got %v", h.Nodes[0].id) - } - }) - - t.Run("hnsw with inserted element q", func(t *testing.T) { - h := NewHNSW(3, 32, 32, []float32{1, 1, 1}) - - if len(h.Nodes) != 1 { - t.Fatalf("hnsw should be initialized with root node but got len: %v", len(h.Nodes)) - } - - err := h.Insert([]float32{1.3, 2.5, 2.3}) - if err != nil { - return - } - - if len(h.Nodes) != 2 { - t.Fatalf("expected 2 nodes after insertion but got %v", len(h.Nodes)) - } - - if h.Nodes[1].id != 1 { - t.Fatalf("expected node id at 1 to be initialized but got %v", h.Nodes[1].id) - } - - if EuclidDist(h.Nodes[1].v, []float32{1.3, 2.5, 2.3}) != 0 { - t.Fatalf("incorrect vector inserted at %v expected vector %v but got %v", 1, []float32{1.3, 2.5, 2.3}, h.Nodes[1].v) - } - }) - - t.Run("multiple insert", func(t *testing.T) { - h := NewHNSW(2, 10, 10, []float32{0, 0}) - - for i := 0; i < 32; i++ { - if len(h.Nodes) != i+1 { - t.Fatalf("expected the number of nodes in graph to be %v, got %v", i+1, len(h.Nodes)) - } - - if err := h.Insert([]float32{float32(32 - i), float32(31 - i)}); err != nil { - t.Fatal(err) - } - - if len(h.Nodes) != i+2 { - t.Fatalf("expected the number of nodes in graph to be %v, got %v", i+2, len(h.Nodes)+2) - } - } - - items, err := h.KnnSearch([]float32{32, 31}, 10, 32) - if err != nil { - return - } - - if items.Len() != 10 { - t.Fatalf("expected to return %v neighbors, got: %v", 10, items.Len()) - } - - expectedId := NodeId(1) - - for !items.IsEmpty() { - peeled, err := items.Peel() - - if err != nil { - t.Fatal(err) - } - - if peeled.id != expectedId { - t.Fatalf("expected %v, but got %v", expectedId, peeled.id) - } - } - }) -} - -func TestHnswVectorDimension(t *testing.T) { - - t.Run("create new hnsw", func(t *testing.T) { - defer func() { - if r := recover(); r == nil { - t.Fatalf("Expected NewHNSW to panic due to mismatched dim, but it did not") - } - }() - NewHNSW(3, 10, 8, []float32{1}) - }) - - t.Run("insert mismatch vec", func(t *testing.T) { - h := NewHNSW(3, 10, 8, []float32{1, 1, 1}) - - defer func() { - if r := recover(); r == nil { - t.Fatalf("Expected NewHNSW to panic due to mismatched dim, but it did not") - } - }() - err := h.Insert([]float32{}) - if err != nil { - t.Fatal(err) - } - }) -} - -func TestHnsw_Link(t *testing.T) { - t.Run("links correctly", func(t *testing.T) { - - n1 := NewNode(1, make(Vector, 128), 3) - n2 := NewNode(2, make(Vector, 128), 0) - - p := make(Vector, 128) - h := NewHNSW(128, 4, 200, p) - - h.Nodes[1] = n1 - h.Nodes[2] = n2 - - i1 := Item{id: 1, dist: 3} - - // now h has enuogh context to test Linking - - if len(h.Nodes[1].friends) != 4 { - t.Fatalf("node1 has max layer 3 so 4 layers total, got %v", len(h.Nodes[1].friends)) - } - - h.Link(&i1, n2, 0) - - if h.Nodes[1].friends[0].Len() != 1 { - t.Fatalf("expected n1's num friends at level 1 to be 1, got %v", h.Nodes[1].friends[1].Len()) - } - - if h.Nodes[2].friends[0].Len() != 1 { - t.Fatalf("expected n2's num friends at level 1 to be 1, got %v", h.Nodes[1].friends[1].Len()) - } - - peeled, err := h.Nodes[1].friends[0].Peel() - if err != nil { - t.Fatal(err) - } - - if peeled.id != 2 { - t.Fatalf("expected n1 to be friends with n2 at level 1") - } - - peeled, err = h.Nodes[2].friends[0].Peel() - - if err != nil { - t.Fatal(err) - } - - if peeled.id != 1 { - t.Fatalf("expected n1 to be friends with n1 at level 1") - } - - }) - - t.Run("links correctly 2", func(t *testing.T) { - qNode := NewNode(1, []float32{4, 4}, 3) - - h := NewHNSW(2, 1, 23, []float32{0, 0}) - - h.Nodes[qNode.id] = qNode - - friends := [][]float32{ - {2, 2}, {3, 3}, {3.5, 3.5}, - } - - for i, v := range friends { - id := NodeId(i + 2) - h.Nodes[id] = NewNode(id, v, 2) - - if len(h.Nodes[id].friends) != 3 { - t.Fatalf("only initialized so expected qfriend to have size 0 friend map, got: %v", len(h.Nodes[id].friends)) - } - } - - // add some friends for qnode at level 2 - qNode.InsertFriendsAtLevel(2, 2, qNode.VecDistFromNode(h.Nodes[2])) - qNode.InsertFriendsAtLevel(2, 3, qNode.VecDistFromNode(h.Nodes[3])) - qNode.InsertFriendsAtLevel(2, 4, qNode.VecDistFromNode(h.Nodes[4])) - - qFriendsAtLevel2 := qNode.GetFriendsAtLevel(2) - if qFriendsAtLevel2.Len() != 3 { - t.Fatalf("expected qFriendsAtLevel2 to be 3, got %v", qFriendsAtLevel2.Len()) - } - - // we pop since link adds bidirectional - for !qFriendsAtLevel2.IsEmpty() { - peeled, err := qFriendsAtLevel2.Peel() - if err != nil { - t.Fatal(err) - } - - if peeled.id != NodeId(qFriendsAtLevel2.Len()+2) { - t.Fatalf("expected peeled id to be %v got %v", qFriendsAtLevel2.Len()+2, peeled.id) - } - } - - for i, v := range friends { - id := NodeId(i + 2) - dist := qNode.VecDistFromVec(v) - - h.Link(&Item{id: id, dist: dist}, qNode, 2) - - qFriendNode := h.Nodes[id] - friendsAtLevel2 := qFriendNode.GetFriendsAtLevel(2) - - if friendsAtLevel2.Len() != 1 { - t.Fatalf("expected friends at level 2 to be 1, got %v", friendsAtLevel2.Len()) - } - - qFriendFriend, err := friendsAtLevel2.Peel() - if err != nil { - t.Fatal(err) - } - - if qFriendFriend.id != qNode.id { - t.Fatalf("expected friend id at level 2 to be q node 1, got %v", qFriendFriend.id) - } - } - }) -} - -func TestNextNodeId(t *testing.T) { - t.Run("generate next node", func(t *testing.T) { - h := NewHNSW(0, 30, 30, []float32{}) - for i := 0; i <= 100; i++ { - nextNodeId := h.getNextNodeId() - - if nextNodeId != NodeId(i+1) { - t.Fatalf("expected %v, got %v", i+1, nextNodeId) - } - } - }) -} - -func TestFindCloserEntryPoint(t *testing.T) { - t.Run("find nothing closer", func(t *testing.T) { - epNode := NewNode(0, []float32{0, 0}, 10) - h := NewHNSW(2, 32, 32, []float32{0, 0}) - - qVector := []float32{6, 6} - qLevel := h.spawnLevel() - - epItem := &Item{id: 0, dist: epNode.VecDistFromVec(qVector)} - newEpItem := h.findCloserEntryPoint(epItem, qVector, qLevel) - - if epItem.id != newEpItem.id { - t.Fatalf("expected id to be %v, got %v", newEpItem.id, epItem.id) - } - }) - - t.Run("finds something closer traverse all layers", func(t *testing.T) { - ep := NewNode(0, []float32{0, 0}, 10) - h := NewHNSW(2, 32, 32, []float32{0, 0}) - h.Nodes[0] = ep - h.MaxLevel = 10 - - q := []float32{6, 6} - - // suppose we had m := []float{5, 5}. It is closer to q, so let's add m to the friends of ep - - m := NewNode(1, []float32{5, 5}, 9) - h.Nodes[m.id] = m - - for level := 0; level <= 9; level++ { - ep.InsertFriendsAtLevel(level, m.id, m.VecDistFromVec(q)) - } - - epItem := &Item{id: 0, dist: ep.VecDistFromVec(q)} - newEpItem := h.findCloserEntryPoint(epItem, q, 0) - - if epItem.id == newEpItem.id { - t.Fatalf("expected id to be %v, got %v", newEpItem.id, epItem.id) - } - - if newEpItem.id != 1 { - t.Fatalf("expected id to be 1, got %v", newEpItem.id) - } - }) - - t.Run("finds something closer during the insertion context", func(t *testing.T) { - ep := NewNode(0, []float32{0, 0}, 10) - - h := NewHNSW(2, 32, 32, []float32{0, 0}) - h.MaxLevel = 10 - h.Nodes[0] = ep - - q := []float32{6, 6} - qLayer := 3 - - // suppose we had m := []float{5, 5}. It is closer to q, so let's add m to the friends of ep - m := NewNode(1, []float32{5, 5}, 9) - h.Nodes[m.id] = m - mDist := m.VecDistFromVec(q) - - h.Link(&Item{id: m.id, dist: mDist}, h.Nodes[h.EntryNodeId], m.level) - - n := NewNode(2, []float32{6.1, 6.1}, 4) - h.Nodes[n.id] = n - nDist := n.VecDistFromNode(m) - h.Link(&Item{id: n.id, dist: nDist}, m, n.level) - - // verify for entry node's friends - friends := h.Nodes[h.EntryNodeId].friends - if friends[9].IsEmpty() { - t.Fatalf("expected friends to not be empty at level 4, got %v", friends[4].Len()) - } - if friends[9].Peek().id != 1 { - t.Fatalf("expected friend id at level 9 to be %v, got %v", 1, friends[9].Peek().id) - } - - nextFriends := h.Nodes[1].friends - if nextFriends[4].IsEmpty() { - t.Fatalf("expected friends to not be empty at level 4, got %v", friends[4].Len()) - } - - if nextFriends[4].Peek().id != 2 { - t.Fatalf("expected friend id at level 4 to be %v, got %v", 2, friends[4].Peek().id) - } - - epItem := &Item{id: 0, dist: ep.VecDistFromVec(q)} - newEpItem := h.findCloserEntryPoint(epItem, q, qLayer) - - if epItem.id == newEpItem.id { - t.Fatalf("expected id to be %v, got %v", newEpItem.id, epItem.id) - } - - if newEpItem.id != n.id { - t.Fatalf("expected id to be %v, got %v", n.id, newEpItem.id) - } - }) -} - -/* -func TestSpawnLevelDistribution(t *testing.T) { - t.Run("plot distribution", func(t *testing.T) { - h := NewHNSW(2, 12, 4, []float32{0, 0}) - - levels := make(map[int]int) - - for i := 0; i < 1000; i++ { - sLevel := h.spawnLevel() - - if _, ok := levels[sLevel]; ok { - levels[sLevel] += 1 - } else { - levels[sLevel] = 1 - } - } - - numLevels := len(levels) - - if numLevels <= 1 { - t.Fatalf("expected geometric distribution to increase to max layer") - } - - prevCt := levels[numLevels-1] - for level := numLevels - 2; level >= 1; level-- { - currCt := levels[level] - - if prevCt > currCt { - t.Fatalf("level %v has %v nodes. level %v has %v nodes.", level, currCt, level+1, prevCt) - } - - prevCt = currCt - } - }) - - t.Run("spawn nodes", func(t *testing.T) { - h := NewHNSW(2, 12, 4, []float32{0, 0}) - - levels := make(map[int]int) - - for i := 0; i < 1000; i++ { - q := []float32{float32(i), float32(i + 1)} - if err := h.Insert(q); err != nil { - t.Fatal(err) - } - - qNode := h.Nodes[h.NextNodeId-1] - - if !NearlyEqual(float64(qNode.VecDistFromVec(q)), 0) { - t.Fatalf("expected qnode to have id %v, got different vector: %v", qNode.id, qNode.VecDistFromVec(q)) - } - - sLevel := qNode.level - - if _, ok := levels[sLevel]; ok { - levels[sLevel] += 1 - } else { - levels[sLevel] = 1 - } - } - - numLevels := len(levels) - - if numLevels <= 1 { - t.Fatalf("expected geometric distribution to increase to max layer") - } - - prevCt := levels[numLevels-1] - for level := numLevels - 2; level >= 1; level-- { - currCt := levels[level] - - if prevCt > currCt { - t.Fatalf("level %v has %v nodes. level %v has %v nodes.", level, currCt, level+1, prevCt) - } - - prevCt = currCt - } - - fmt.Printf("levels distribution: %v\n", levels) - }) -} -*/ - -func TestHnsw_KnnCluster(t *testing.T) { - - var clusterC = []Vector{ - {0.2, 0.5}, - {0.2, 0.7}, - {0.3, 0.8}, - {0.5, 0.5}, - {0.4, 0.1}, - } - - var clusterCNodes = map[NodeId][]NodeId{ - 1: {2, 4, 3, 5}, - 2: {3, 1, 4, 5}, - 3: {2, 1, 4, 5}, - 4: {1, 2, 3, 5}, - 5: {4, 1, 2, 3}, - } - - var clusterCVisited = map[NodeId][]bool{ - 1: {false, true, true, true, true, true}, - 2: {false, false, true, true, true, true}, - 3: {false, false, false, true, true, true}, - 4: {false, true, true, true, false, true}, - 5: {false, true, true, true, true, false}, - } - - t.Run("cluster c insert", func(t *testing.T) { - h := NewHNSW(2, 4, 4, []float32{0, 0}) - - for i, q := range clusterC { - if err := h.Insert(q); err != nil { - t.Fatalf("failed to insert item %d: %v", i, err) - } - } - - fmt.Printf("%v", h.Nodes) - - if reflect.DeepEqual(h.Nodes, clusterCNodes) { - t.Fatalf("expected all node keys to be the same as clusterC") - } - - if len(h.Nodes) != 6 { - t.Fatalf("expected 6 nodes, got %d", len(h.Nodes)) - } - - for i := 1; i <= 5; i++ { - nodeId := NodeId(i) - node := h.Nodes[nodeId] - - var nodeNN []NodeId - visitedNN := make([]bool, 6) // counting entry - - for level := node.level; level >= 0; level-- { - friendsAtLevel := node.friends[level] - - for !friendsAtLevel.IsEmpty() { - peeled, err := friendsAtLevel.Peel() - if err != nil { - t.Fatal(err) - } - - if !visitedNN[peeled.id] { - nodeNN = append(nodeNN, peeled.id) - visitedNN[peeled.id] = true - } - } - } - - if reflect.DeepEqual(clusterCVisited[nodeId], visitedNN) { - t.Fatalf("expected all node keys to be the same as clusterC") - } - } - - }) - -} diff --git a/pkg/hnsw/node.go b/pkg/hnsw/node.go deleted file mode 100644 index 3de89a1..0000000 --- a/pkg/hnsw/node.go +++ /dev/null @@ -1,119 +0,0 @@ -package hnsw - -import ( - "math" -) - -type Vector []float32 - -type NodeId = uint32 - -type Node struct { - // id is very special. It is sequential, with the 0-id reserved for the entry point node. - // We need id to be sequential because we build the bitset with the assumption that every id is unique and sequential. - id NodeId - v Vector - - level int - - // for every level, we have a list of friends' NodeIds - friends []*BaseQueue -} - -func NewNode(id NodeId, v Vector, level int) *Node { - - friends := make([]*BaseQueue, level+1) - - for i := range friends { - friends[i] = NewBaseQueue(MinComparator{}) - } - - return &Node{ - id, - v, - level, - friends, - } -} - -// Must assert with HasLevel first -func (n0 *Node) InsertFriendsAtLevel(level int, id NodeId, dist float32) { - n0.friends[int(level)].Insert(id, dist) -} - -func (n0 *Node) HasLevel(level int) bool { - if level < 0 { - panic("level cannot be negative") - } - - return len(n0.friends)-1 >= level -} - -func (n0 *Node) GetFriendsAtLevel(level int) *BaseQueue { - return n0.friends[level] -} - -func (n0 *Node) VecDistFromVec(v1 Vector) float32 { - v0 := n0.v - - return EuclidDist(v0, v1) -} - -func (n0 *Node) VecDistFromNode(n1 *Node) float32 { - // pull vec from nodes - v0 := n0.v - v1 := n1.v - - return EuclidDist(v0, v1) -} - -func EuclidDist(v0, v1 Vector) float32 { - // check if vector dimensionality is correct - if len(v0) != len(v1) { - panic("invalid lengths") - } - - var sum float32 - - for i := range v0 { - delta := v0[i] - v1[i] - sum += delta * delta - } - - return float32(math.Sqrt(float64(sum))) -} - -// NearlyEqual is sourced from scalar package written by gonum -// https://pkg.go.dev/gonum.org/v1/gonum/floats/scalar#EqualWithinAbsOrRel -func NearlyEqual(a, b float64) bool { - return EqualWithinAbs(a, b) || EqualWithinRel(a, b) -} - -// EqualWithinAbs returns true when a and b have an absolute difference -// not greater than tol. -func EqualWithinAbs(a, b float64) bool { - return a == b || math.Abs(a-b) <= 1e-6 -} - -// minNormalFloat64 is the smallest normal number. For 64 bit IEEE-754 -// floats this is 2^{-1022}. -const minNormalFloat64 = 0x1p-1022 - -// EqualWithinRel returns true when the difference between a and b -// is not greater than tol times the greater absolute value of a and b, -// -// abs(a-b) <= tol * max(abs(a), abs(b)). -func EqualWithinRel(a, b float64) bool { - if a == b { - return true - } - delta := math.Abs(a - b) - if delta <= minNormalFloat64 { - return delta <= 1e-6*minNormalFloat64 - } - // We depend on the division in this relationship to identify - // infinities (we rely on the NaN to fail the test) otherwise - // we compare Infs of the same sign and evaluate Infs as equal - // independent of sign. - return delta/math.Max(math.Abs(a), math.Abs(b)) <= 1e-6 -} diff --git a/pkg/hnsw/node_test.go b/pkg/hnsw/node_test.go deleted file mode 100644 index 4fcdcd8..0000000 --- a/pkg/hnsw/node_test.go +++ /dev/null @@ -1,114 +0,0 @@ -package hnsw - -import ( - "math" - "testing" -) - -func TestWithinLevels(t *testing.T) { - t.Run("levels are in bounds", func(t *testing.T) { - n := NewNode(3, []float32{3, 6, 9}, 3) - - n.friends[0] = NewBaseQueue(MinComparator{}) - n.friends[1] = NewBaseQueue(MinComparator{}) - n.friends[2] = NewBaseQueue(MinComparator{}) - - for i := 0; i < 3; i++ { - if !n.HasLevel(i) { - t.Fatalf("since n's max level is %v, all levels less should be true", n.level) - } - } - - if n.HasLevel(3 + 1) { - t.Fatalf("since n's max level is %v, levels greater is not in bounds", n.level) - } - }) -} - -func TestVec(t *testing.T) { - - type t_case struct { - u, v []float32 - expected float64 - } - - bank := [7]t_case{ - { - u: []float32{5, 3, 0}, - v: []float32{2, -2, float32(math.Sqrt(2))}, - expected: 6, - }, - { - u: []float32{1, 0, -5}, - v: []float32{-3, 2, -1}, - expected: 6, - }, - { - u: []float32{1, 3}, - v: []float32{5, 2}, - expected: math.Sqrt(17), - }, - { - u: []float32{0, 1, 4}, - v: []float32{2, 9, 1}, - expected: math.Sqrt(77), - }, - { - u: []float32{0}, - v: []float32{0}, - expected: 0, - }, - { - u: []float32{10, 20, 30, 40}, - v: []float32{10, 20, 30, 40}, - expected: 0, - }, - } - - t.Run("correctly computes the dist from node", func(t *testing.T) { - for i, bank := range bank { - - if !NearlyEqual(bank.expected, float64(EuclidDist(bank.u, bank.v))) { - t.Fatalf("err at %v, expected %v, got %v", i, bank.expected, EuclidDist(bank.u, bank.v)) - } - } - }) - - t.Run("symmetric", func(t *testing.T) { - for i, bank := range bank { - - if !NearlyEqual(float64(EuclidDist(bank.v, bank.u)), float64(EuclidDist(bank.u, bank.v))) { - t.Fatalf("err at %v, expected %v, got %v", i, bank.expected, EuclidDist(bank.u, bank.v)) - } - } - }) -} - -func TestNodeFriends(t *testing.T) { - t.Run("initialized with correct # of levels", func(t *testing.T) { - h := NewHNSW(2, 32, 32, []float32{3, 4}) - qLayer := h.spawnLevel() - qNode := NewNode(1, []float32{3, 1}, qLayer) - - if len(qNode.friends) != qLayer+1 { - t.Fatalf("expected the friends list to initialize to %v levels, got %v", qLayer+1, len(qNode.friends)) - } - }) - - t.Run("correctly determines if has layer", func(t *testing.T) { - qNode := NewNode(10, []float32{3, 1, 0.3, 9.2}, 100) - - if !qNode.HasLevel(100) { - t.Fatalf("expected qNode to have level %v", 100) - } - - if qNode.HasLevel(101) { - t.Fatalf("expected qNode to not have level %v", 101) - } - - if !qNode.HasLevel(0) { - t.Fatalf("expected qNode to have level %v", 0) - } - }) - -} diff --git a/pkg/hnsw/eucqueue.go b/pkg/hnsw/pq.go similarity index 90% rename from pkg/hnsw/eucqueue.go rename to pkg/hnsw/pq.go index 5587d04..0a86445 100644 --- a/pkg/hnsw/eucqueue.go +++ b/pkg/hnsw/pq.go @@ -24,26 +24,26 @@ func (c MinComparator) Less(i, j *Item) bool { } type Item struct { - id NodeId + id Id dist float32 index int } type Heapy interface { heap.Interface - Insert(id NodeId, dist float32) + Insert(id Id, dist float32) IsEmpty() bool Len() int Peel() *Item Peek() *Item Take(count int) (*BaseQueue, error) - update(item *Item, id NodeId, dist float32) + update(item *Item, id Id, dist float32) } // Nothing from BaseQueue should be used. Only use the Max and Min queue. // BaseQueue isn't even a heap! It misses the Less() method which the Min/Max queue implement. type BaseQueue struct { - visitedIds map[NodeId]*Item + visitedIds map[Id]*Item items []*Item comparator Comparator } @@ -114,7 +114,7 @@ func (bq *BaseQueue) Less(i, j int) bool { return bq.comparator.Less(bq.items[i], bq.items[j]) } -func (bq *BaseQueue) Insert(id NodeId, dist float32) { +func (bq *BaseQueue) Insert(id Id, dist float32) { if item, ok := bq.visitedIds[id]; ok { bq.update(item, id, dist) return @@ -138,7 +138,7 @@ func FromBaseQueue(items []*Item, comparator Comparator) *BaseQueue { func NewBaseQueue(comparator Comparator) *BaseQueue { bq := &BaseQueue{ - visitedIds: map[NodeId]*Item{}, + visitedIds: map[Id]*Item{}, comparator: comparator, } heap.Init(bq) @@ -154,7 +154,7 @@ func (bq *BaseQueue) Peel() (*Item, error) { return popped, nil } -func (bq *BaseQueue) update(item *Item, id NodeId, dist float32) { +func (bq *BaseQueue) update(item *Item, id Id, dist float32) { item.id = id item.dist = dist heap.Fix(bq, item.index) diff --git a/pkg/hnsw/vector.go b/pkg/hnsw/vector.go new file mode 100644 index 0000000..63fbc22 --- /dev/null +++ b/pkg/hnsw/vector.go @@ -0,0 +1,122 @@ +package hnsw + +import ( + "errors" + "math" +) + +type Point []float32 + +type Vector struct { + id Id + point Point + + friends []*BaseQueue +} + +// NewVector creates a new vector, note the max level is inclusive. +func NewVector(id Id, point Point, maxLevel int) *Vector { + friends := make([]*BaseQueue, 0) + + for i := 0; i < maxLevel; i++ { + bq := NewBaseQueue(MinComparator{}) + friends = append(friends, bq) + } + + return &Vector{ + id: id, + point: point, + friends: friends, + } +} + +func (v *Vector) Levels() int { + return len(v.friends) +} + +func (v *Vector) MaxLevel() int { + return len(v.friends) - 1 +} + +func (v *Vector) HasLevel(level int) bool { + if level < 0 { + panic("level must be nonzero positive integer") + } + + return level <= v.MaxLevel() +} + +// InsertFriendsAtLevel requires level must be zero-indexed +func (v *Vector) InsertFriendsAtLevel(level int, friend *Vector) { + if !v.HasLevel(level) { + panic("failed to insert friends at level, as level is not valId") + } + + if friend.id == v.id { + panic("cannot insert yourself to friends list") + } + + dist := v.EuclidDistance(friend) + + for i := 0; i <= level; i++ { + v.friends[level].Insert(friend.id, dist) + } +} + +func (v *Vector) GetFriendsAtLevel(level int) (*BaseQueue, error) { + if !v.HasLevel(level) { + return nil, errors.New("failed to get friends at level") + } + + return v.friends[level], nil +} + +func (v *Vector) EuclidDistance(v1 *Vector) float32 { + return v.EuclidDistanceFromPoint(v1.point) +} + +func (v *Vector) EuclidDistanceFromPoint(point Point) float32 { + var sum float32 + + for i := range v.point { + delta := v.point[i] - point[i] + sum += delta * delta + } + + return float32(math.Sqrt(float64(sum))) +} + +// NearlyEqual is sourced from scalar package written by gonum +// https://pkg.go.dev/gonum.org/v1/gonum/floats/scalar#EqualWithinAbsOrRel +func NearlyEqual(a, b float32) bool { + return EqualWithinAbs(float64(a), float64(b)) || EqualWithinRel(float64(a), float64(b)) +} + +// EqualWithinAbs returns true when a and b have an absolute difference +// not greater than tol. +func EqualWithinAbs(a, b float64) bool { + return a == b || math.Abs(a-b) <= 1e-6 +} + +// minNormalFloat64 is the smallest normal number. For 64 bit IEEE-754 +// floats this is 2^{-1022}. +const minNormalFloat64 = 0x1p-1022 + +// EqualWithinRel returns true when the difference between a and b +// is not greater than tol times the greater absolute value of a and b, +// +// abs(a-b) <= tol * max(abs(a), abs(b)). +func EqualWithinRel(a, b float64) bool { + if a == b { + return true + } + delta := math.Abs(a - b) + if delta <= minNormalFloat64 { + return delta <= 1e-6*minNormalFloat64 + } + // We depend on the division in this relationship to Identify + // infinities (we rely on the NaN to fail the test) otherwise + // we compare Infs of the same sign and evaluate Infs as equal + // independent of sign. + return delta/math.Max(math.Abs(a), math.Abs(b)) <= 1e-6 +} diff --git a/pkg/hnsw/vector_test.go b/pkg/hnsw/vector_test.go new file mode 100644 index 0000000..35d836e --- /dev/null +++ b/pkg/hnsw/vector_test.go @@ -0,0 +1,123 @@ +package hnsw + +import ( + "math" + "testing" +) + +func TestVector_LevelManagement(t *testing.T) { + + /* + hex has 6 layers from [0..5] + oct has 8 layers from [0..8] + */ + t.Run("check levels for oct and hex vectors", func(t *testing.T) { + hexId := Id(1) + hex := NewVector(hexId, []float32{9, 2.0, 30}, 6) + + if hex.MaxLevel() != 5 { + t.Fatalf("since 0-indexed, the max level is 5, got: %v", hex.MaxLevel()) + } + + if hex.Levels() != 6 { + t.Fatalf("since 0-indexed, the number of levels is 6, got: %v", hex.Levels()) + } + + octId := Id(2) + oct := NewVector(octId, []float32{0, 2, 3}, 8) + + if oct.MaxLevel() != 7 { + t.Fatalf("since 0-indexed, the max level is 7, got: %v", hex.MaxLevel()) + } + + if hex.Levels() != 8 { + t.Fatalf("since 0-indexed, the number of levels is 8, got: %v", hex.Levels()) + } + + for i := 0; i <= 5; i++ { + if !hex.HasLevel(i) { + t.Fatalf("since 0-indexed, the level #%v is missing", i) + } + } + + for i := 6; i <= 8; i++ { + if hex.HasLevel(i) { + t.Fatalf("since 0-indexed, expected the level #%v to be missing", i) + } + } + + hex.InsertFriendsAtLevel(5, oct) + oct.InsertFriendsAtLevel(5, hex) + + for i := 0; i <= 5; i++ { + hexFriends, _ := hex.GetFriendsAtLevel(i) + octFriends, _ := oct.GetFriendsAtLevel(i) + + if hexFriends.Len() != 1 || octFriends.Len() != 1 { + t.Fatalf("expected hex and oct friends list at level %v to be 1, got: %v || %v", i, hexFriends.Len(), octFriends.Len()) + } + + top := hexFriends.Peek() + if top.id != octId { + t.Fatalf("expected %v, got %v", octId, top.id) + } + + top = octFriends.Peek() + if top.id != hexId { + t.Fatalf("expected %v, got %v", hexId, top.id) + } + } + }) + +} + +func TestVector_EuclidDistance(t *testing.T) { + + type vectorPair struct { + v0, v1 *Vector + expected float32 + } + + basic := []vectorPair{ + { + v0: NewVector(0, []float32{5, 3, 0}, 4), + v1: NewVector(1, []float32{2, -2, float32(math.Sqrt(2))}, 4), + expected: 6, + }, + { + v0: NewVector(1, []float32{1, 0, -5}, 3), + v1: NewVector(2, []float32{-3, 2, -1}, 3), + expected: 6, + }, + { + v0: NewVector(1, []float32{1, 3}, 20), + v1: NewVector(1, []float32{5, 2}, 120), + expected: float32(math.Sqrt(17)), + }, + { + v0: NewVector(1, []float32{0, 1, 4}, 10), + v1: NewVector(2, []float32{2, 9, 1}, 100), + expected: float32(math.Sqrt(77)), + }, + { + v0: NewVector(1, []float32{0}, 9), + v1: NewVector(2, []float32{0}, 8), + expected: 0, + }, + { + v0: NewVector(1, []float32{10, 20, 30, 40}, 4), + v1: NewVector(2, []float32{10, 20, 30, 40}, 3), + expected: 0, + }, + } + + t.Run("correctly computes the distance of two vectors", func(t *testing.T) { + for i, pair := range basic { + dist := pair.v0.EuclidDistance(pair.v1) + + if !NearlyEqual(dist, pair.expected) { + t.Fatalf("iter i: %v, expected %v and %v to be equal", i, dist, pair.expected) + } + } + }) +} From 7ee02165ffed50c6f540ea463f929708ece19364 Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Thu, 30 May 2024 14:21:05 -0400 Subject: [PATCH 2/6] handle error --- pkg/hnsw/vector_test.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pkg/hnsw/vector_test.go b/pkg/hnsw/vector_test.go index 35d836e..a9d7ed0 100644 --- a/pkg/hnsw/vector_test.go +++ b/pkg/hnsw/vector_test.go @@ -50,8 +50,15 @@ func TestVector_LevelManagement(t *testing.T) { oct.InsertFriendsAtLevel(5, hex) for i := 0; i <= 5; i++ { - hexFriends, _ := hex.GetFriendsAtLevel(i) - octFriends, _ := oct.GetFriendsAtLevel(i) + hexFriends, err := hex.GetFriendsAtLevel(i) + if err != nil { + t.Fatal(err) + } + + octFriends, err := oct.GetFriendsAtLevel(i) + if err != nil { + t.Fatal(err) + } if hexFriends.Len() != 1 || octFriends.Len() != 1 { t.Fatalf("expected hex and oct friends list at level %v to be 1, got: %v || %v", i, hexFriends.Len(), octFriends.Len()) From 363015a43de24fd1fe3247e2977924dc911e8327 Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Thu, 30 May 2024 19:09:10 -0400 Subject: [PATCH 3/6] feat: vector and priority queue assert --- pkg/hnsw/pq_test.go | 105 ++++++++++++++++++++++++++++++++++++++++ pkg/hnsw/vector.go | 12 ++--- pkg/hnsw/vector_test.go | 45 ++++++----------- 3 files changed, 126 insertions(+), 36 deletions(-) create mode 100644 pkg/hnsw/pq_test.go diff --git a/pkg/hnsw/pq_test.go b/pkg/hnsw/pq_test.go new file mode 100644 index 0000000..bce3413 --- /dev/null +++ b/pkg/hnsw/pq_test.go @@ -0,0 +1,105 @@ +package hnsw + +import ( + "testing" +) + +func TestPQ(t *testing.T) { + + t.Run("bricks and ladders || min heap", func(t *testing.T) { + type Case struct { + heights []int + bricks int + ladders int + expected int + } + + cases := [3]Case{ + { + heights: []int{4, 2, 7, 6, 9, 14, 12}, + bricks: 5, + ladders: 1, + expected: 4, + }, + { + heights: []int{4, 12, 2, 7, 3, 18, 20, 3, 19}, + bricks: 10, + ladders: 2, + expected: 7, + }, + { + heights: []int{14, 3, 19, 3}, + bricks: 17, + ladders: 0, + expected: 3, + }, + } + + for _, c := range cases { + res, err := furthestBuildings(c.heights, c.bricks, c.ladders) + if err != nil { + t.Fatal(err) + } + + if res != c.expected { + t.Errorf("got %d, want %d", res, c.expected) + } + } + + }) +} + +func furthestBuildings(heights []int, bricks, ladders int) (int, error) { + + ladderJumps := NewBaseQueue(MinComparator{}) + + for idx := 0; idx < len(heights)-1; idx++ { + height := heights[idx] + nextHeight := heights[idx+1] + + if height >= nextHeight { + continue + } + + jump := nextHeight - height + + ladderJumps.Insert(Id(idx), float32(jump)) + + if ladderJumps.Len() > ladders { + minLadderJump, err := ladderJumps.Peel() + if err != nil { + return -1, err + } + + if bricks-int(minLadderJump.dist) < 0 { + return idx, nil + } + + bricks -= int(minLadderJump.dist) + } + } + + return len(heights) - 1, nil +} + +/* + + + + + + + + + + + + + + + + + + + + */ diff --git a/pkg/hnsw/vector.go b/pkg/hnsw/vector.go index 63fbc22..e6583a7 100644 --- a/pkg/hnsw/vector.go +++ b/pkg/hnsw/vector.go @@ -16,11 +16,10 @@ type Vector struct { // NewVector creates a new vector, note the max level is inclusive. func NewVector(id Id, point Point, maxLevel int) *Vector { - friends := make([]*BaseQueue, 0) + friends := make([]*BaseQueue, maxLevel+1) - for i := 0; i < maxLevel; i++ { - bq := NewBaseQueue(MinComparator{}) - friends = append(friends, bq) + for i := 0; i <= maxLevel; i++ { + friends[i] = NewBaseQueue(MinComparator{}) } return &Vector{ @@ -48,7 +47,7 @@ func (v *Vector) HasLevel(level int) bool { // InsertFriendsAtLevel requires level must be zero-indexed func (v *Vector) InsertFriendsAtLevel(level int, friend *Vector) { - if !v.HasLevel(level) { + if !v.HasLevel(level) || !friend.HasLevel(level) { panic("failed to insert friends at level, as level is not valId") } @@ -59,7 +58,8 @@ func (v *Vector) InsertFriendsAtLevel(level int, friend *Vector) { dist := v.EuclidDistance(friend) for i := 0; i <= level; i++ { - v.friends[level].Insert(friend.id, dist) + v.friends[i].Insert(friend.id, dist) + friend.friends[i].Insert(v.id, dist) } } diff --git a/pkg/hnsw/vector_test.go b/pkg/hnsw/vector_test.go index a9d7ed0..b16e072 100644 --- a/pkg/hnsw/vector_test.go +++ b/pkg/hnsw/vector_test.go @@ -15,63 +15,48 @@ func TestVector_LevelManagement(t *testing.T) { hexId := Id(1) hex := NewVector(hexId, []float32{9, 2.0, 30}, 6) - if hex.MaxLevel() != 5 { + if hex.MaxLevel() != 6 { t.Fatalf("since 0-indexed, the max level is 5, got: %v", hex.MaxLevel()) } - if hex.Levels() != 6 { + if hex.Levels() != 7 { t.Fatalf("since 0-indexed, the number of levels is 6, got: %v", hex.Levels()) } octId := Id(2) oct := NewVector(octId, []float32{0, 2, 3}, 8) - if oct.MaxLevel() != 7 { + if oct.MaxLevel() != 8 { t.Fatalf("since 0-indexed, the max level is 7, got: %v", hex.MaxLevel()) } - if hex.Levels() != 8 { + if oct.Levels() != 9 { t.Fatalf("since 0-indexed, the number of levels is 8, got: %v", hex.Levels()) } - for i := 0; i <= 5; i++ { + for i := 0; i <= 6; i++ { if !hex.HasLevel(i) { t.Fatalf("since 0-indexed, the level #%v is missing", i) } } - for i := 6; i <= 8; i++ { + for i := 7; i <= 8; i++ { if hex.HasLevel(i) { t.Fatalf("since 0-indexed, expected the level #%v to be missing", i) } } hex.InsertFriendsAtLevel(5, oct) - oct.InsertFriendsAtLevel(5, hex) - for i := 0; i <= 5; i++ { - hexFriends, err := hex.GetFriendsAtLevel(i) - if err != nil { - t.Fatal(err) - } - - octFriends, err := oct.GetFriendsAtLevel(i) - if err != nil { - t.Fatal(err) - } - - if hexFriends.Len() != 1 || octFriends.Len() != 1 { - t.Fatalf("expected hex and oct friends list at level %v to be 1, got: %v || %v", i, hexFriends.Len(), octFriends.Len()) - } - - top := hexFriends.Peek() - if top.id != octId { - t.Fatalf("expected %v, got %v", octId, top.id) - } - - top = octFriends.Peek() - if top.id != hexId { - t.Fatalf("expected %v, got %v", hexId, top.id) + for level, friends := range hex.friends { + if level <= 5 { + if friends.Len() != 1 { + t.Fatalf("expected 1 item, got %v", friends.Len()) + } + } else { + if friends.Len() != 0 { + t.Fatalf("expected 0 items, got %v", friends.Len()) + } } } }) From 0c1121ab513bbc49ac034f55b912bfa41784baff Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Thu, 30 May 2024 19:12:17 -0400 Subject: [PATCH 4/6] check both vectors --- pkg/hnsw/vector_test.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/pkg/hnsw/vector_test.go b/pkg/hnsw/vector_test.go index b16e072..0ef80e5 100644 --- a/pkg/hnsw/vector_test.go +++ b/pkg/hnsw/vector_test.go @@ -53,6 +53,30 @@ func TestVector_LevelManagement(t *testing.T) { if friends.Len() != 1 { t.Fatalf("expected 1 item, got %v", friends.Len()) } + + friend := friends.Peek() + + if friend.id != octId { + t.Fatalf("expected %v, got %v", octId, friend.id) + } + + } else { + if friends.Len() != 0 { + t.Fatalf("expected 0 items, got %v", friends.Len()) + } + } + } + + for level, friends := range oct.friends { + if level <= 5 { + if friends.Len() != 1 { + t.Fatalf("expected 1 item, got %v", friends.Len()) + } + friend := friends.Peek() + + if friend.id != hexId { + t.Fatalf("expected %v, got %v", hexId, friend.id) + } } else { if friends.Len() != 0 { t.Fatalf("expected 0 items, got %v", friends.Len()) From 90c33b9ca1ba1f7560f35eb5864cecb83e575c19 Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Mon, 3 Jun 2024 09:59:06 -0400 Subject: [PATCH 5/6] fmt --- pkg/hnsw/hnsw.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pkg/hnsw/hnsw.go b/pkg/hnsw/hnsw.go index 5eeddd2..98d0508 100644 --- a/pkg/hnsw/hnsw.go +++ b/pkg/hnsw/hnsw.go @@ -8,7 +8,7 @@ type Id = uint type Hnsw struct { vectorDimensionality int - + Vectors map[Id]*Friends normFactorForLevelGeneration int @@ -37,7 +37,6 @@ func NewHnsw(d int, efConstruction uint, M, mmax, mmax0 int) *Hnsw { } } - func (h *Hnsw) InsertVector(q Point) error { if !h.validatePoint(q) { return fmt.Errorf("invalidvector") From 264403893213e12d80d7fd60c0024a6a96c9f329 Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Mon, 3 Jun 2024 10:00:02 -0400 Subject: [PATCH 6/6] remove vector.go --- pkg/hnsw/vector.go | 122 ----------------------------------- pkg/hnsw/vector_test.go | 139 ---------------------------------------- 2 files changed, 261 deletions(-) delete mode 100644 pkg/hnsw/vector.go delete mode 100644 pkg/hnsw/vector_test.go diff --git a/pkg/hnsw/vector.go b/pkg/hnsw/vector.go deleted file mode 100644 index e6583a7..0000000 --- a/pkg/hnsw/vector.go +++ /dev/null @@ -1,122 +0,0 @@ -package hnsw - -import ( - "errors" - "math" -) - -type Point []float32 - -type Vector struct { - id Id - point Point - - friends []*BaseQueue -} - -// NewVector creates a new vector, note the max level is inclusive. -func NewVector(id Id, point Point, maxLevel int) *Vector { - friends := make([]*BaseQueue, maxLevel+1) - - for i := 0; i <= maxLevel; i++ { - friends[i] = NewBaseQueue(MinComparator{}) - } - - return &Vector{ - id: id, - point: point, - friends: friends, - } -} - -func (v *Vector) Levels() int { - return len(v.friends) -} - -func (v *Vector) MaxLevel() int { - return len(v.friends) - 1 -} - -func (v *Vector) HasLevel(level int) bool { - if level < 0 { - panic("level must be nonzero positive integer") - } - - return level <= v.MaxLevel() -} - -// InsertFriendsAtLevel requires level must be zero-indexed -func (v *Vector) InsertFriendsAtLevel(level int, friend *Vector) { - if !v.HasLevel(level) || !friend.HasLevel(level) { - panic("failed to insert friends at level, as level is not valId") - } - - if friend.id == v.id { - panic("cannot insert yourself to friends list") - } - - dist := v.EuclidDistance(friend) - - for i := 0; i <= level; i++ { - v.friends[i].Insert(friend.id, dist) - friend.friends[i].Insert(v.id, dist) - } -} - -func (v *Vector) GetFriendsAtLevel(level int) (*BaseQueue, error) { - if !v.HasLevel(level) { - return nil, errors.New("failed to get friends at level") - } - - return v.friends[level], nil -} - -func (v *Vector) EuclidDistance(v1 *Vector) float32 { - return v.EuclidDistanceFromPoint(v1.point) -} - -func (v *Vector) EuclidDistanceFromPoint(point Point) float32 { - var sum float32 - - for i := range v.point { - delta := v.point[i] - point[i] - sum += delta * delta - } - - return float32(math.Sqrt(float64(sum))) -} - -// NearlyEqual is sourced from scalar package written by gonum -// https://pkg.go.dev/gonum.org/v1/gonum/floats/scalar#EqualWithinAbsOrRel -func NearlyEqual(a, b float32) bool { - return EqualWithinAbs(float64(a), float64(b)) || EqualWithinRel(float64(a), float64(b)) -} - -// EqualWithinAbs returns true when a and b have an absolute difference -// not greater than tol. -func EqualWithinAbs(a, b float64) bool { - return a == b || math.Abs(a-b) <= 1e-6 -} - -// minNormalFloat64 is the smallest normal number. For 64 bit IEEE-754 -// floats this is 2^{-1022}. -const minNormalFloat64 = 0x1p-1022 - -// EqualWithinRel returns true when the difference between a and b -// is not greater than tol times the greater absolute value of a and b, -// -// abs(a-b) <= tol * max(abs(a), abs(b)). -func EqualWithinRel(a, b float64) bool { - if a == b { - return true - } - delta := math.Abs(a - b) - if delta <= minNormalFloat64 { - return delta <= 1e-6*minNormalFloat64 - } - // We depend on the division in this relationship to Identify - // infinities (we rely on the NaN to fail the test) otherwise - // we compare Infs of the same sign and evaluate Infs as equal - // independent of sign. - return delta/math.Max(math.Abs(a), math.Abs(b)) <= 1e-6 -} diff --git a/pkg/hnsw/vector_test.go b/pkg/hnsw/vector_test.go deleted file mode 100644 index 0ef80e5..0000000 --- a/pkg/hnsw/vector_test.go +++ /dev/null @@ -1,139 +0,0 @@ -package hnsw - -import ( - "math" - "testing" -) - -func TestVector_LevelManagement(t *testing.T) { - - /* - hex has 6 layers from [0..5] - oct has 8 layers from [0..8] - */ - t.Run("check levels for oct and hex vectors", func(t *testing.T) { - hexId := Id(1) - hex := NewVector(hexId, []float32{9, 2.0, 30}, 6) - - if hex.MaxLevel() != 6 { - t.Fatalf("since 0-indexed, the max level is 5, got: %v", hex.MaxLevel()) - } - - if hex.Levels() != 7 { - t.Fatalf("since 0-indexed, the number of levels is 6, got: %v", hex.Levels()) - } - - octId := Id(2) - oct := NewVector(octId, []float32{0, 2, 3}, 8) - - if oct.MaxLevel() != 8 { - t.Fatalf("since 0-indexed, the max level is 7, got: %v", hex.MaxLevel()) - } - - if oct.Levels() != 9 { - t.Fatalf("since 0-indexed, the number of levels is 8, got: %v", hex.Levels()) - } - - for i := 0; i <= 6; i++ { - if !hex.HasLevel(i) { - t.Fatalf("since 0-indexed, the level #%v is missing", i) - } - } - - for i := 7; i <= 8; i++ { - if hex.HasLevel(i) { - t.Fatalf("since 0-indexed, expected the level #%v to be missing", i) - } - } - - hex.InsertFriendsAtLevel(5, oct) - - for level, friends := range hex.friends { - if level <= 5 { - if friends.Len() != 1 { - t.Fatalf("expected 1 item, got %v", friends.Len()) - } - - friend := friends.Peek() - - if friend.id != octId { - t.Fatalf("expected %v, got %v", octId, friend.id) - } - - } else { - if friends.Len() != 0 { - t.Fatalf("expected 0 items, got %v", friends.Len()) - } - } - } - - for level, friends := range oct.friends { - if level <= 5 { - if friends.Len() != 1 { - t.Fatalf("expected 1 item, got %v", friends.Len()) - } - friend := friends.Peek() - - if friend.id != hexId { - t.Fatalf("expected %v, got %v", hexId, friend.id) - } - } else { - if friends.Len() != 0 { - t.Fatalf("expected 0 items, got %v", friends.Len()) - } - } - } - }) - -} - -func TestVector_EuclidDistance(t *testing.T) { - - type vectorPair struct { - v0, v1 *Vector - expected float32 - } - - basic := []vectorPair{ - { - v0: NewVector(0, []float32{5, 3, 0}, 4), - v1: NewVector(1, []float32{2, -2, float32(math.Sqrt(2))}, 4), - expected: 6, - }, - { - v0: NewVector(1, []float32{1, 0, -5}, 3), - v1: NewVector(2, []float32{-3, 2, -1}, 3), - expected: 6, - }, - { - v0: NewVector(1, []float32{1, 3}, 20), - v1: NewVector(1, []float32{5, 2}, 120), - expected: float32(math.Sqrt(17)), - }, - { - v0: NewVector(1, []float32{0, 1, 4}, 10), - v1: NewVector(2, []float32{2, 9, 1}, 100), - expected: float32(math.Sqrt(77)), - }, - { - v0: NewVector(1, []float32{0}, 9), - v1: NewVector(2, []float32{0}, 8), - expected: 0, - }, - { - v0: NewVector(1, []float32{10, 20, 30, 40}, 4), - v1: NewVector(2, []float32{10, 20, 30, 40}, 3), - expected: 0, - }, - } - - t.Run("correctly computes the distance of two vectors", func(t *testing.T) { - for i, pair := range basic { - dist := pair.v0.EuclidDistance(pair.v1) - - if !NearlyEqual(dist, pair.expected) { - t.Fatalf("iter i: %v, expected %v and %v to be equal", i, dist, pair.expected) - } - } - }) -}