From 861afeeaf6a2cca176693c92e15e2a65be2d647e Mon Sep 17 00:00:00 2001 From: Matthew Kim <38759997+friendlymatthew@users.noreply.github.com> Date: Mon, 3 Jun 2024 13:41:27 -0400 Subject: [PATCH] feat: assert search layer --- pkg/hnsw/friends.go | 18 ++---- pkg/hnsw/friends_test.go | 24 +++---- pkg/hnsw/hnsw.go | 135 +++++++++++++++++++++++++++++++++++---- pkg/hnsw/hnsw_test.go | 67 +++++++++++++++++++ pkg/hnsw/pq.go | 10 +-- pkg/hnsw/pq_test.go | 4 +- 6 files changed, 216 insertions(+), 42 deletions(-) create mode 100644 pkg/hnsw/hnsw_test.go diff --git a/pkg/hnsw/friends.go b/pkg/hnsw/friends.go index a790e0f..d8f2544 100644 --- a/pkg/hnsw/friends.go +++ b/pkg/hnsw/friends.go @@ -12,10 +12,10 @@ type Friends struct { } // NewFriends creates a new vector, note the max level is inclusive. -func NewFriends(maxLevel int) *Friends { - friends := make([]*BaseQueue, maxLevel+1) +func NewFriends(topLevel int) *Friends { + friends := make([]*BaseQueue, topLevel+1) - for i := 0; i <= maxLevel; i++ { + for i := 0; i <= topLevel; i++ { friends[i] = NewBaseQueue(MinComparator{}) } @@ -24,11 +24,11 @@ func NewFriends(maxLevel int) *Friends { } } -func (v *Friends) Levels() int { +func (v *Friends) NumLevels() int { return len(v.friends) } -func (v *Friends) MaxLevel() int { +func (v *Friends) TopLevel() int { return len(v.friends) - 1 } @@ -37,19 +37,15 @@ func (v *Friends) HasLevel(level int) bool { panic("level must be nonzero positive integer") } - return level <= v.MaxLevel() + return level <= v.TopLevel() } // InsertFriendsAtLevel requires level must be zero-indexed -func (v *Friends) InsertFriendsAtLevel(level int, vectorId, friendId Id, dist float32) { +func (v *Friends) InsertFriendsAtLevel(level int, friendId Id, dist float32) { if !v.HasLevel(level) { panic("failed to insert friends at level, as level is not valId") } - if friendId == vectorId { - panic("cannot insert yourself to friends list") - } - for i := 0; i <= level; i++ { v.friends[i].Insert(friendId, dist) } diff --git a/pkg/hnsw/friends_test.go b/pkg/hnsw/friends_test.go index 37f2b00..ea52d67 100644 --- a/pkg/hnsw/friends_test.go +++ b/pkg/hnsw/friends_test.go @@ -17,24 +17,24 @@ func TestVector_LevelManagement(t *testing.T) { hexFriends := NewFriends(6) - if hexFriends.MaxLevel() != 6 { - t.Fatalf("since 0-indexed, the max level is 5, got: %v", hexFriends.MaxLevel()) + if hexFriends.TopLevel() != 6 { + t.Fatalf("since 0-indexed, the max level is 5, got: %v", hexFriends.TopLevel()) } - if hexFriends.Levels() != 7 { - t.Fatalf("since 0-indexed, the number of levels is 6, got: %v", hexFriends.Levels()) + if hexFriends.NumLevels() != 7 { + t.Fatalf("since 0-indexed, the number of levels is 6, got: %v", hexFriends.NumLevels()) } octId := Id(2) oct := []float32{0, 2, 3} octFriends := NewFriends(8) - if octFriends.MaxLevel() != 8 { - t.Fatalf("since 0-indexed, the max level is 7, got: %v", octFriends.MaxLevel()) + if octFriends.TopLevel() != 8 { + t.Fatalf("since 0-indexed, the max level is 7, got: %v", octFriends.TopLevel()) } - if octFriends.Levels() != 9 { - t.Fatalf("since 0-indexed, the number of levels is 8, got: %v", octFriends.Levels()) + if octFriends.NumLevels() != 9 { + t.Fatalf("since 0-indexed, the number of levels is 8, got: %v", octFriends.NumLevels()) } for i := 0; i <= 6; i++ { @@ -51,8 +51,8 @@ func TestVector_LevelManagement(t *testing.T) { hexOctDist := EuclidDistance(oct, hex) - hexFriends.InsertFriendsAtLevel(5, hexId, octId, hexOctDist) - octFriends.InsertFriendsAtLevel(5, octId, hexId, hexOctDist) + hexFriends.InsertFriendsAtLevel(5, octId, hexOctDist) + octFriends.InsertFriendsAtLevel(5, hexId, hexOctDist) for i := 0; i <= 5; i++ { hexFriends, err := hexFriends.GetFriendsAtLevel(i) @@ -69,12 +69,12 @@ func TestVector_LevelManagement(t *testing.T) { t.Fatalf("expected hex and oct friends list at level %v to be 1, got: %v || %v", i, hexFriends.Len(), octFriends.Len()) } - top := hexFriends.Peek() + top := hexFriends.Top() if top.id != octId { t.Fatalf("expected %v, got %v", octId, top.id) } - top = octFriends.Peek() + top = octFriends.Top() if top.id != hexId { t.Fatalf("expected %v, got %v", hexId, top.id) } diff --git a/pkg/hnsw/hnsw.go b/pkg/hnsw/hnsw.go index 98d0508..76514dc 100644 --- a/pkg/hnsw/hnsw.go +++ b/pkg/hnsw/hnsw.go @@ -2,46 +2,157 @@ package hnsw import ( "fmt" + "math" + "math/rand" ) type Id = uint +var ErrNodeNotFound = fmt.Errorf("node not found") + type Hnsw struct { vectorDimensionality int - Vectors map[Id]*Friends + points map[Id]*Point + friends map[Id]*Friends - normFactorForLevelGeneration int + levelMultiplier float64 // efConstruction is the size of the dynamic candIdate list efConstruction uint // default number of connections - M int - - // mmax, mmax0 is the maximum number of connections for each element per layer - mmax, mmax0 int + M, Mmax0 int } -func NewHnsw(d int, efConstruction uint, M, mmax, mmax0 int) *Hnsw { - if d <= 0 { - panic("vector dimensionality cannot be less than 1") +func NewHnsw(d int, efConstruction uint, M int, entryPoint Point) *Hnsw { + if d <= 0 || len(entryPoint) != d { + panic("invalid vector dimensionality") } + friends := make(map[Id]*Friends) + friends[Id(0)] = NewFriends(0) + + points := make(map[Id]*Point) + points[Id(0)] = &entryPoint + return &Hnsw{ + points: points, vectorDimensionality: d, + friends: friends, efConstruction: efConstruction, M: M, - mmax: mmax, - mmax0: mmax0, + Mmax0: 2 * M, + levelMultiplier: 1 / math.Log(float64(M)), } } +func (h *Hnsw) SpawnLevel() int { + return int(math.Floor(-math.Log(rand.Float64() * h.levelMultiplier))) +} + +func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, level int) (*BaseQueue, error) { + visited := make([]bool, len(h.friends)+1) + + candidatesForQ := NewBaseQueue(MinComparator{}) + foundNNToQ := NewBaseQueue(MaxComparator{}) + + // note entryItem.dist should be the distance to Q + candidatesForQ.Insert(entryItem.id, entryItem.dist) + foundNNToQ.Insert(entryItem.id, entryItem.dist) + + for !candidatesForQ.IsEmpty() { + closestCandidate, err := candidatesForQ.PopItem() + if err != nil { + return nil, fmt.Errorf("error during searching level %d: %w", level, err) + } + + furthestFoundNN := foundNNToQ.Top() + + // if distance(c, q) > distance(f, q) + if closestCandidate.dist > furthestFoundNN.dist { + // all items in furthest found nn are evaluated + break + } + + closestCandidateFriends, err := h.friends[closestCandidate.id].GetFriendsAtLevel(level) + if err != nil { + return nil, fmt.Errorf("error during searching level %d: %w", level, err) + } + + for _, ccFriendItem := range closestCandidateFriends.items { + ccFriendId := ccFriendItem.id + if !visited[ccFriendId] { + visited[ccFriendId] = true + + furthestFoundNN = foundNNToQ.Top() + + ccFriendPoint, ok := h.points[ccFriendId] + if !ok { + return nil, ErrNodeNotFound + } + + // if distance(ccFriend, q) < distance(f, q) + ccFriendDistToQ := EuclidDistance(*ccFriendPoint, *q) + if ccFriendDistToQ < furthestFoundNN.dist || foundNNToQ.Len() < numNearestToQToReturn { + candidatesForQ.Insert(ccFriendId, ccFriendDistToQ) + foundNNToQ.Insert(ccFriendId, ccFriendDistToQ) + + if foundNNToQ.Len() > numNearestToQToReturn { + _, err = foundNNToQ.PopItem() + if err != nil { + return nil, fmt.Errorf("error during searching level %d: %w", level, err) + } + } + } + } + } + + } + + return FromBaseQueue(foundNNToQ, MinComparator{}), nil +} + +func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item { + initialEntryPoint, ok := h.friends[Id(0)] + if !ok { + panic(ErrNodeNotFound) + } + + entryPointDistToQ := EuclidDistance(*h.points[Id(0)], *q) + + epItem := &Item{id: Id(0), dist: entryPointDistToQ} + for level := initialEntryPoint.TopLevel(); level > qFriends.TopLevel()+1; level-- { + closestNeighborsToQ, err := h.searchLevel(q, epItem, 1, level) + if err != nil { + panic(err) + } + + if closestNeighborsToQ.IsEmpty() { + // return the existing epItem. it's the closest to q. + return epItem + } + + newEpItem, err := closestNeighborsToQ.PopItem() + if err != nil { + panic(err) + } + + epItem = newEpItem + } + + return epItem +} + func (h *Hnsw) InsertVector(q Point) error { if !h.validatePoint(q) { - return fmt.Errorf("invalidvector") + return fmt.Errorf("invalid vector dimensionality") } + qTopLevel := h.SpawnLevel() + qFriends := NewFriends(qTopLevel) + + _ = h.findCloserEntryPoint(&q, qFriends) return nil } diff --git a/pkg/hnsw/hnsw_test.go b/pkg/hnsw/hnsw_test.go new file mode 100644 index 0000000..a7b6bda --- /dev/null +++ b/pkg/hnsw/hnsw_test.go @@ -0,0 +1,67 @@ +package hnsw + +import "testing" + +/* +var clusterA = []Point{ + {0.2, 0.5}, + {0.2, 0.7}, + {0.3, 0.8}, + {0.5, 0.5}, + {0.4, 0.1}, + {0.3, 0.7}, + {0.27, 0.23}, + {0.12, 0.1}, + {0.23, 0.25}, + {0.3, 0.3}, + {0.01, 0.3}, +} + +var clusterB = []Point{ + {4.2, 3.5}, + {4.2, 4.7}, + {4.3, 3.8}, + {4.5, 4.5}, + {4.4, 3.1}, + {4.3, 4.7}, + {4.27, 3.23}, + {4.1, 4.1}, + {4.12, 3.1}, + {4.23, 4.25}, + {4.3, 3.3}, + {4.01, 4.3}, +} +*/ + +func TestHnsw_SearchLevel(t *testing.T) { + t.Run("search level 0", func(t *testing.T) { + entryPoint := Point{0, 0} + g := NewHnsw(2, 4, 4, entryPoint) + mPoint := Point{2, 2} + g.points[Id(1)] = &mPoint + + g.friends[Id(0)].InsertFriendsAtLevel(0, 1, EuclidDistance(mPoint, entryPoint)) + g.friends[Id(1)] = NewFriends(0) + g.friends[Id(1)].InsertFriendsAtLevel(0, 0, EuclidDistance(mPoint, entryPoint)) + + qPoint := Point{4, 4} + closestNeighbor, err := g.searchLevel(&qPoint, &Item{id: 0, dist: EuclidDistance(entryPoint, qPoint)}, 1, 0) + if err != nil { + t.Fatal(err) + } + + if closestNeighbor.IsEmpty() { + t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor) + } + + closestItem, err := closestNeighbor.PopItem() + + if err != nil { + t.Fatal(err) + } + + if Id(1) != closestItem.id { + t.Fatalf("expected item id to be %v, got %v", 1, closestItem.id) + } + }) +} diff --git a/pkg/hnsw/pq.go b/pkg/hnsw/pq.go index 7eb21cd..cc62372 100644 --- a/pkg/hnsw/pq.go +++ b/pkg/hnsw/pq.go @@ -34,8 +34,8 @@ type Heapy interface { Insert(id Id, dist float32) IsEmpty() bool Len() int - Peel() *Item - Peek() *Item + PopItem() *Item + Top() *Item Take(count int) (*BaseQueue, error) update(item *Item, id Id, dist float32) } @@ -61,7 +61,7 @@ func (bq *BaseQueue) Take(count int, comparator Comparator) (*BaseQueue, error) break } - peeled, err := bq.Peel() + peeled, err := bq.PopItem() if err != nil { return nil, err } @@ -89,7 +89,7 @@ func (bq *BaseQueue) Push(x any) { bq.items = append(bq.items, item) } -func (bq *BaseQueue) Peek() *Item { +func (bq *BaseQueue) Top() *Item { if len(bq.items) == 0 { return nil } @@ -135,7 +135,7 @@ func NewBaseQueue(comparator Comparator) *BaseQueue { return bq } -func (bq *BaseQueue) Peel() (*Item, error) { +func (bq *BaseQueue) PopItem() (*Item, error) { if bq.Len() == 0 { return nil, fmt.Errorf("no items to peel") } diff --git a/pkg/hnsw/pq_test.go b/pkg/hnsw/pq_test.go index c98e223..5bc1539 100644 --- a/pkg/hnsw/pq_test.go +++ b/pkg/hnsw/pq_test.go @@ -57,7 +57,7 @@ func TestPQ(t *testing.T) { i := Id(99) for !incBq.IsEmpty() { - item, err := incBq.Peel() + item, err := incBq.PopItem() if err != nil { t.Fatal(err) } @@ -88,7 +88,7 @@ func furthestBuildings(heights []int, bricks, ladders int) (int, error) { ladderJumps.Insert(Id(idx), float32(jump)) if ladderJumps.Len() > ladders { - minLadderJump, err := ladderJumps.Peel() + minLadderJump, err := ladderJumps.PopItem() if err != nil { return -1, err }