Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update HNSW scheme to keep track of EntryPointId #327

Merged
merged 4 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions pkg/hnsw/hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ var ErrNodeNotFound = fmt.Errorf("node not found")
type Hnsw struct {
vectorDimensionality int

entryPointId Id

points map[Id]*Point
friends map[Id]*Friends

Expand All @@ -30,13 +32,16 @@ func NewHnsw(d int, efConstruction uint, M int, entryPoint Point) *Hnsw {
panic("invalid vector dimensionality")
}

defaultEntryPointId := Id(0)

friends := make(map[Id]*Friends)
friends[Id(0)] = NewFriends(0)
friends[defaultEntryPointId] = NewFriends(0)

points := make(map[Id]*Point)
points[Id(0)] = &entryPoint
points[defaultEntryPointId] = &entryPoint

return &Hnsw{
entryPointId: defaultEntryPointId,
points: points,
vectorDimensionality: d,
friends: friends,
Expand Down Expand Up @@ -113,14 +118,14 @@ func (h *Hnsw) searchLevel(q *Point, entryItem *Item, numNearestToQToReturn, lev
}

func (h *Hnsw) findCloserEntryPoint(q *Point, qFriends *Friends) *Item {
initialEntryPoint, ok := h.friends[Id(0)]
initialEntryPoint, ok := h.friends[h.entryPointId]
if !ok {
panic(ErrNodeNotFound)
}

entryPointDistToQ := EuclidDistance(*h.points[Id(0)], *q)
entryPointDistToQ := EuclidDistance(*h.points[h.entryPointId], *q)

epItem := &Item{id: Id(0), dist: entryPointDistToQ}
epItem := &Item{id: h.entryPointId, dist: entryPointDistToQ}
for level := initialEntryPoint.TopLevel(); level > qFriends.TopLevel()+1; level-- {
closestNeighborsToQ, err := h.searchLevel(q, epItem, 1, level)
if err != nil {
Expand Down
24 changes: 12 additions & 12 deletions pkg/hnsw/hnsw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ func SetupClusterHnsw(cluster []Point) (*Hnsw, error) {
g.friends[pointId] = NewFriends(0)

distEntryToClusterPoint := EuclidDistance(entryPoint, point)
g.friends[Id(0)].InsertFriendsAtLevel(0, pointId, distEntryToClusterPoint)
g.friends[pointId].InsertFriendsAtLevel(0, Id(0), distEntryToClusterPoint)
g.friends[g.entryPointId].InsertFriendsAtLevel(0, pointId, distEntryToClusterPoint)
g.friends[pointId].InsertFriendsAtLevel(0, g.entryPointId, distEntryToClusterPoint)
}

for idx, pointA := range cluster {
Expand Down Expand Up @@ -92,9 +92,9 @@ func TestHnsw_SearchLevel(t *testing.T) {
mPoint := Point{2, 2}
g.points[Id(1)] = &mPoint

g.friends[Id(0)].InsertFriendsAtLevel(0, 1, EuclidDistance(mPoint, entryPoint))
g.friends[g.entryPointId].InsertFriendsAtLevel(0, 1, EuclidDistance(mPoint, entryPoint))
g.friends[Id(1)] = NewFriends(0)
g.friends[Id(1)].InsertFriendsAtLevel(0, 0, EuclidDistance(mPoint, entryPoint))
g.friends[Id(1)].InsertFriendsAtLevel(0, g.entryPointId, EuclidDistance(mPoint, entryPoint))

qPoint := Point{4, 4}
closestNeighbor, err := g.searchLevel(&qPoint, &Item{id: 0, dist: EuclidDistance(entryPoint, qPoint)}, 1, 0)
Expand Down Expand Up @@ -124,7 +124,7 @@ func TestHnsw_SearchLevel(t *testing.T) {
t.Fatal(err)
}

entryPoint, ok := g.points[Id(0)]
entryPoint, ok := g.points[g.entryPointId]
if !ok {
t.Fatal(ErrNodeNotFound)
}
Expand Down Expand Up @@ -169,7 +169,7 @@ func TestHnsw_SearchLevel(t *testing.T) {
t.Fatal(err)
}

entryPoint, ok := g.points[Id(0)]
entryPoint, ok := g.points[g.entryPointId]
if !ok {
t.Fatal(ErrNodeNotFound)
}
Expand Down Expand Up @@ -276,10 +276,11 @@ func TestHnsw_FindCloserEntryPoint(t *testing.T) {
/*
Before anything, we need to pad the entry node's friends queue to include more than level 0.
This is because we only consider the following topLevels
for level := initialEntryPoint.TopLevel(); level > qFriends.TopLevel()+1; level-- {

for level := initialEntryPoint.TopLevel(); level > qFriends.TopLevel()+1; level-- {
*/

h.friends[Id(0)] = NewFriends(4)
h.friends[h.entryPointId] = NewFriends(4)

closerPointId := Id(1)
closerPoint := Point{2, 2}
Expand All @@ -288,9 +289,8 @@ func TestHnsw_FindCloserEntryPoint(t *testing.T) {
h.friends[closerPointId] = NewFriends(4)

distToEntry := EuclidDistance(Point{0, 0}, closerPoint)

h.friends[closerPointId].InsertFriendsAtLevel(4, Id(0), distToEntry)
h.friends[Id(0)].InsertFriendsAtLevel(4, closerPointId, distToEntry)
h.friends[closerPointId].InsertFriendsAtLevel(4, h.entryPointId, distToEntry)
h.friends[h.entryPointId].InsertFriendsAtLevel(4, closerPointId, distToEntry)

closestItem := h.findCloserEntryPoint(&Point{4, 4}, NewFriends(0))

Expand All @@ -307,7 +307,7 @@ func TestHnsw_FindCloserEntryPoint(t *testing.T) {
t.Run("single level means entry point is the closest", func(t *testing.T) {
h := NewHnsw(2, 4, 4, Point{0, 0})

h.friends[Id(0)] = NewFriends(4)
h.friends[h.entryPointId] = NewFriends(4)

closerPointId := Id(1)
closerPoint := Point{2, 2}
Expand Down
Loading