Skip to content

Commit

Permalink
assert every point's friends queue has correct max size
Browse files Browse the repository at this point in the history
  • Loading branch information
friendlymatthew committed Jun 11, 2024
1 parent c6d81e6 commit 3696088
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 10 deletions.
27 changes: 18 additions & 9 deletions pkg/hnsw/hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,30 +207,39 @@ func (h *Hnsw) InsertVector(q Point) error {
// add bidirectional connections from neighbors to q at layer c
for _, neighbor := range neighbors {
neighborPoint := h.points[neighbor.id]

distNeighToQ := EuclidDistance(*neighborPoint, q)

h.friends[neighbor.id].InsertFriendsAtLevel(level, qId, distNeighToQ)
h.friends[qId].InsertFriendsAtLevel(level, neighbor.id, distNeighToQ)
}

for _, neighbor := range neighbors {
neighborFriendsAtLevel, err := h.friends[neighbor.id].GetFriendsAtLevel(level)

if err != nil {
return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}

maxNeighborsFriendsAtLevel := FromBaseQueue(neighborFriendsAtLevel, MaxComparator{})
maxNumberOfNeighbors := h.M
if level == 0 {
maxNumberOfNeighbors = h.Mmax0
}

eConnections := neighborFriendsAtLevel
if eConnections.Len() > maxNumberOfNeighbors {
var items []*Item

for !neighborFriendsAtLevel.IsEmpty() {
nearestNeighborFriendAtLevelItem, err := neighborFriendsAtLevel.PopItem()
if err != nil {
return fmt.Errorf("failed to pop from neighborFriendsAtLevel: %w", err)
}

for maxNeighborsFriendsAtLevel.Len() > h.M {
_, err = maxNeighborsFriendsAtLevel.PopItem()
if err != nil {
return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
items = append(items, nearestNeighborFriendAtLevelItem)
}

eConnections = FromItems(items, MinComparator{})
}

h.friends[neighbor.id].friends[level] = FromBaseQueue(maxNeighborsFriendsAtLevel, MinComparator{})
h.friends[neighbor.id].friends[level] = eConnections
}

newEntryItem, err := nnToQAtLevel.PopItem()
Expand Down
30 changes: 29 additions & 1 deletion pkg/hnsw/hnsw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ func TestHnsw_InsertVector(t *testing.T) {
t.Run("bulk insert", func(t *testing.T) {
items := 1

h := NewHnsw(3, 4, 4, Point{0, 0, 0})
h := NewHnsw(3, 4, 10, Point{0, 0, 0})

for i := 100; i >= 1; i-- {
j := float32(i)
Expand Down Expand Up @@ -465,6 +465,34 @@ func TestHnsw_InsertVector(t *testing.T) {

items += 1
}

// ensure every friend pq is of max length 4
var allNodeIds []Id
for id := range h.friends {
allNodeIds = append(allNodeIds, id)
}

for _, nodeId := range allNodeIds {
nodeFriends, ok := h.friends[nodeId]
if !ok {
t.Fatalf("expected to find point for node %v", nodeId)
}

for level, friendsAtLevel := range nodeFriends.friends {
if level == 0 {
if friendsAtLevel.Len() > h.Mmax0 {
t.Fatalf("node id %v, num friends at level 0 cannot be greater than max number of connections M = %v. Got %v", nodeId, h.M, friendsAtLevel.Len())
}

continue
}

if friendsAtLevel.Len() > h.M {
t.Fatalf("num friends at level %v cannot be greater than max number of connections M: %v. Got: %v", level, h.M, friendsAtLevel.Len())
}
}

}
})

t.Run("basic cluster insertion", func(t *testing.T) {
Expand Down

0 comments on commit 3696088

Please sign in to comment.