Skip to content

Commit

Permalink
Merge pull request #27 from RedisGraph/fix.sample.cdf
Browse files Browse the repository at this point in the history
In case of CDF vs random leads to after last bucket ensure that bucket pos is correct
  • Loading branch information
filipecosta90 authored Mar 17, 2023
2 parents 7ce8935 + 30c337a commit 98c7c14
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
5 changes: 4 additions & 1 deletion multi-query.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ import (
func sample(cdf []float32) int {
r := rand.Float32()
bucket := 0
for r > cdf[bucket] {
for (bucket < len(cdf)) && (r > cdf[bucket]) {
bucket++
}
if bucket >= len(cdf) {
bucket = bucket - 1
}
return bucket
}

Expand Down
26 changes: 26 additions & 0 deletions multi-query_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package main

import "testing"

func Test_sample(t *testing.T) {
type args struct {
cdf []float32
}
tests := []struct {
name string
args args
want int
}{
{name: "last bucket", args: args{[]float32{0.01, 0.99}}, want: 1},
{name: "last bucket with high likelyhood", args: args{[]float32{0.00001, 0.99}}, want: 1},
{name: "single bucket", args: args{[]float32{0.99}}, want: 0},
{name: "after bucket", args: args{[]float32{0.00000001}}, want: 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := sample(tt.args.cdf); got != tt.want {
t.Errorf("sample() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit 98c7c14

Please sign in to comment.