From a1f8dec1b57939a78411b57d083cd336607bfe0d Mon Sep 17 00:00:00 2001 From: Sergey Grebenshchikov Date: Wed, 9 Oct 2024 18:43:48 +0200 Subject: [PATCH] cleanup --- .gitignore | 1 + lsh/model_bench_test.go | 13 +++--------- lsh/model_test.go | 44 +++++++++++++++++++++-------------------- lsh/nearest.go | 1 - nearest.go | 10 ++++++---- 5 files changed, 33 insertions(+), 36 deletions(-) diff --git a/.gitignore b/.gitignore index 76b137f..323ec4b 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ *.txt *.test testdata +*.gz diff --git a/lsh/model_bench_test.go b/lsh/model_bench_test.go index b5d9449..5a107a9 100644 --- a/lsh/model_bench_test.go +++ b/lsh/model_bench_test.go @@ -15,19 +15,12 @@ func Benchmark_Model_Predict1(b *testing.B) { dataSize []int k []int } - mh := lsh.RandomMinHashR(testrandom.Source) - rbh := lsh.RandomBlurR(3, 20, testrandom.Source) hashes := []lsh.Hash{ - lsh.NoHash{}, - lsh.HashCompose{rbh, mh}, - lsh.RandomBitSampleR(30, testrandom.Source), - rbh, - mh, - lsh.RandomMinHashesR(10, testrandom.Source), + lsh.ConstantHash{}, // should be only a bit slower than exact KNN } benches := []bench{ - // {hashes: hashes, dataSize: []int{100}, k: []int{1, 3, 10}}, - {hashes: hashes, dataSize: []int{1_000_000}, k: []int{1, 1000}}, + {hashes: hashes, dataSize: []int{100}, k: []int{1, 3, 10}}, + {hashes: hashes, dataSize: []int{1_000_000}, k: []int{3, 10, 100}}, } for _, bench := range benches { for _, dataSize := range bench.dataSize { diff --git a/lsh/model_test.go b/lsh/model_test.go index 6f72ae1..7e352f9 100644 --- a/lsh/model_test.go +++ b/lsh/model_test.go @@ -2,6 +2,8 @@ package lsh_test import ( "math" + "reflect" + "slices" "testing" "github.com/keilerkonzept/bitknn" @@ -11,14 +13,14 @@ import ( func Test_Model_NoHash_IsExact(t *testing.T) { var h lsh.NoHash - _ = h var h0 lsh.ConstantHash + id := func(a uint64) uint64 { return a } rapid.Check(t, func(t *rapid.T) { - k := rapid.IntRange(1, 1001).Draw(t, "k") - data := rapid.SliceOfNDistinct(rapid.Uint64(), 3, 1000, func(a uint64) uint64 { return a }).Draw(t, "data") + k := rapid.IntRange(3, 1001).Draw(t, "k") + data := rapid.SliceOfNDistinct(rapid.Uint64(), 3, 1000, id).Draw(t, "data") labels := rapid.SliceOfN(rapid.IntRange(0, 3), len(data), len(data)).Draw(t, "labels") values := rapid.SliceOfN(rapid.Float64(), len(data), len(data)).Draw(t, "values") - queries := rapid.SliceOfN(rapid.Uint64(), 3, 64).Draw(t, "queries") + queries := rapid.SliceOfNDistinct(rapid.Uint64(), 3, 64, id).Draw(t, "queries") knnVotes := make([]float64, 4) annVotes := make([]float64, 4) type pair struct { @@ -31,52 +33,53 @@ func Test_Model_NoHash_IsExact(t *testing.T) { { "V", bitknn.Fit(data, labels, bitknn.WithValues(values)), - lsh.Fit(data, labels, h0, bitknn.WithValues(values)), + lsh.Fit(data, labels, h, bitknn.WithValues(values)), lsh.Fit(data, labels, h0, bitknn.WithValues(values)), }, { "LV", bitknn.Fit(data, labels, bitknn.WithLinearDistanceWeighting(), bitknn.WithValues(values)), - lsh.Fit(data, labels, h0, bitknn.WithLinearDistanceWeighting(), bitknn.WithValues(values)), + lsh.Fit(data, labels, h, bitknn.WithLinearDistanceWeighting(), bitknn.WithValues(values)), lsh.Fit(data, labels, h0, bitknn.WithLinearDistanceWeighting(), bitknn.WithValues(values)), }, { "QV", bitknn.Fit(data, labels, bitknn.WithQuadraticDistanceWeighting(), bitknn.WithValues(values)), - lsh.Fit(data, labels, h0, bitknn.WithQuadraticDistanceWeighting(), bitknn.WithValues(values)), + lsh.Fit(data, labels, h, bitknn.WithQuadraticDistanceWeighting(), bitknn.WithValues(values)), lsh.Fit(data, labels, h0, bitknn.WithQuadraticDistanceWeighting(), bitknn.WithValues(values)), }, { "CV", bitknn.Fit(data, labels, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear), bitknn.WithValues(values)), - lsh.Fit(data, labels, h0, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear), bitknn.WithValues(values)), + lsh.Fit(data, labels, h, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear), bitknn.WithValues(values)), lsh.Fit(data, labels, h0, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear), bitknn.WithValues(values)), }, { "0", bitknn.Fit(data, labels), - lsh.Fit(data, labels, h0), + lsh.Fit(data, labels, h), lsh.Fit(data, labels, h0), }, { "L", bitknn.Fit(data, labels, bitknn.WithLinearDistanceWeighting()), - lsh.Fit(data, labels, h0, bitknn.WithLinearDistanceWeighting()), + lsh.Fit(data, labels, h, bitknn.WithLinearDistanceWeighting()), lsh.Fit(data, labels, h0, bitknn.WithLinearDistanceWeighting()), }, { "Q", bitknn.Fit(data, labels, bitknn.WithQuadraticDistanceWeighting()), - lsh.Fit(data, labels, h0, bitknn.WithQuadraticDistanceWeighting()), + lsh.Fit(data, labels, h, bitknn.WithQuadraticDistanceWeighting()), lsh.Fit(data, labels, h0, bitknn.WithQuadraticDistanceWeighting()), }, { "C", bitknn.Fit(data, labels, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear)), - lsh.Fit(data, labels, h0, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear)), + lsh.Fit(data, labels, h, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear)), lsh.Fit(data, labels, h0, bitknn.WithDistanceWeightingFunc(bitknn.DistanceWeightingFuncLinear)), }, } + const eps = 1e-8 for _, pair := range pairs { knn := pair.KNN ann := pair.ANN @@ -87,25 +90,24 @@ func Test_Model_NoHash_IsExact(t *testing.T) { knn.Predict1(k, q, knnVotes) ann.Predict1(k, q, annVotes) - const eps = 1e-8 - for i, vk := range knnVotes { - va := annVotes[i] - if math.Abs(vk-va) > eps { - t.Fatalf("%s: %v: %v %v", pair.name, q, knnVotes, annVotes) - } + slices.Sort(knn.HeapDistances[:k]) + slices.Sort(ann.HeapDistances[:k]) + if !reflect.DeepEqual(knn.HeapDistances[:k], ann.HeapDistances[:k]) { + t.Fatal("NoHash ANN should result in the same distances for the nearest neighbors: ", knn.HeapDistances[:k], ann.HeapDistances[:k], knn.HeapIndices[:k], ann.HeapIndices[:k]) } - ann.Predict1Alloc(k, q, annVotes) + + ann0.Predict1Alloc(k, q, annVotes) for i, vk := range knnVotes { va := annVotes[i] if math.Abs(vk-va) > eps { - t.Fatalf("%s: %v: %v %v", pair.name, q, knnVotes, annVotes) + t.Fatalf("ANN: %s: %v: %v %v", pair.name, q, knnVotes, annVotes) } } ann0.Predict1(k, q, annVotes) for i, vk := range knnVotes { va := annVotes[i] if math.Abs(vk-va) > eps { - t.Fatalf("%s: %v: %v %v", pair.name, q, knnVotes, annVotes) + t.Fatalf("ANN0: %s: %v: %v %v", pair.name, q, knnVotes, annVotes) } } } diff --git a/lsh/nearest.go b/lsh/nearest.go index ee9c835..400dc76 100644 --- a/lsh/nearest.go +++ b/lsh/nearest.go @@ -133,5 +133,4 @@ func nearestBuckets(bucketIDs []uint64, k int, x uint64, distance0 *int, heap *h heap.PushPop(dist, b) maxDist = *distance0 } - return } diff --git a/nearest.go b/nearest.go index b61f842..ead3c94 100644 --- a/nearest.go +++ b/nearest.go @@ -14,25 +14,27 @@ import ( // cap(distances) = cap(indices) = k+1 >= 1 func Nearest(data []uint64, k int, x uint64, distances, indices []int) int { heap := heap.MakeMax(distances, indices) + distance0 := &distances[0] k0 := min(k, len(data)) - for i := 0; i < k0; i++ { - dist := bits.OnesCount64(x ^ data[i]) + for i, d := range data[:k0] { + dist := bits.OnesCount64(x ^ d) heap.Push(dist, i) } + if k0 < k { return k0 } - maxDist := distances[0] + maxDist := *distance0 for i := k; i < len(data); i++ { dist := bits.OnesCount64(x ^ data[i]) if dist >= maxDist { continue } heap.PushPop(dist, i) - maxDist = distances[0] + maxDist = *distance0 } return k }