Skip to content

Commit

Permalink
In case of CDF vs random leads to after last bucket ensure that bucke…
Browse files Browse the repository at this point in the history
…t pos is correct
  • Loading branch information
filipecosta90 committed Mar 17, 2023
1 parent 7ce8935 commit 30c337a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
5 changes: 4 additions & 1 deletion multi-query.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ import (
func sample(cdf []float32) int {
r := rand.Float32()
bucket := 0
for r > cdf[bucket] {
for (bucket < len(cdf)) && (r > cdf[bucket]) {
bucket++
}
if bucket >= len(cdf) {
bucket = bucket - 1
}
return bucket
}

Expand Down
26 changes: 26 additions & 0 deletions multi-query_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package main

import "testing"

func Test_sample(t *testing.T) {
type args struct {
cdf []float32
}
tests := []struct {
name string
args args
want int
}{
{name: "last bucket", args: args{[]float32{0.01, 0.99}}, want: 1},
{name: "last bucket with high likelyhood", args: args{[]float32{0.00001, 0.99}}, want: 1},
{name: "single bucket", args: args{[]float32{0.99}}, want: 0},
{name: "after bucket", args: args{[]float32{0.00000001}}, want: 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := sample(tt.args.cdf); got != tt.want {
t.Errorf("sample() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit 30c337a

Please sign in to comment.