From 3696088930665864bb05ab6f4c4fcb4371f39ced Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Tue, 11 Jun 2024 14:56:20 -0400 Subject: [PATCH] assert every point's friends queue has correct max size --- pkg/hnsw/hnsw.go | 27 ++++++++++++++++++--------- pkg/hnsw/hnsw_test.go | 30 +++++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/pkg/hnsw/hnsw.go b/pkg/hnsw/hnsw.go index 99c9d9b..8cd14b0 100644 --- a/pkg/hnsw/hnsw.go +++ b/pkg/hnsw/hnsw.go @@ -207,30 +207,39 @@ func (h *Hnsw) InsertVector(q Point) error { // add bidirectional connections from neighbors to q at layer c for _, neighbor := range neighbors { neighborPoint := h.points[neighbor.id] - distNeighToQ := EuclidDistance(*neighborPoint, q) - h.friends[neighbor.id].InsertFriendsAtLevel(level, qId, distNeighToQ) h.friends[qId].InsertFriendsAtLevel(level, neighbor.id, distNeighToQ) } for _, neighbor := range neighbors { neighborFriendsAtLevel, err := h.friends[neighbor.id].GetFriendsAtLevel(level) - if err != nil { return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) } - maxNeighborsFriendsAtLevel := FromBaseQueue(neighborFriendsAtLevel, MaxComparator{}) + maxNumberOfNeighbors := h.M + if level == 0 { + maxNumberOfNeighbors = h.Mmax0 + } + + eConnections := neighborFriendsAtLevel + if eConnections.Len() > maxNumberOfNeighbors { + var items []*Item + + for !neighborFriendsAtLevel.IsEmpty() { + nearestNeighborFriendAtLevelItem, err := neighborFriendsAtLevel.PopItem() + if err != nil { + return fmt.Errorf("failed to pop from neighborFriendsAtLevel: %w", err) + } - for maxNeighborsFriendsAtLevel.Len() > h.M { - _, err = maxNeighborsFriendsAtLevel.PopItem() - if err != nil { - return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err) + items = append(items, nearestNeighborFriendAtLevelItem) } + + eConnections = FromItems(items, MinComparator{}) } - h.friends[neighbor.id].friends[level] = FromBaseQueue(maxNeighborsFriendsAtLevel, MinComparator{}) + h.friends[neighbor.id].friends[level] = eConnections } newEntryItem, err := nnToQAtLevel.PopItem() diff --git a/pkg/hnsw/hnsw_test.go b/pkg/hnsw/hnsw_test.go index ca36756..6e0d0ec 100644 --- a/pkg/hnsw/hnsw_test.go +++ b/pkg/hnsw/hnsw_test.go @@ -436,7 +436,7 @@ func TestHnsw_InsertVector(t *testing.T) { t.Run("bulk insert", func(t *testing.T) { items := 1 - h := NewHnsw(3, 4, 4, Point{0, 0, 0}) + h := NewHnsw(3, 4, 10, Point{0, 0, 0}) for i := 100; i >= 1; i-- { j := float32(i) @@ -465,6 +465,34 @@ func TestHnsw_InsertVector(t *testing.T) { items += 1 } + + // ensure every friend pq is of max length 4 + var allNodeIds []Id + for id := range h.friends { + allNodeIds = append(allNodeIds, id) + } + + for _, nodeId := range allNodeIds { + nodeFriends, ok := h.friends[nodeId] + if !ok { + t.Fatalf("expected to find point for node %v", nodeId) + } + + for level, friendsAtLevel := range nodeFriends.friends { + if level == 0 { + if friendsAtLevel.Len() > h.Mmax0 { + t.Fatalf("node id %v, num friends at level 0 cannot be greater than max number of connections M = %v. Got %v", nodeId, h.M, friendsAtLevel.Len()) + } + + continue + } + + if friendsAtLevel.Len() > h.M { + t.Fatalf("num friends at level %v cannot be greater than max number of connections M: %v. Got: %v", level, h.M, friendsAtLevel.Len()) + } + } + + } }) t.Run("basic cluster insertion", func(t *testing.T) {