Skip to content

Commit

Permalink
complete refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
friendlymatthew committed Jun 13, 2024
1 parent 954b87d commit 00a7273
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 334 deletions.
8 changes: 4 additions & 4 deletions pkg/hnsw/friends.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ import (
type Point []float32

type Friends struct {
friends []*BaseQueue
friends []*DistHeap
}

// NewFriends creates a new vector, note the max level is inclusive.
func NewFriends(topLevel int) *Friends {
friends := make([]*BaseQueue, topLevel+1)
friends := make([]*DistHeap, topLevel+1)

for i := 0; i <= topLevel; i++ {
friends[i] = NewBaseQueue(MinComparator{})
friends[i] = NewDistHeap()
}

return &Friends{
Expand Down Expand Up @@ -51,7 +51,7 @@ func (v *Friends) InsertFriendsAtLevel(level int, friendId Id, dist float32) {
}
}

func (v *Friends) GetFriendsAtLevel(level int) (*BaseQueue, error) {
func (v *Friends) GetFriendsAtLevel(level int) (*DistHeap, error) {
if !v.HasLevel(level) {
return nil, errors.New("failed to get friends at level")
}
Expand Down
10 changes: 8 additions & 2 deletions pkg/hnsw/friends_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,18 @@ func TestVector_LevelManagement(t *testing.T) {
t.Fatalf("expected hex and oct friends list at level %v to be 1, got: %v || %v", i, hexFriends.Len(), octFriends.Len())
}

top := hexFriends.Top()
top, err := hexFriends.PeekMinItem()
if err != nil {
t.Fatal(err)
}
if top.id != octId {
t.Fatalf("expected %v, got %v", octId, top.id)
}

top = octFriends.Top()
top, err = octFriends.PeekMinItem()
if err != nil {
t.Fatal(err)
}
if top.id != hexId {
t.Fatalf("expected %v, got %v", hexId, top.id)
}
Expand Down
65 changes: 30 additions & 35 deletions pkg/hnsw/hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,23 +60,26 @@ func (h *Hnsw) GenerateId() Id {
return Id(len(h.points))
}

func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, level int) (*BaseQueue, error) {
func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, level int) (*DistHeap, error) {
visited := make([]bool, len(h.friends)+1)

candidatesForQ := NewBaseQueue(MinComparator{})
foundNNToQ := NewBaseQueue(MaxComparator{})
candidatesForQ := NewDistHeap()
foundNNToQ := NewDistHeap() // this is a max

// note entryItem.dist should be the distance to Q
candidatesForQ.Insert(entryItem.id, entryItem.dist)
foundNNToQ.Insert(entryItem.id, entryItem.dist)

for !candidatesForQ.IsEmpty() {
closestCandidate, err := candidatesForQ.PopItem()
closestCandidate, err := candidatesForQ.PopMinItem()
if err != nil {
return nil, fmt.Errorf("error during searching level %d: %w", level, err)
}

furthestFoundNN := foundNNToQ.Top()
furthestFoundNN, err := foundNNToQ.PeekMaxItem()
if err != nil {
return nil, fmt.Errorf("error during searching level %d: %w", level, err)
}

// if distance(c, q) > distance(f, q)
if closestCandidate.dist > furthestFoundNN.dist {
Expand All @@ -94,7 +97,10 @@ func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, lev
if !visited[ccFriendId] {
visited[ccFriendId] = true

furthestFoundNN = foundNNToQ.Top()
furthestFoundNN, err = foundNNToQ.PeekMaxItem()
if err != nil {
return nil, fmt.Errorf("error during searching level %d: %w", level, err)
}

ccFriendPoint, ok := h.points[ccFriendId]
if !ok {
Expand All @@ -108,7 +114,7 @@ func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, lev
foundNNToQ.Insert(ccFriendId, ccFriendDistToQ)

if foundNNToQ.Len() > numNearestToQToReturn {
if _, err = foundNNToQ.PopItem(); err != nil {
if _, err = foundNNToQ.PopMaxItem(); err != nil {
return nil, fmt.Errorf("error during searching level %d: %w", level, err)
}
}
Expand All @@ -118,7 +124,7 @@ func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, lev

}

return FromBaseQueue(foundNNToQ, MinComparator{}), nil
return foundNNToQ, nil
}

func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item {
Expand All @@ -141,7 +147,7 @@ func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item {
return epItem
}

newEpItem, err := closestNeighborsToQ.PopItem()
newEpItem, err := closestNeighborsToQ.PopMinItem()
if err != nil {
panic(err)
}
Expand All @@ -152,15 +158,15 @@ func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item {
return epItem
}

func (h *Hnsw) selectNeighbors(nearestNeighbors *BaseQueue) ([]*Item, error) {
func (h *Hnsw) selectNeighbors(nearestNeighbors *DistHeap) ([]*Item, error) {
if nearestNeighbors.Len() <= h.M {
return nearestNeighbors.items, nil
}

nearestItems := make([]*Item, h.M)

for i := 0; i < h.M; i++ {
nearestItem, err := nearestNeighbors.PopItem()
nearestItem, err := nearestNeighbors.PopMinItem()

if err != nil {
return nil, err
Expand Down Expand Up @@ -221,19 +227,17 @@ func (h *Hnsw) InsertVector(q Point) error {
return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}

maxNeighborsFriendsAtLevel := FromBaseQueue(neighborFriendsAtLevel, MaxComparator{})

for maxNeighborsFriendsAtLevel.Len() > h.M {
_, err = maxNeighborsFriendsAtLevel.PopItem()
for neighborFriendsAtLevel.Len() > h.M {
_, err := neighborFriendsAtLevel.PopMaxItem()
if err != nil {
return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}
}

h.friends[neighbor.id].friends[level] = FromBaseQueue(maxNeighborsFriendsAtLevel, MinComparator{})
h.friends[neighbor.id].friends[level] = neighborFriendsAtLevel
}

newEntryItem, err := nnToQAtLevel.PopItem()
newEntryItem, err := nnToQAtLevel.PopMinItem()
if err != nil {
return fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}
Expand All @@ -252,7 +256,7 @@ func (h *Hnsw) isValidPoint(point Point) bool {
return len(point) == h.vectorDimensionality
}

func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*BaseQueue, error) {
func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*DistHeap, error) {
entryPoint, ok := h.points[h.entryPointId]

if !ok {
Expand All @@ -278,7 +282,10 @@ func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*BaseQueue, error)
return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}

entryItem = nearestNeighborQueueAtLevel.Top()
entryItem, err = nearestNeighborQueueAtLevel.PeekMinItem()
if err != nil {
return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", level, err)
}
}

// level 0
Expand All @@ -287,24 +294,12 @@ func (h *Hnsw) KnnSearch(q Point, numNeighborsToReturn int) (*BaseQueue, error)
return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %d", h.entryPointId, 0)
}

if nearestNeighborQueueAtLevel0.Len() <= numNeighborsToReturn {
return nearestNeighborQueueAtLevel0, nil
}

var items []*Item

for !nearestNeighborQueueAtLevel0.IsEmpty() {
if len(items) == numNeighborsToReturn {
return FromItems(items, MinComparator{}), nil
}

nearestNeighborAtLevel0Item, err := nearestNeighborQueueAtLevel0.PopItem()
for nearestNeighborQueueAtLevel0.Len() > numNeighborsToReturn {
_, err := nearestNeighborQueueAtLevel0.PopMaxItem()
if err != nil {
return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %d", h.entryPointId, 0)
return nil, fmt.Errorf("failed to find nearest neighbor to Q at level %v: %w", h.entryPointId, err)
}

items = append(items, nearestNeighborAtLevel0Item)
}

return FromItems(items, MinComparator{}), nil
return nearestNeighborQueueAtLevel0, nil
}
30 changes: 15 additions & 15 deletions pkg/hnsw/hnsw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ func TestHnsw_SearchLevel(t *testing.T) {
t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor)
}

closestItem, err := closestNeighbor.PopItem()
closestItem, err := closestNeighbor.PopMinItem()

if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -149,7 +149,7 @@ func TestHnsw_SearchLevel(t *testing.T) {
t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor)
}

closestItem, err := closestNeighbor.PopItem()
closestItem, err := closestNeighbor.PopMinItem()

if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -188,7 +188,7 @@ func TestHnsw_SearchLevel(t *testing.T) {
t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor)
}

closestItem, err := closestNeighbor.PopItem()
closestItem, err := closestNeighbor.PopMinItem()

if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -218,7 +218,7 @@ func TestHnsw_SearchLevel(t *testing.T) {
t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor.Len())
}

closestItem, err := closestNeighbor.PopItem()
closestItem, err := closestNeighbor.PopMinItem()
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -253,7 +253,7 @@ func TestHnsw_SearchLevel(t *testing.T) {
var closestIds []Id

for !closestNeighbor.IsEmpty() {
closestItem, err := closestNeighbor.PopItem()
closestItem, err := closestNeighbor.PopMinItem()
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -330,7 +330,7 @@ func TestHnsw_FindCloserEntryPoint(t *testing.T) {
func TestHnsw_SelectNeighbors(t *testing.T) {

t.Run("selects neighbors given overflow", func(t *testing.T) {
nearestNeighbors := NewBaseQueue(MinComparator{})
nearestNeighbors := NewDistHeap()

M := 4

Expand All @@ -352,15 +352,15 @@ func TestHnsw_SelectNeighbors(t *testing.T) {
}

// for the sake of testing, let's rebuild the pq and assert ids are correct
reneighbors := NewBaseQueue(MinComparator{})
reneighbors := NewDistHeap()

for _, item := range neighbors {
reneighbors.Insert(item.id, item.dist)
}

expectedId := Id(0)
for !reneighbors.IsEmpty() {
nn, err := reneighbors.PopItem()
nn, err := reneighbors.PopMinItem()

if err != nil {
t.Fatal(err)
Expand All @@ -378,7 +378,7 @@ func TestHnsw_SelectNeighbors(t *testing.T) {
M := 10
h := NewHnsw(2, 10, M, Point{0, 0})

nnQueue := NewBaseQueue(MinComparator{})
nnQueue := NewDistHeap()

for i := 0; i < 3; i++ {
nnQueue.Insert(Id(i), float32(i))
Expand All @@ -394,15 +394,15 @@ func TestHnsw_SelectNeighbors(t *testing.T) {
t.Fatalf("select neighbors should have at least 3 neighbors, got: %v", len(neighbors))
}

reneighbors := NewBaseQueue(MinComparator{})
reneighbors := NewDistHeap()

for _, item := range neighbors {
reneighbors.Insert(item.id, item.dist)
}

expectedId := Id(0)
for !reneighbors.IsEmpty() {
nn, err := reneighbors.PopItem()
nn, err := reneighbors.PopMinItem()

if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -510,7 +510,7 @@ func TestHnsw_KnnSearch(t *testing.T) {
expectedId := Id(3)

for !nearestNeighbors.IsEmpty() {
nearestNeighbor, err := nearestNeighbors.PopItem()
nearestNeighbor, err := nearestNeighbors.PopMinItem()
if err != nil {
t.Fatalf("failed to pop item: %v, err: %v", nearestNeighbors, err)
}
Expand Down Expand Up @@ -546,7 +546,7 @@ func TestHnsw_KnnSearch(t *testing.T) {
var gotIds []Id

for !closestToQ.IsEmpty() {
closest, err := closestToQ.PopItem()
closest, err := closestToQ.PopMinItem()
if err != nil {
t.Fatalf("failed to pop item: %v, err: %v", closestToQ, err)
}
Expand Down Expand Up @@ -587,7 +587,7 @@ func TestHnsw_KnnSearch(t *testing.T) {
var got []Id

for !closestNeighbors.IsEmpty() {
closest, err := closestNeighbors.PopItem()
closest, err := closestNeighbors.PopMinItem()
if err != nil {
t.Fatalf("failed to pop item: %v, err: %v", closestNeighbors, err)
}
Expand Down Expand Up @@ -621,7 +621,7 @@ func TestHnsw_KnnSearch(t *testing.T) {
expectedId := Id(0)

for found.IsEmpty() {
nnItem, err := found.PopItem()
nnItem, err := found.PopMinItem()
if err != nil {
t.Fatalf("failed to pop item: %v, err: %v", found, err)
}
Expand Down
Loading

0 comments on commit 00a7273

Please sign in to comment.