Skip to content

Commit

Permalink
SearchLayer given Clusters A, B (#325)
Browse files Browse the repository at this point in the history
* consider 1, consider many

* fix assertiong
  • Loading branch information
friendlymatthew authored Jun 4, 2024
1 parent a8cabd0 commit 77b2f20
Showing 1 changed file with 75 additions and 10 deletions.
85 changes: 75 additions & 10 deletions pkg/hnsw/hnsw_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ var clusterA = []Point{
{0.01, 0.3},
}

/*
var clusterB = []Point{
{4.2, 3.5},
{4.2, 4.7},
Expand All @@ -30,20 +29,19 @@ var clusterB = []Point{
{4.3, 4.7},
{4.27, 3.23},
{4.1, 4.1},
{4.12, 3.1},
{4, 2},
{4.23, 4.25},
{4.3, 3.3},
{4.01, 4.3},
}
*/

func SetupClusterAHnsw() (*Hnsw, error) {
func SetupClusterHnsw(cluster []Point) (*Hnsw, error) {
efc := uint(4)

entryPoint := Point{0, 0}
g := NewHnsw(2, efc, 4, entryPoint)

for idx, point := range clusterA {
for idx, point := range cluster {
pointId := Id(idx + 1)
g.points[pointId] = &point
g.friends[pointId] = NewFriends(0)
Expand All @@ -53,8 +51,8 @@ func SetupClusterAHnsw() (*Hnsw, error) {
g.friends[pointId].InsertFriendsAtLevel(0, Id(0), distEntryToClusterPoint)
}

for idx, pointA := range clusterA {
for jdx, pointB := range clusterA {
for idx, pointA := range cluster {
for jdx, pointB := range cluster {
if idx == jdx {
continue
}
Expand All @@ -68,7 +66,7 @@ func SetupClusterAHnsw() (*Hnsw, error) {
}
}

for kdx := range clusterA {
for kdx := range cluster {
pointId := Id(kdx + 1)
friends, err := g.friends[pointId].GetFriendsAtLevel(0)
if err != nil {
Expand Down Expand Up @@ -120,7 +118,7 @@ func TestHnsw_SearchLevel(t *testing.T) {
})

t.Run("cluster a searchLayer for existing point", func(t *testing.T) {
g, err := SetupClusterAHnsw()
g, err := SetupClusterHnsw(clusterA)

if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -165,7 +163,7 @@ func TestHnsw_SearchLevel(t *testing.T) {
})

t.Run("cluster a searchLayer for new point", func(t *testing.T) {
g, err := SetupClusterAHnsw()
g, err := SetupClusterHnsw(clusterA)

if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -203,4 +201,71 @@ func TestHnsw_SearchLevel(t *testing.T) {
}
})

t.Run("cluster a, b, selectLayer and return the closest point", func(t *testing.T) {
clusterC := append(append([]Point{}, clusterA...), clusterB...)
g, err := SetupClusterHnsw(clusterC)
if err != nil {
t.Fatal(err)
}

qPoint := Point{2, 2}

closestNeighbor, err := g.searchLevel(&qPoint, &Item{id: 0, dist: EuclidDistance(Point{0, 0}, qPoint)}, 1, 0)

if err != nil {
t.Fatal(err)
}

if closestNeighbor.IsEmpty() {
t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor.Len())
}

closestItem, err := closestNeighbor.PopItem()
if err != nil {
t.Fatal(err)
}

if closestItem.id != Id(20) {
t.Fatalf("expected the closest point which is {4, 2} and id %v, got %v", Id(20), closestItem.id)
}
})

t.Run("cluster a, b, selectLayer and return the closest points from both clusters", func(t *testing.T) {
clusterC := append(append([]Point{}, clusterA...), clusterB...)
g, err := SetupClusterHnsw(clusterC)
if err != nil {
t.Fatal(err)
}

qPoint := Point{2, 2}

closestNeighbor, err := g.searchLevel(&qPoint, &Item{id: 0, dist: EuclidDistance(Point{0, 0}, qPoint)}, 4, 0)
if err != nil {
t.Fatal(err)
}

if closestNeighbor.IsEmpty() {
t.Fatalf("expected # of neighbors to return to be 1, got %v", closestNeighbor.Len())
}

if closestNeighbor.Len() != 4 {
t.Fatalf("expected # of neighbors to return to be %v, got %v", 4, closestNeighbor.Len())
}

var closestIds []Id

for !closestNeighbor.IsEmpty() {
closestItem, err := closestNeighbor.PopItem()
if err != nil {
t.Fatal(err)
}

closestIds = append(closestIds, closestItem.id)
}

if !reflect.DeepEqual(closestIds, []Id{20, 3, 4, 6}) {
t.Fatalf("expected the following closest ids: %v, got: %v", []Id{20, 3, 4, 6}, closestIds)
}
})

}

0 comments on commit 77b2f20

Please sign in to comment.