Skip to content

Commit

Permalink
feat: assert search layer
Browse files Browse the repository at this point in the history
  • Loading branch information
friendlymatthew committed Jun 3, 2024
1 parent 667ee83 commit 861afee
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 42 deletions.
18 changes: 7 additions & 11 deletions pkg/hnsw/friends.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ type Friends struct {
}

// NewFriends creates a new vector, note the max level is inclusive.
func NewFriends(maxLevel int) *Friends {
friends := make([]*BaseQueue, maxLevel+1)
func NewFriends(topLevel int) *Friends {
friends := make([]*BaseQueue, topLevel+1)

for i := 0; i <= maxLevel; i++ {
for i := 0; i <= topLevel; i++ {
friends[i] = NewBaseQueue(MinComparator{})
}

Expand All @@ -24,11 +24,11 @@ func NewFriends(maxLevel int) *Friends {
}
}

func (v *Friends) Levels() int {
func (v *Friends) NumLevels() int {
return len(v.friends)
}

func (v *Friends) MaxLevel() int {
func (v *Friends) TopLevel() int {
return len(v.friends) - 1
}

Expand All @@ -37,19 +37,15 @@ func (v *Friends) HasLevel(level int) bool {
panic("level must be nonzero positive integer")
}

return level <= v.MaxLevel()
return level <= v.TopLevel()
}

// InsertFriendsAtLevel requires level must be zero-indexed
func (v *Friends) InsertFriendsAtLevel(level int, vectorId, friendId Id, dist float32) {
func (v *Friends) InsertFriendsAtLevel(level int, friendId Id, dist float32) {
if !v.HasLevel(level) {
panic("failed to insert friends at level, as level is not valId")
}

if friendId == vectorId {
panic("cannot insert yourself to friends list")
}

for i := 0; i <= level; i++ {
v.friends[i].Insert(friendId, dist)
}
Expand Down
24 changes: 12 additions & 12 deletions pkg/hnsw/friends_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@ func TestVector_LevelManagement(t *testing.T) {

hexFriends := NewFriends(6)

if hexFriends.MaxLevel() != 6 {
t.Fatalf("since 0-indexed, the max level is 5, got: %v", hexFriends.MaxLevel())
if hexFriends.TopLevel() != 6 {
t.Fatalf("since 0-indexed, the max level is 5, got: %v", hexFriends.TopLevel())
}

if hexFriends.Levels() != 7 {
t.Fatalf("since 0-indexed, the number of levels is 6, got: %v", hexFriends.Levels())
if hexFriends.NumLevels() != 7 {
t.Fatalf("since 0-indexed, the number of levels is 6, got: %v", hexFriends.NumLevels())
}

octId := Id(2)
oct := []float32{0, 2, 3}
octFriends := NewFriends(8)

if octFriends.MaxLevel() != 8 {
t.Fatalf("since 0-indexed, the max level is 7, got: %v", octFriends.MaxLevel())
if octFriends.TopLevel() != 8 {
t.Fatalf("since 0-indexed, the max level is 7, got: %v", octFriends.TopLevel())
}

if octFriends.Levels() != 9 {
t.Fatalf("since 0-indexed, the number of levels is 8, got: %v", octFriends.Levels())
if octFriends.NumLevels() != 9 {
t.Fatalf("since 0-indexed, the number of levels is 8, got: %v", octFriends.NumLevels())
}

for i := 0; i <= 6; i++ {
Expand All @@ -51,8 +51,8 @@ func TestVector_LevelManagement(t *testing.T) {

hexOctDist := EuclidDistance(oct, hex)

hexFriends.InsertFriendsAtLevel(5, hexId, octId, hexOctDist)
octFriends.InsertFriendsAtLevel(5, octId, hexId, hexOctDist)
hexFriends.InsertFriendsAtLevel(5, octId, hexOctDist)
octFriends.InsertFriendsAtLevel(5, hexId, hexOctDist)

for i := 0; i <= 5; i++ {
hexFriends, err := hexFriends.GetFriendsAtLevel(i)
Expand All @@ -69,12 +69,12 @@ func TestVector_LevelManagement(t *testing.T) {
t.Fatalf("expected hex and oct friends list at level %v to be 1, got: %v || %v", i, hexFriends.Len(), octFriends.Len())
}

top := hexFriends.Peek()
top := hexFriends.Top()
if top.id != octId {
t.Fatalf("expected %v, got %v", octId, top.id)
}

top = octFriends.Peek()
top = octFriends.Top()
if top.id != hexId {
t.Fatalf("expected %v, got %v", hexId, top.id)
}
Expand Down
135 changes: 123 additions & 12 deletions pkg/hnsw/hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,157 @@ package hnsw

import (
"fmt"
"math"
"math/rand"
)

type Id = uint

var ErrNodeNotFound = fmt.Errorf("node not found")

type Hnsw struct {
vectorDimensionality int

Vectors map[Id]*Friends
points map[Id]*Point
friends map[Id]*Friends

normFactorForLevelGeneration int
levelMultiplier float64

// efConstruction is the size of the dynamic candIdate list
efConstruction uint

// default number of connections
M int

// mmax, mmax0 is the maximum number of connections for each element per layer
mmax, mmax0 int
M, Mmax0 int
}

func NewHnsw(d int, efConstruction uint, M, mmax, mmax0 int) *Hnsw {
if d <= 0 {
panic("vector dimensionality cannot be less than 1")
func NewHnsw(d int, efConstruction uint, M int, entryPoint Point) *Hnsw {
if d <= 0 || len(entryPoint) != d {
panic("invalid vector dimensionality")
}

friends := make(map[Id]*Friends)
friends[Id(0)] = NewFriends(0)

points := make(map[Id]*Point)
points[Id(0)] = &entryPoint

return &Hnsw{
points: points,
vectorDimensionality: d,
friends: friends,
efConstruction: efConstruction,
M: M,
mmax: mmax,
mmax0: mmax0,
Mmax0: 2 * M,
levelMultiplier: 1 / math.Log(float64(M)),
}
}

func (h *Hnsw) SpawnLevel() int {
return int(math.Floor(-math.Log(rand.Float64() * h.levelMultiplier)))
}

func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, level int) (*BaseQueue, error) {
visited := make([]bool, len(h.friends)+1)

candidatesForQ := NewBaseQueue(MinComparator{})
foundNNToQ := NewBaseQueue(MaxComparator{})

// note entryItem.dist should be the distance to Q
candidatesForQ.Insert(entryItem.id, entryItem.dist)
foundNNToQ.Insert(entryItem.id, entryItem.dist)

for !candidatesForQ.IsEmpty() {
closestCandidate, err := candidatesForQ.PopItem()
if err != nil {
return nil, fmt.Errorf("error during searching level %d: %w", level, err)
}

furthestFoundNN := foundNNToQ.Top()

// if distance(c, q) > distance(f, q)
if closestCandidate.dist > furthestFoundNN.dist {
// all items in furthest found nn are evaluated
break
}

closestCandidateFriends, err := h.friends[closestCandidate.id].GetFriendsAtLevel(level)
if err != nil {
return nil, fmt.Errorf("error during searching level %d: %w", level, err)
}

for _, ccFriendItem := range closestCandidateFriends.items {
ccFriendId := ccFriendItem.id
if !visited[ccFriendId] {
visited[ccFriendId] = true

furthestFoundNN = foundNNToQ.Top()

ccFriendPoint, ok := h.points[ccFriendId]
if !ok {
return nil, ErrNodeNotFound
}

// if distance(ccFriend, q) < distance(f, q)
ccFriendDistToQ := EuclidDistance(*ccFriendPoint, *q)
if ccFriendDistToQ < furthestFoundNN.dist || foundNNToQ.Len() < numNearestToQToReturn {
candidatesForQ.Insert(ccFriendId, ccFriendDistToQ)
foundNNToQ.Insert(ccFriendId, ccFriendDistToQ)

if foundNNToQ.Len() > numNearestToQToReturn {
_, err = foundNNToQ.PopItem()
if err != nil {
return nil, fmt.Errorf("error during searching level %d: %w", level, err)
}
}
}
}
}

}

return FromBaseQueue(foundNNToQ, MinComparator{}), nil
}

func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item {
initialEntryPoint, ok := h.friends[Id(0)]
if !ok {
panic(ErrNodeNotFound)
}

entryPointDistToQ := EuclidDistance(*h.points[Id(0)], *q)

epItem := &Item{id: Id(0), dist: entryPointDistToQ}
for level := initialEntryPoint.TopLevel(); level > qFriends.TopLevel()+1; level-- {
closestNeighborsToQ, err := h.searchLevel(q, epItem, 1, level)
if err != nil {
panic(err)
}

if closestNeighborsToQ.IsEmpty() {
// return the existing epItem. it's the closest to q.
return epItem
}

newEpItem, err := closestNeighborsToQ.PopItem()
if err != nil {
panic(err)
}

epItem = newEpItem
}

return epItem
}

func (h *Hnsw) InsertVector(q Point) error {
if !h.validatePoint(q) {
return fmt.Errorf("invalidvector")
return fmt.Errorf("invalid vector dimensionality")
}

qTopLevel := h.SpawnLevel()
qFriends := NewFriends(qTopLevel)

_ = h.findCloserEntryPoint(&q, qFriends)
return nil
}

Expand Down
67 changes: 67 additions & 0 deletions pkg/hnsw/hnsw_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package hnsw

import "testing"

/*
var clusterA = []Point{
{0.2, 0.5},
{0.2, 0.7},
{0.3, 0.8},
{0.5, 0.5},
{0.4, 0.1},
{0.3, 0.7},
{0.27, 0.23},
{0.12, 0.1},
{0.23, 0.25},
{0.3, 0.3},
{0.01, 0.3},
}
var clusterB = []Point{
{4.2, 3.5},
{4.2, 4.7},
{4.3, 3.8},
{4.5, 4.5},
{4.4, 3.1},
{4.3, 4.7},
{4.27, 3.23},
{4.1, 4.1},
{4.12, 3.1},
{4.23, 4.25},
{4.3, 3.3},
{4.01, 4.3},
}
*/

func TestHnsw_SearchLevel(t *testing.T) {
t.Run("search level 0", func(t *testing.T) {
entryPoint := Point{0, 0}
g := NewHnsw(2, 4, 4, entryPoint)
mPoint := Point{2, 2}
g.points[Id(1)] = &mPoint

g.friends[Id(0)].InsertFriendsAtLevel(0, 1, EuclidDistance(mPoint, entryPoint))
g.friends[Id(1)] = NewFriends(0)
g.friends[Id(1)].InsertFriendsAtLevel(0, 0, EuclidDistance(mPoint, entryPoint))

qPoint := Point{4, 4}
closestNeighbor, err := g.searchLevel(&qPoint, &Item{id: 0, dist: EuclidDistance(entryPoint, qPoint)}, 1, 0)
if err != nil {
t.Fatal(err)
}

if closestNeighbor.IsEmpty() {
t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor)
}

closestItem, err := closestNeighbor.PopItem()

if err != nil {
t.Fatal(err)
}

if Id(1) != closestItem.id {
t.Fatalf("expected item id to be %v, got %v", 1, closestItem.id)
}
})
}
10 changes: 5 additions & 5 deletions pkg/hnsw/pq.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ type Heapy interface {
Insert(id Id, dist float32)
IsEmpty() bool
Len() int
Peel() *Item
Peek() *Item
PopItem() *Item
Top() *Item
Take(count int) (*BaseQueue, error)
update(item *Item, id Id, dist float32)
}
Expand All @@ -61,7 +61,7 @@ func (bq *BaseQueue) Take(count int, comparator Comparator) (*BaseQueue, error)
break
}

peeled, err := bq.Peel()
peeled, err := bq.PopItem()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -89,7 +89,7 @@ func (bq *BaseQueue) Push(x any) {
bq.items = append(bq.items, item)
}

func (bq *BaseQueue) Peek() *Item {
func (bq *BaseQueue) Top() *Item {
if len(bq.items) == 0 {
return nil
}
Expand Down Expand Up @@ -135,7 +135,7 @@ func NewBaseQueue(comparator Comparator) *BaseQueue {
return bq
}

func (bq *BaseQueue) Peel() (*Item, error) {
func (bq *BaseQueue) PopItem() (*Item, error) {
if bq.Len() == 0 {
return nil, fmt.Errorf("no items to peel")
}
Expand Down
Loading

0 comments on commit 861afee

Please sign in to comment.