diff --git a/go/mysql/decimal/decimal.go b/go/mysql/decimal/decimal.go index a3017190609..a2b505a1232 100644 --- a/go/mysql/decimal/decimal.go +++ b/go/mysql/decimal/decimal.go @@ -677,6 +677,10 @@ func (d *Decimal) ensureInitialized() { } } +func (d Decimal) IsInitialized() bool { + return d.value != nil +} + // RescalePair rescales two decimals to common exponential value (minimal exp of both decimals) func RescalePair(d1 Decimal, d2 Decimal) (Decimal, Decimal) { d1.ensureInitialized() diff --git a/go/sqltypes/testing.go b/go/sqltypes/testing.go index f5aa58f0a72..b591cf710f0 100644 --- a/go/sqltypes/testing.go +++ b/go/sqltypes/testing.go @@ -18,8 +18,14 @@ package sqltypes import ( "bytes" + crand "crypto/rand" + "encoding/base64" + "encoding/hex" "fmt" + "math/rand" + "strconv" "strings" + "time" querypb "vitess.io/vitess/go/vt/proto/query" ) @@ -154,3 +160,124 @@ func PrintResults(results []*Result) string { func split(str string) []string { return strings.Split(str, "|") } + +func TestRandomValues() (Value, Value) { + if rand.Int()%2 == 0 { + // create a single value, and turn it into two different types + v := rand.Int() + return randomNumericType(v), randomNumericType(v) + } + + // just produce two arbitrary random values and compare + return randomNumericType(rand.Int()), randomNumericType(rand.Int()) +} + +func randomNumericType(i int) Value { + r := rand.Intn(len(numericTypes)) + return numericTypes[r](i) +} + +var numericTypes = []func(int) Value{ + func(i int) Value { return NULL }, + func(i int) Value { return NewInt8(int8(i)) }, + func(i int) Value { return NewInt32(int32(i)) }, + func(i int) Value { return NewInt64(int64(i)) }, + func(i int) Value { return NewUint64(uint64(i)) }, + func(i int) Value { return NewUint32(uint32(i)) }, + func(i int) Value { return NewFloat64(float64(i)) }, + func(i int) Value { return NewDecimal(fmt.Sprintf("%d", i)) }, + func(i int) Value { return NewVarChar(fmt.Sprintf("%d", i)) }, + func(i int) Value { return NewVarChar(fmt.Sprintf(" %f aa", float64(i))) }, +} + +type RandomGenerator func() Value + +func randomBytes() []byte { + b := make([]byte, rand.Intn(128)) + _, _ = crand.Read(b) + return b +} + +var RandomGenerators = map[Type]RandomGenerator{ + Null: func() Value { + return NULL + }, + Int8: func() Value { + return NewInt8(int8(rand.Intn(255))) + }, + Int32: func() Value { + return NewInt32(rand.Int31()) + }, + Int64: func() Value { + return NewInt64(rand.Int63()) + }, + Uint32: func() Value { + return NewUint32(rand.Uint32()) + }, + Uint64: func() Value { + return NewUint64(rand.Uint64()) + }, + Float64: func() Value { + return NewFloat64(rand.ExpFloat64()) + }, + Decimal: func() Value { + dec := fmt.Sprintf("%d.%d", rand.Intn(9999999999), rand.Intn(9999999999)) + if rand.Int()&0x1 == 1 { + dec = "-" + dec + } + return NewDecimal(dec) + }, + VarChar: func() Value { + return NewVarChar(base64.StdEncoding.EncodeToString(randomBytes())) + }, + VarBinary: func() Value { + return NewVarBinary(string(randomBytes())) + }, + Date: func() Value { + return NewDate(randTime().Format(time.DateOnly)) + }, + Datetime: func() Value { + return NewDatetime(randTime().Format(time.DateTime)) + }, + Timestamp: func() Value { + return NewTimestamp(randTime().Format(time.DateTime)) + }, + Time: func() Value { + return NewTime(randTime().Format(time.TimeOnly)) + }, + TypeJSON: func() Value { + var j string + switch rand.Intn(6) { + case 0: + j = "null" + case 1: + i := rand.Int63() + if rand.Int()&0x1 == 1 { + i = -i + } + j = strconv.FormatInt(i, 10) + case 2: + j = strconv.FormatFloat(rand.NormFloat64(), 'g', -1, 64) + case 3: + j = strconv.Quote(hex.EncodeToString(randomBytes())) + case 4: + j = "true" + case 5: + j = "false" + } + v, err := NewJSON(j) + if err != nil { + panic(err) + } + return v + }, +} + +func randTime() time.Time { + min := time.Date(1970, 1, 0, 0, 0, 0, 0, time.UTC).Unix() + max := time.Date(2070, 1, 0, 0, 0, 0, 0, time.UTC).Unix() + delta := max - min + + sec := rand.Int63n(delta) + min + return time.Unix(sec, 0) +} diff --git a/go/sqltypes/value.go b/go/sqltypes/value.go index a36d6fa2858..8d95a94561f 100644 --- a/go/sqltypes/value.go +++ b/go/sqltypes/value.go @@ -31,6 +31,7 @@ import ( "vitess.io/vitess/go/hack" "vitess.io/vitess/go/mysql/decimal" "vitess.io/vitess/go/mysql/fastparse" + "vitess.io/vitess/go/mysql/format" querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -186,12 +187,12 @@ func NewBoolean(v bool) Value { // NewFloat64 builds an Float64 Value. func NewFloat64(v float64) Value { - return MakeTrusted(Float64, strconv.AppendFloat(nil, v, 'g', -1, 64)) + return MakeTrusted(Float64, format.FormatFloat(v)) } // NewFloat32 builds a Float32 Value. func NewFloat32(v float32) Value { - return MakeTrusted(Float32, strconv.AppendFloat(nil, float64(v), 'g', -1, 64)) + return MakeTrusted(Float32, format.FormatFloat(float64(v))) } // NewVarChar builds a VarChar Value. diff --git a/go/vt/vtgate/engine/aggregations.go b/go/vt/vtgate/engine/aggregations.go index d8c648a4ce4..fbacf3d13c5 100644 --- a/go/vt/vtgate/engine/aggregations.go +++ b/go/vt/vtgate/engine/aggregations.go @@ -91,210 +91,360 @@ func (ap *AggregateParams) String() string { } func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type { - opCode := ap.Opcode if ap.OrigOpcode != AggregateUnassigned { - opCode = ap.OrigOpcode + return ap.OrigOpcode.Type(inputType) } - typ, _ := opCode.Type(&inputType) - return typ -} - -func convertRow( - fields []*querypb.Field, - row []sqltypes.Value, - aggregates []*AggregateParams, -) (newRow []sqltypes.Value, curDistincts []sqltypes.Value) { - newRow = append(newRow, row...) - curDistincts = make([]sqltypes.Value, len(aggregates)) - for index, aggr := range aggregates { - switch aggr.Opcode { - case AggregateCountStar: - newRow[aggr.Col] = countOne - case AggregateCount: - val := countOne - if row[aggr.Col].IsNull() { - val = countZero - } - newRow[aggr.Col] = val - case AggregateCountDistinct: - curDistincts[index] = findComparableCurrentDistinct(row, aggr) - // Type is int64. Ok to call MakeTrusted. - if row[aggr.KeyCol].IsNull() { - newRow[aggr.Col] = countZero - } else { - newRow[aggr.Col] = countOne - } - case AggregateSum: - if row[aggr.Col].IsNull() { - break - } - var err error - newRow[aggr.Col], err = sqltypes.Cast(row[aggr.Col], fields[aggr.Col].Type) - if err != nil { - newRow[aggr.Col] = sumZero - } - case AggregateSumDistinct: - curDistincts[index] = findComparableCurrentDistinct(row, aggr) - var err error - newRow[aggr.Col], err = sqltypes.Cast(row[aggr.Col], fields[aggr.Col].Type) + return ap.Opcode.Type(inputType) +} + +type aggregator interface { + add(row []sqltypes.Value) error + finish() sqltypes.Value + reset() +} + +type aggregatorDistinct struct { + column int + last sqltypes.Value + coll collations.ID +} + +func (a *aggregatorDistinct) shouldReturn(row []sqltypes.Value) (bool, error) { + if a.column >= 0 { + if !a.last.IsNull() { + cmp, err := evalengine.NullsafeCompare(a.last, row[a.column], a.coll) if err != nil { - newRow[aggr.Col] = sumZero + return true, err } - case AggregateGtid: - vgtid := &binlogdatapb.VGtid{} - vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{ - Keyspace: row[aggr.Col-1].ToString(), - Shard: row[aggr.Col+1].ToString(), - Gtid: row[aggr.Col].ToString(), - }) - data, _ := vgtid.MarshalVT() - val, _ := sqltypes.NewValue(sqltypes.VarBinary, data) - newRow[aggr.Col] = val - case AggregateGroupConcat: - if !row[aggr.Col].IsNull() { - newRow[aggr.Col] = sqltypes.MakeTrusted(fields[aggr.Col].Type, []byte(row[aggr.Col].ToString())) + if cmp == 0 { + return true, nil } } + a.last = row[a.column] + } + return false, nil +} + +func (a *aggregatorDistinct) reset() { + a.last = sqltypes.NULL +} + +type aggregatorCount struct { + from int + n int64 + distinct aggregatorDistinct +} + +func (a *aggregatorCount) add(row []sqltypes.Value) error { + if row[a.from].IsNull() { + return nil + } + if ret, err := a.distinct.shouldReturn(row); ret { + return err + } + a.n++ + return nil +} + +func (a *aggregatorCount) finish() sqltypes.Value { + return sqltypes.NewInt64(a.n) +} + +func (a *aggregatorCount) reset() { + a.n = 0 + a.distinct.reset() +} + +type aggregatorCountStar struct { + n int64 +} + +func (a *aggregatorCountStar) add(_ []sqltypes.Value) error { + a.n++ + return nil +} + +func (a *aggregatorCountStar) finish() sqltypes.Value { + return sqltypes.NewInt64(a.n) +} + +func (a *aggregatorCountStar) reset() { + a.n = 0 +} + +type aggregatorMinMax struct { + from int + minmax evalengine.MinMax +} + +type aggregatorMin struct { + aggregatorMinMax +} + +func (a *aggregatorMin) add(row []sqltypes.Value) (err error) { + return a.minmax.Min(row[a.from]) +} + +type aggregatorMax struct { + aggregatorMinMax +} + +func (a *aggregatorMax) add(row []sqltypes.Value) (err error) { + return a.minmax.Max(row[a.from]) +} + +func (a *aggregatorMinMax) finish() sqltypes.Value { + return a.minmax.Result() +} + +func (a *aggregatorMinMax) reset() { + a.minmax.Reset() +} + +type aggregatorSum struct { + from int + sum evalengine.Sum + distinct aggregatorDistinct +} + +func (a *aggregatorSum) add(row []sqltypes.Value) error { + if row[a.from].IsNull() { + return nil + } + if ret, err := a.distinct.shouldReturn(row); ret { + return err } - return newRow, curDistincts + return a.sum.Add(row[a.from]) } -func merge( - fields []*querypb.Field, - row1, row2 []sqltypes.Value, - curDistincts []sqltypes.Value, - aggregates []*AggregateParams, -) ([]sqltypes.Value, []sqltypes.Value, error) { - result := sqltypes.CopyRow(row1) - for index, aggr := range aggregates { +func (a *aggregatorSum) finish() sqltypes.Value { + return a.sum.Result() +} + +func (a *aggregatorSum) reset() { + a.sum.Reset() + a.distinct.reset() +} + +type aggregatorScalar struct { + from int + current sqltypes.Value + init bool +} + +func (a *aggregatorScalar) add(row []sqltypes.Value) error { + if !a.init { + a.current = row[a.from] + a.init = true + } + return nil +} + +func (a *aggregatorScalar) finish() sqltypes.Value { + return a.current +} + +func (a *aggregatorScalar) reset() { + a.current = sqltypes.NULL + a.init = false +} + +type aggregatorGroupConcat struct { + from int + type_ sqltypes.Type + + concat []byte + n int +} + +func (a *aggregatorGroupConcat) add(row []sqltypes.Value) error { + if row[a.from].IsNull() { + return nil + } + if a.n > 0 { + a.concat = append(a.concat, ',') + } + a.concat = append(a.concat, row[a.from].Raw()...) + a.n++ + return nil +} + +func (a *aggregatorGroupConcat) finish() sqltypes.Value { + if a.n == 0 { + return sqltypes.NULL + } + return sqltypes.MakeTrusted(a.type_, a.concat) +} + +func (a *aggregatorGroupConcat) reset() { + a.n = 0 + a.concat = nil // not safe to reuse this byte slice as it's returned as MakeTrusted +} + +type aggregatorGtid struct { + from int + shards []*binlogdatapb.ShardGtid +} + +func (a *aggregatorGtid) add(row []sqltypes.Value) error { + a.shards = append(a.shards, &binlogdatapb.ShardGtid{ + Keyspace: row[a.from-1].ToString(), + Shard: row[a.from+1].ToString(), + Gtid: row[a.from].ToString(), + }) + return nil +} + +func (a *aggregatorGtid) finish() sqltypes.Value { + gtid := binlogdatapb.VGtid{ShardGtids: a.shards} + return sqltypes.NewVarChar(gtid.String()) +} + +func (a *aggregatorGtid) reset() { + a.shards = a.shards[:0] // safe to reuse because only the serialized form of a.shards is returned +} + +type aggregationState []aggregator + +func (a aggregationState) add(row []sqltypes.Value) error { + for _, st := range a { + if err := st.add(row); err != nil { + return err + } + } + return nil +} + +func (a aggregationState) finish() (row []sqltypes.Value) { + row = make([]sqltypes.Value, 0, len(a)) + for _, st := range a { + row = append(row, st.finish()) + } + return +} + +func (a aggregationState) reset() { + for _, st := range a { + st.reset() + } +} + +func isComparable(typ sqltypes.Type) bool { + if typ == sqltypes.Null || sqltypes.IsNumber(typ) || sqltypes.IsBinary(typ) { + return true + } + switch typ { + case sqltypes.Timestamp, + sqltypes.Date, + sqltypes.Time, + sqltypes.Datetime, + sqltypes.Enum, + sqltypes.Set, + sqltypes.TypeJSON, + sqltypes.Bit: + return true + } + return false +} + +func newAggregation(fields []*querypb.Field, aggregates []*AggregateParams) (aggregationState, []*querypb.Field, error) { + fields = slice.Map(fields, func(from *querypb.Field) *querypb.Field { + return proto.Clone(from).(*querypb.Field) + }) + + agstate := make([]aggregator, len(fields)) + for _, aggr := range aggregates { + sourceType := fields[aggr.Col].Type + targetType := aggr.typ(sourceType) + + var ag aggregator + var distinct = -1 + if aggr.Opcode.IsDistinct() { - if row2[aggr.KeyCol].IsNull() { - continue + distinct = aggr.KeyCol + if aggr.WAssigned() && !isComparable(sourceType) { + distinct = aggr.WCol } - cmp, err := evalengine.NullsafeCompare(curDistincts[index], row2[aggr.KeyCol], aggr.CollationID) - if err != nil { - return nil, nil, err - } - if cmp == 0 { - continue + } + + if aggr.Opcode == AggregateMin || aggr.Opcode == AggregateMax { + if aggr.WAssigned() && !isComparable(sourceType) { + return nil, nil, vterrors.VT12001("min/max on types that are not comparable is not supported") } - curDistincts[index] = findComparableCurrentDistinct(row2, aggr) } - var err error switch aggr.Opcode { case AggregateCountStar: - result[aggr.Col], err = evalengine.NullSafeAdd(row1[aggr.Col], countOne, fields[aggr.Col].Type) - case AggregateCount: - val := countOne - if row2[aggr.Col].IsNull() { - val = countZero + ag = &aggregatorCountStar{} + + case AggregateCount, AggregateCountDistinct: + ag = &aggregatorCount{ + from: aggr.Col, + distinct: aggregatorDistinct{ + column: distinct, + coll: aggr.CollationID, + }, } - result[aggr.Col], err = evalengine.NullSafeAdd(row1[aggr.Col], val, fields[aggr.Col].Type) - case AggregateSum: - value := row1[aggr.Col] - v2 := row2[aggr.Col] - if value.IsNull() && v2.IsNull() { - result[aggr.Col] = sqltypes.NULL - break + + case AggregateSum, AggregateSumDistinct: + var sum evalengine.Sum + switch aggr.OrigOpcode { + case AggregateCount, AggregateCountStar, AggregateCountDistinct: + sum = evalengine.NewSumOfCounts() + default: + sum = evalengine.NewAggregationSum(sourceType) } - result[aggr.Col], err = evalengine.NullSafeAdd(value, v2, fields[aggr.Col].Type) + + ag = &aggregatorSum{ + from: aggr.Col, + sum: sum, + distinct: aggregatorDistinct{ + column: distinct, + coll: aggr.CollationID, + }, + } + case AggregateMin: - if aggr.WAssigned() && !row2[aggr.Col].IsComparable() { - return minMaxWeightStringError() + ag = &aggregatorMin{ + aggregatorMinMax{ + from: aggr.Col, + minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationID), + }, } - result[aggr.Col], err = evalengine.Min(row1[aggr.Col], row2[aggr.Col], aggr.CollationID) + case AggregateMax: - if aggr.WAssigned() && !row2[aggr.Col].IsComparable() { - return minMaxWeightStringError() + ag = &aggregatorMax{ + aggregatorMinMax{ + from: aggr.Col, + minmax: evalengine.NewAggregationMinMax(sourceType, aggr.CollationID), + }, } - result[aggr.Col], err = evalengine.Max(row1[aggr.Col], row2[aggr.Col], aggr.CollationID) - case AggregateCountDistinct: - result[aggr.Col], err = evalengine.NullSafeAdd(row1[aggr.Col], countOne, fields[aggr.Col].Type) - case AggregateSumDistinct: - result[aggr.Col], err = evalengine.NullSafeAdd(row1[aggr.Col], row2[aggr.Col], fields[aggr.Col].Type) + case AggregateGtid: - vgtid := &binlogdatapb.VGtid{} - rowBytes, err := row1[aggr.Col].ToBytes() - if err != nil { - return nil, nil, err - } - err = vgtid.UnmarshalVT(rowBytes) - if err != nil { - return nil, nil, err - } - vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{ - Keyspace: row2[aggr.Col-1].ToString(), - Shard: row2[aggr.Col+1].ToString(), - Gtid: row2[aggr.Col].ToString(), - }) - data, _ := vgtid.MarshalVT() - val, _ := sqltypes.NewValue(sqltypes.VarBinary, data) - result[aggr.Col] = val + ag = &aggregatorGtid{from: aggr.Col} + case AggregateAnyValue: - // we just grab the first value per grouping. no need to do anything more complicated here - case AggregateGroupConcat: - if row2[aggr.Col].IsNull() { - break - } - if result[aggr.Col].IsNull() { - result[aggr.Col] = sqltypes.MakeTrusted(fields[aggr.Col].Type, []byte(row2[aggr.Col].ToString())) - break - } - concat := row1[aggr.Col].ToString() + "," + row2[aggr.Col].ToString() - result[aggr.Col] = sqltypes.MakeTrusted(fields[aggr.Col].Type, []byte(concat)) - default: - return nil, nil, fmt.Errorf("BUG: Unexpected opcode: %v", aggr.Opcode) - } - if err != nil { - return nil, nil, err - } - } - return result, curDistincts, nil -} + ag = &aggregatorScalar{from: aggr.Col} -func minMaxWeightStringError() ([]sqltypes.Value, []sqltypes.Value, error) { - return nil, nil, vterrors.VT12001("min/max on types that are not comparable is not supported") -} + case AggregateGroupConcat: + ag = &aggregatorGroupConcat{from: aggr.Col, type_: targetType} -func convertFinal(current []sqltypes.Value, aggregates []*AggregateParams) ([]sqltypes.Value, error) { - result := sqltypes.CopyRow(current) - for _, aggr := range aggregates { - switch aggr.Opcode { - case AggregateGtid: - vgtid := &binlogdatapb.VGtid{} - currentBytes, err := current[aggr.Col].ToBytes() - if err != nil { - return nil, err - } - err = vgtid.UnmarshalVT(currentBytes) - if err != nil { - return nil, err - } - result[aggr.Col] = sqltypes.NewVarChar(vgtid.String()) + default: + panic("BUG: unexpected Aggregation opcode") } - } - return result, nil -} -func convertFields(fields []*querypb.Field, aggrs []*AggregateParams) []*querypb.Field { - fields = slice.Map(fields, func(from *querypb.Field) *querypb.Field { - return proto.Clone(from).(*querypb.Field) - }) - for _, aggr := range aggrs { - fields[aggr.Col].Type = aggr.typ(fields[aggr.Col].Type) + agstate[aggr.Col] = ag + fields[aggr.Col].Type = targetType if aggr.Alias != "" { fields[aggr.Col].Name = aggr.Alias } } - return fields -} -func findComparableCurrentDistinct(row []sqltypes.Value, aggr *AggregateParams) sqltypes.Value { - curDistinct := row[aggr.KeyCol] - if aggr.WAssigned() && !curDistinct.IsComparable() { - aggr.KeyCol = aggr.WCol - curDistinct = row[aggr.KeyCol] + for i, a := range agstate { + if a == nil { + agstate[i] = &aggregatorScalar{from: i} + } } - return curDistinct + + return agstate, fields, nil } diff --git a/go/vt/vtgate/engine/aggregations_test.go b/go/vt/vtgate/engine/aggregations_test.go new file mode 100644 index 00000000000..55ec59f73e1 --- /dev/null +++ b/go/vt/vtgate/engine/aggregations_test.go @@ -0,0 +1,181 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package engine + +import ( + "context" + "fmt" + "math/rand" + "strings" + "testing" + + "github.com/google/uuid" + + "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" + . "vitess.io/vitess/go/vt/vtgate/engine/opcode" +) + +func makeTestResults(fields []*querypb.Field, gen []sqltypes.RandomGenerator, N int) []*sqltypes.Result { + result := &sqltypes.Result{Fields: fields} + + for i := 0; i < N; i++ { + row := make([]sqltypes.Value, 0, len(fields)) + for _, f := range gen { + row = append(row, f()) + } + result.Rows = append(result.Rows, row) + } + + return []*sqltypes.Result{result} +} + +func benchmarkName(fields []*querypb.Field) string { + var buf strings.Builder + for i, f := range fields { + if i > 0 { + buf.WriteByte('_') + } + fmt.Fprintf(&buf, "%s(%s)", f.Name, f.Type.String()) + } + return buf.String() +} + +func BenchmarkScalarAggregate(b *testing.B) { + var rand_i64 = sqltypes.RandomGenerators[sqltypes.Int64] + var rand_i64small = func() sqltypes.Value { + return sqltypes.NewInt64(rand.Int63n(1024)) + } + var rand_f64 = sqltypes.RandomGenerators[sqltypes.Float64] + var rand_dec = sqltypes.RandomGenerators[sqltypes.Decimal] + var rand_bin = sqltypes.RandomGenerators[sqltypes.VarBinary] + + var cases = []struct { + fields []*querypb.Field + gen []sqltypes.RandomGenerator + params []*AggregateParams + }{ + { + fields: sqltypes.MakeTestFields("count", "int64"), + gen: []sqltypes.RandomGenerator{rand_i64}, + params: []*AggregateParams{ + {Opcode: AggregateCount, Col: 0}, + }, + }, + { + fields: sqltypes.MakeTestFields("sum_small", "int64"), + gen: []sqltypes.RandomGenerator{rand_i64small}, + params: []*AggregateParams{ + {Opcode: AggregateSum, Col: 0}, + }, + }, + { + fields: sqltypes.MakeTestFields("sum", "int64"), + gen: []sqltypes.RandomGenerator{rand_i64}, + params: []*AggregateParams{ + {Opcode: AggregateSum, Col: 0}, + }, + }, + { + fields: sqltypes.MakeTestFields("sum", "float64"), + gen: []sqltypes.RandomGenerator{rand_f64}, + params: []*AggregateParams{ + {Opcode: AggregateSum, Col: 0}, + }, + }, + { + fields: sqltypes.MakeTestFields("sum", "decimal"), + gen: []sqltypes.RandomGenerator{rand_dec}, + params: []*AggregateParams{ + {Opcode: AggregateSum, Col: 0}, + }, + }, + { + fields: sqltypes.MakeTestFields("min", "int64"), + gen: []sqltypes.RandomGenerator{rand_i64}, + params: []*AggregateParams{ + {Opcode: AggregateMin, Col: 0}, + }, + }, + { + fields: sqltypes.MakeTestFields("min", "float64"), + gen: []sqltypes.RandomGenerator{rand_f64}, + params: []*AggregateParams{ + {Opcode: AggregateMin, Col: 0}, + }, + }, + { + fields: sqltypes.MakeTestFields("min", "decimal"), + gen: []sqltypes.RandomGenerator{rand_dec}, + params: []*AggregateParams{ + {Opcode: AggregateMin, Col: 0}, + }, + }, + { + fields: sqltypes.MakeTestFields("min", "varbinary"), + gen: []sqltypes.RandomGenerator{rand_bin}, + params: []*AggregateParams{ + {Opcode: AggregateMin, Col: 0}, + }, + }, + { + fields: sqltypes.MakeTestFields("keyspace|gtid|shard", "varchar|varchar|varchar"), + gen: []sqltypes.RandomGenerator{ + func() sqltypes.Value { + return sqltypes.NewVarChar("keyspace") + }, + func() sqltypes.Value { + return sqltypes.NewVarChar(uuid.New().String()) + }, + func() sqltypes.Value { + return sqltypes.NewVarChar(fmt.Sprintf("%x-%x", rand.Intn(256), rand.Intn(256))) + }, + }, + params: []*AggregateParams{ + {Opcode: AggregateGtid, Col: 1}, + }, + }, + } + + for _, tc := range cases { + b.Run(benchmarkName(tc.fields), func(b *testing.B) { + results := makeTestResults(tc.fields, tc.gen, 10000) + + fp := &fakePrimitive{ + allResultsInOneCall: true, + results: results, + } + oa := &ScalarAggregate{ + Aggregates: tc.params, + Input: fp, + } + + b.Run("TryExecute", func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + fp.rewind() + _, err := oa.TryExecute(context.Background(), &noopVCursor{}, nil, true) + if err != nil { + panic(err) + } + } + }) + }) + } +} diff --git a/go/vt/vtgate/engine/opcode/constants.go b/go/vt/vtgate/engine/opcode/constants.go index 818a9e67db6..b8df30ff01b 100644 --- a/go/vt/vtgate/engine/opcode/constants.go +++ b/go/vt/vtgate/engine/opcode/constants.go @@ -128,35 +128,26 @@ func (code AggregateOpcode) MarshalJSON() ([]byte, error) { } // Type returns the opcode return sql type, and a bool telling is we are sure about this type or not -func (code AggregateOpcode) Type(typ *querypb.Type) (querypb.Type, bool) { +func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type { switch code { case AggregateUnassigned: - return sqltypes.Null, false + return sqltypes.Null case AggregateGroupConcat: - if typ == nil { - return sqltypes.Text, false + if sqltypes.IsBinary(typ) { + return sqltypes.Blob } - if sqltypes.IsBinary(*typ) { - return sqltypes.Blob, true - } - return sqltypes.Text, true + return sqltypes.Text case AggregateMax, AggregateMin, AggregateAnyValue: - if typ == nil { - return sqltypes.Null, false - } - return *typ, true + return typ case AggregateSumDistinct, AggregateSum: - if typ == nil { - return sqltypes.Float64, false - } - if sqltypes.IsIntegral(*typ) || sqltypes.IsDecimal(*typ) { - return sqltypes.Decimal, true + if sqltypes.IsIntegral(typ) || sqltypes.IsDecimal(typ) { + return sqltypes.Decimal } - return sqltypes.Float64, true + return sqltypes.Float64 case AggregateCount, AggregateCountStar, AggregateCountDistinct: - return sqltypes.Int64, true + return sqltypes.Int64 case AggregateGtid: - return sqltypes.VarChar, true + return sqltypes.VarChar default: panic(code.String()) // we have a unit test checking we never reach here } diff --git a/go/vt/vtgate/engine/opcode/constants_test.go b/go/vt/vtgate/engine/opcode/constants_test.go index ac204a9f8dc..50cfc49a71c 100644 --- a/go/vt/vtgate/engine/opcode/constants_test.go +++ b/go/vt/vtgate/engine/opcode/constants_test.go @@ -16,11 +16,15 @@ limitations under the License. package opcode -import "testing" +import ( + "testing" + + "vitess.io/vitess/go/sqltypes" +) func TestCheckAllAggrOpCodes(t *testing.T) { // This test is just checking that we never reach the panic when using Type() on valid opcodes for i := AggregateOpcode(0); i < _NumOfOpCodes; i++ { - i.Type(nil) + i.Type(sqltypes.Null) } } diff --git a/go/vt/vtgate/engine/ordered_aggregate.go b/go/vt/vtgate/engine/ordered_aggregate.go index be142cd2a5d..acb958199d0 100644 --- a/go/vt/vtgate/engine/ordered_aggregate.go +++ b/go/vt/vtgate/engine/ordered_aggregate.go @@ -125,114 +125,98 @@ func (oa *OrderedAggregate) execute(ctx context.Context, vcursor VCursor, bindVa if err != nil { return nil, err } - fields := convertFields(result.Fields, oa.Aggregates) + agg, fields, err := newAggregation(result.Fields, oa.Aggregates) + if err != nil { + return nil, err + } + out := &sqltypes.Result{ Fields: fields, Rows: make([][]sqltypes.Value, 0, len(result.Rows)), } - // This code is similar to the one in StreamExecute. - var current []sqltypes.Value - var curDistincts []sqltypes.Value + var currentKey []sqltypes.Value for _, row := range result.Rows { - // this is the first row. set up everything - if current == nil { - current, curDistincts = convertRow(fields, row, oa.Aggregates) - continue - } + var nextGroup bool - // not the first row. are we still in the old group, or is this a new grouping?= - equal, err := oa.keysEqual(current, row) + currentKey, nextGroup, err = oa.nextGroupBy(currentKey, row) if err != nil { return nil, err } - if equal { - // we are continuing to add values to the current grouping - current, curDistincts, err = merge(fields, current, row, curDistincts, oa.Aggregates) - if err != nil { - return nil, err - } - continue + if nextGroup { + out.Rows = append(out.Rows, agg.finish()) + agg.reset() } - // this is a new grouping. let's yield the old one, and start a new - out.Rows = append(out.Rows, current) - current, curDistincts = convertRow(fields, row, oa.Aggregates) - continue - } - - if current != nil { - final, err := convertFinal(current, oa.Aggregates) - if err != nil { + if err := agg.add(row); err != nil { return nil, err } - out.Rows = append(out.Rows, final) } + + if currentKey != nil { + out.Rows = append(out.Rows, agg.finish()) + } + return out, nil } // TryStreamExecute is a Primitive function. func (oa *OrderedAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, _ bool, callback func(*sqltypes.Result) error) error { - var current []sqltypes.Value - var curDistincts []sqltypes.Value - var fields []*querypb.Field - cb := func(qr *sqltypes.Result) error { return callback(qr.Truncate(oa.TruncateColumnCount)) } + var agg aggregationState + var fields []*querypb.Field + var currentKey []sqltypes.Value + visitor := func(qr *sqltypes.Result) error { - if len(qr.Fields) != 0 { - fields = convertFields(qr.Fields, oa.Aggregates) - if err := cb(&sqltypes.Result{Fields: fields}); err != nil { + var err error + + if agg == nil && len(qr.Fields) != 0 { + agg, fields, err = newAggregation(qr.Fields, oa.Aggregates) + if err != nil { + return err + } + if err = cb(&sqltypes.Result{Fields: fields}); err != nil { return err } } + // This code is similar to the one in Execute. for _, row := range qr.Rows { - // this is the first row. set up everything - if current == nil { - current, curDistincts = convertRow(fields, row, oa.Aggregates) - continue - } + var nextGroup bool - // not the first row. are we still in the old group, or is this a new grouping? - equal, err := oa.keysEqual(current, row) + currentKey, nextGroup, err = oa.nextGroupBy(currentKey, row) if err != nil { return err } - if equal { - // we are continuing to add values to the current grouping - current, curDistincts, err = merge(fields, current, row, curDistincts, oa.Aggregates) - if err != nil { + if nextGroup { + // this is a new grouping. let's yield the old one, and start a new + if err := cb(&sqltypes.Result{Rows: [][]sqltypes.Value{agg.finish()}}); err != nil { return err } - continue + + agg.reset() } - // this is a new grouping. let's yield the old one, and start a new - if err := cb(&sqltypes.Result{Rows: [][]sqltypes.Value{current}}); err != nil { + if err := agg.add(row); err != nil { return err } - current, curDistincts = convertRow(fields, row, oa.Aggregates) - continue } return nil } - err := vcursor.StreamExecutePrimitive(ctx, - oa.Input, - bindVars, - true, /* we need the input fields types to correctly calculate the output types */ - visitor) + /* we need the input fields types to correctly calculate the output types */ + err := vcursor.StreamExecutePrimitive(ctx, oa.Input, bindVars, true, visitor) if err != nil { return err } - if current != nil { - if err := cb(&sqltypes.Result{Rows: [][]sqltypes.Value{current}}); err != nil { + if currentKey != nil { + if err := cb(&sqltypes.Result{Rows: [][]sqltypes.Value{agg.finish()}}); err != nil { return err } } @@ -245,7 +229,13 @@ func (oa *OrderedAggregate) GetFields(ctx context.Context, vcursor VCursor, bind if err != nil { return nil, err } - qr = &sqltypes.Result{Fields: convertFields(qr.Fields, oa.Aggregates)} + + _, fields, err := newAggregation(qr.Fields, oa.Aggregates) + if err != nil { + return nil, err + } + + qr = &sqltypes.Result{Fields: fields} return qr.Truncate(oa.TruncateColumnCount), nil } @@ -259,26 +249,30 @@ func (oa *OrderedAggregate) NeedsTransaction() bool { return oa.Input.NeedsTransaction() } -func (oa *OrderedAggregate) keysEqual(row1, row2 []sqltypes.Value) (bool, error) { +func (oa *OrderedAggregate) nextGroupBy(currentKey, nextRow []sqltypes.Value) (nextKey []sqltypes.Value, nextGroup bool, err error) { + if currentKey == nil { + return nextRow, false, nil + } + for _, gb := range oa.GroupByKeys { - cmp, err := evalengine.NullsafeCompare(row1[gb.KeyCol], row2[gb.KeyCol], gb.CollationID) + cmp, err := evalengine.NullsafeCompare(currentKey[gb.KeyCol], nextRow[gb.KeyCol], gb.CollationID) if err != nil { _, isComparisonErr := err.(evalengine.UnsupportedComparisonError) _, isCollationErr := err.(evalengine.UnsupportedCollationError) if !isComparisonErr && !isCollationErr || gb.WeightStringCol == -1 { - return false, err + return nil, false, err } gb.KeyCol = gb.WeightStringCol - cmp, err = evalengine.NullsafeCompare(row1[gb.WeightStringCol], row2[gb.WeightStringCol], gb.CollationID) + cmp, err = evalengine.NullsafeCompare(currentKey[gb.WeightStringCol], nextRow[gb.WeightStringCol], gb.CollationID) if err != nil { - return false, err + return nil, false, err } } if cmp != 0 { - return false, nil + return nextRow, true, nil } } - return true, nil + return currentKey, false, nil } func aggregateParamsToString(in any) string { return in.(*AggregateParams).String() diff --git a/go/vt/vtgate/engine/ordered_aggregate_test.go b/go/vt/vtgate/engine/ordered_aggregate_test.go index 9f6b8afddd7..ba9e3a06ffc 100644 --- a/go/vt/vtgate/engine/ordered_aggregate_test.go +++ b/go/vt/vtgate/engine/ordered_aggregate_test.go @@ -598,34 +598,6 @@ func TestOrderedAggregateMergeFail(t *testing.T) { require.NoError(t, err) } -func TestMerge(t *testing.T) { - oa := &OrderedAggregate{ - Aggregates: []*AggregateParams{ - NewAggregateParam(AggregateSum, 1, ""), - NewAggregateParam(AggregateSum, 2, ""), - NewAggregateParam(AggregateMin, 3, ""), - NewAggregateParam(AggregateMax, 4, ""), - }} - fields := sqltypes.MakeTestFields( - "a|b|c|d|e", - "int64|int64|decimal|in32|varbinary", - ) - r := sqltypes.MakeTestResult(fields, - "1|2|3.2|3|ab", - "1|3|2.8|2|bc", - ) - - merged, _, err := merge(fields, r.Rows[0], r.Rows[1], nil, oa.Aggregates) - assert.NoError(t, err) - want := sqltypes.MakeTestResult(fields, "1|5|6.0|2|bc").Rows[0] - assert.Equal(t, want, merged) - - // swap and retry - merged, _, err = merge(fields, r.Rows[1], r.Rows[0], nil, oa.Aggregates) - assert.NoError(t, err) - assert.Equal(t, want, merged) -} - func TestOrderedAggregateExecuteGtid(t *testing.T) { vgtid := binlogdatapb.VGtid{} vgtid.ShardGtids = append(vgtid.ShardGtids, &binlogdatapb.ShardGtid{ diff --git a/go/vt/vtgate/engine/scalar_aggregation.go b/go/vt/vtgate/engine/scalar_aggregation.go index 8b336ddc0f7..6190e2e5fd6 100644 --- a/go/vt/vtgate/engine/scalar_aggregation.go +++ b/go/vt/vtgate/engine/scalar_aggregation.go @@ -22,9 +22,6 @@ import ( "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" - "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/vterrors" - . "vitess.io/vitess/go/vt/vtgate/engine/opcode" ) var _ Primitive = (*ScalarAggregate)(nil) @@ -66,7 +63,13 @@ func (sa *ScalarAggregate) GetFields(ctx context.Context, vcursor VCursor, bindV if err != nil { return nil, err } - qr = &sqltypes.Result{Fields: convertFields(qr.Fields, sa.Aggregates)} + + _, fields, err := newAggregation(qr.Fields, sa.Aggregates) + if err != nil { + return nil, err + } + + qr = &sqltypes.Result{Fields: fields} return qr.Truncate(sa.TruncateColumnCount), nil } @@ -81,36 +84,22 @@ func (sa *ScalarAggregate) TryExecute(ctx context.Context, vcursor VCursor, bind if err != nil { return nil, err } - fields := convertFields(result.Fields, sa.Aggregates) - out := &sqltypes.Result{ - Fields: fields, + + agg, fields, err := newAggregation(result.Fields, sa.Aggregates) + if err != nil { + return nil, err } - var resultRow []sqltypes.Value - var curDistincts []sqltypes.Value for _, row := range result.Rows { - if resultRow == nil { - resultRow, curDistincts = convertRow(fields, row, sa.Aggregates) - continue - } - resultRow, curDistincts, err = merge(fields, resultRow, row, curDistincts, sa.Aggregates) - if err != nil { + if err := agg.add(row); err != nil { return nil, err } } - if resultRow == nil { - // When doing aggregation without grouping keys, we need to produce a single row containing zero-value for the - // different aggregation functions - resultRow, err = sa.createEmptyRow() - } else { - resultRow, err = convertFinal(resultRow, sa.Aggregates) - } - if err != nil { - return nil, err + out := &sqltypes.Result{ + Fields: fields, + Rows: [][]sqltypes.Value{agg.finish()}, } - - out.Rows = [][]sqltypes.Value{resultRow} return out.Truncate(sa.TruncateColumnCount), nil } @@ -119,11 +108,11 @@ func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor cb := func(qr *sqltypes.Result) error { return callback(qr.Truncate(sa.TruncateColumnCount)) } - var current []sqltypes.Value - var curDistincts []sqltypes.Value - var fields []*querypb.Field - fieldsSent := false + var mu sync.Mutex + var agg aggregationState + var fields []*querypb.Field + var fieldsSent bool err := vcursor.StreamExecutePrimitive(ctx, sa.Input, bindVars, wantfields, func(result *sqltypes.Result) error { // as the underlying primitive call is not sync @@ -131,23 +120,23 @@ func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor // for correct aggregation. mu.Lock() defer mu.Unlock() - if len(result.Fields) != 0 && !fieldsSent { - fields = convertFields(result.Fields, sa.Aggregates) + + if agg == nil { + var err error + agg, fields, err = newAggregation(result.Fields, sa.Aggregates) + if err != nil { + return err + } + } + if !fieldsSent { if err := cb(&sqltypes.Result{Fields: fields}); err != nil { return err } fieldsSent = true } - // this code is very similar to the TryExecute method for _, row := range result.Rows { - if current == nil { - current, curDistincts = convertRow(fields, row, sa.Aggregates) - continue - } - var err error - current, curDistincts, err = merge(fields, current, row, curDistincts, sa.Aggregates) - if err != nil { + if err := agg.add(row); err != nil { return err } } @@ -157,58 +146,7 @@ func (sa *ScalarAggregate) TryStreamExecute(ctx context.Context, vcursor VCursor return err } - if current == nil { - // When doing aggregation without grouping keys, we need to produce a single row containing zero-value for the - // different aggregation functions - current, err = sa.createEmptyRow() - if err != nil { - return err - } - } else { - current, err = convertFinal(current, sa.Aggregates) - if err != nil { - return err - } - } - - return cb(&sqltypes.Result{Rows: [][]sqltypes.Value{current}}) -} - -// creates the empty row for the case when we are missing grouping keys and have empty input table -func (sa *ScalarAggregate) createEmptyRow() ([]sqltypes.Value, error) { - out := make([]sqltypes.Value, len(sa.Aggregates)) - for i, aggr := range sa.Aggregates { - op := aggr.Opcode - if aggr.OrigOpcode != AggregateUnassigned { - op = aggr.OrigOpcode - } - value, err := createEmptyValueFor(op) - if err != nil { - return nil, err - } - out[i] = value - } - return out, nil -} - -func createEmptyValueFor(opcode AggregateOpcode) (sqltypes.Value, error) { - switch opcode { - case - AggregateCountDistinct, - AggregateCount, - AggregateCountStar: - return countZero, nil - case - AggregateSumDistinct, - AggregateSum, - AggregateMin, - AggregateMax, - AggregateAnyValue, - AggregateGroupConcat: - return sqltypes.NULL, nil - - } - return sqltypes.NULL, vterrors.Errorf(vtrpc.Code_INVALID_ARGUMENT, "unknown aggregation %v", opcode) + return cb(&sqltypes.Result{Rows: [][]sqltypes.Value{agg.finish()}}) } // Inputs implements the Primitive interface diff --git a/go/vt/vtgate/evalengine/api_aggregation.go b/go/vt/vtgate/evalengine/api_aggregation.go new file mode 100644 index 00000000000..c0d490ced22 --- /dev/null +++ b/go/vt/vtgate/evalengine/api_aggregation.go @@ -0,0 +1,497 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "strconv" + + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/mysql/decimal" + "vitess.io/vitess/go/mysql/fastparse" + "vitess.io/vitess/go/mysql/format" + "vitess.io/vitess/go/sqltypes" +) + +// Sum implements a SUM() aggregation +type Sum interface { + Add(value sqltypes.Value) error + Result() sqltypes.Value + Reset() +} + +// MinMax implements a MIN() or MAX() aggregation +type MinMax interface { + Min(value sqltypes.Value) error + Max(value sqltypes.Value) error + Result() sqltypes.Value + Reset() +} + +// aggregationSumCount implements a sum of count values. +// This is a Vitess-specific optimization that allows our planner to push down +// some expensive cross-shard operations by summing counts from different result sets. +// The result of this operator is always an INT64 (like for the COUNT() operator); +// if no values were provided to the operator, the result will be 0 (not NULL). +// If the sum of counts overflows, an error will be returned (instead of transparently +// calculating the larger sum using decimals). +type aggregationSumCount struct { + n int64 +} + +func (s *aggregationSumCount) Add(value sqltypes.Value) error { + if value.IsNull() { + return nil + } + n, err := value.ToInt64() + if err != nil { + return err + } + + result := s.n + n + if (result > s.n) != (n > 0) { + return dataOutOfRangeError(s.n, n, "BIGINT", "+") + } + + s.n = result + return nil +} + +func (s *aggregationSumCount) Result() sqltypes.Value { + return sqltypes.NewInt64(s.n) +} + +func (s *aggregationSumCount) Reset() { + s.n = 0 +} + +// aggregationInt implements SUM, MIN and MAX aggregation for Signed types, +// including INT64, INT32, INT24, INT16 and INT8. +// +// For SUM, the result of the operator is always a DECIMAL (matching MySQL's behavior), +// unless no values have been aggregated, in which case the result is NULL. +// For performance reasons, although the output of a SUM is a DECIMAL, the computations +// are performed using 64-bit arithmetic as long as they don't overflow. +// +// For MIN and MAX aggregations, the result of the operator is the same type as the values that +// have been aggregated. +type aggregationInt struct { + current int64 + dec decimal.Decimal + t sqltypes.Type + init bool +} + +func (s *aggregationInt) Add(value sqltypes.Value) error { + if value.IsNull() { + return nil + } + n, err := value.ToInt64() + if err != nil { + return err + } + + s.init = true + + if s.dec.IsInitialized() { + s.dec = s.dec.Add(decimal.NewFromInt(n)) + return nil + } + + result := s.current + n + if (result > s.current) != (n > 0) { + s.dec = decimal.NewFromInt(s.current).Add(decimal.NewFromInt(n)) + } else { + s.current = result + } + + return nil +} + +func (s *aggregationInt) Min(value sqltypes.Value) error { + if value.IsNull() { + return nil + } + n, err := value.ToInt64() + if err != nil { + return err + } + if !s.init || n < s.current { + s.current = n + } + s.init = true + return nil +} + +func (s *aggregationInt) Max(value sqltypes.Value) error { + if value.IsNull() { + return nil + } + n, err := value.ToInt64() + if err != nil { + return err + } + if !s.init || n > s.current { + s.current = n + } + s.init = true + return nil +} + +func (s *aggregationInt) Result() sqltypes.Value { + if !s.init { + return sqltypes.NULL + } + + var b []byte + if s.dec.IsInitialized() { + b = s.dec.FormatMySQL(0) + } else { + b = strconv.AppendInt(nil, s.current, 10) + } + return sqltypes.MakeTrusted(s.t, b) +} + +func (s *aggregationInt) Reset() { + s.current = 0 + s.dec = decimal.Decimal{} + s.init = false +} + +// aggregationUint implements SUM, MIN and MAX aggregation for Unsigned types, +// including UINT64, UINT32, UINT24, UINT16 and UINT8. +// +// For SUM, the result of the operator is always a DECIMAL (matching MySQL's behavior), +// unless no values have been aggregated, in which case the result is NULL. +// For performance reasons, although the output of a SUM is a DECIMAL, the computations +// are performed using 64-bit arithmetic as long as they don't overflow. +// +// For MIN and MAX aggregations, the result of the operator is the same type as the values that +// have been aggregated. +type aggregationUint struct { + current uint64 + dec decimal.Decimal + t sqltypes.Type + init bool +} + +func (s *aggregationUint) Add(value sqltypes.Value) error { + if value.IsNull() { + return nil + } + n, err := value.ToUint64() + if err != nil { + return err + } + + s.init = true + + if s.dec.IsInitialized() { + s.dec = s.dec.Add(decimal.NewFromUint(n)) + return nil + } + + result := s.current + n + if false { + s.dec = decimal.NewFromUint(s.current).Add(decimal.NewFromUint(n)) + } else { + s.current = result + } + + return nil +} + +func (s *aggregationUint) Min(value sqltypes.Value) error { + if value.IsNull() { + return nil + } + n, err := value.ToUint64() + if err != nil { + return err + } + if !s.init || n < s.current { + s.current = n + } + s.init = true + return nil +} + +func (s *aggregationUint) Max(value sqltypes.Value) error { + if value.IsNull() { + return nil + } + n, err := value.ToUint64() + if err != nil { + return err + } + if !s.init || n > s.current { + s.current = n + } + s.init = true + return nil +} + +func (s *aggregationUint) Result() sqltypes.Value { + if !s.init { + return sqltypes.NULL + } + + var b []byte + if s.dec.IsInitialized() { + b = s.dec.FormatMySQL(0) + } else { + b = strconv.AppendUint(nil, s.current, 10) + } + return sqltypes.MakeTrusted(s.t, b) +} + +func (s *aggregationUint) Reset() { + s.current = 0 + s.dec = decimal.Decimal{} + s.init = false +} + +// aggregationFloat implements SUM, MIN and MAX aggregations for FLOAT32 and FLOAT64 types. +// For SUM aggregations, the result is always a FLOAT64, unless no values have been aggregated, +// in which case the result is NULL. +// For MIN and MAX aggregations, the result is the same type as the aggregated values. +type aggregationFloat struct { + current float64 + t sqltypes.Type + init bool +} + +func (s *aggregationFloat) Add(value sqltypes.Value) error { + if value.IsNull() { + return nil + } + f, err := value.ToFloat64() + if err != nil { + return err + } + s.current += f + s.init = true + return nil +} + +func (s *aggregationFloat) Min(value sqltypes.Value) error { + if value.IsNull() { + return nil + } + n, err := value.ToFloat64() + if err != nil { + return err + } + if !s.init || n < s.current { + s.current = n + } + s.init = true + return nil +} + +func (s *aggregationFloat) Max(value sqltypes.Value) error { + if value.IsNull() { + return nil + } + n, err := value.ToFloat64() + if err != nil { + return err + } + if !s.init || n > s.current { + s.current = n + } + s.init = true + return nil +} + +func (s *aggregationFloat) Result() sqltypes.Value { + if !s.init { + return sqltypes.NULL + } + return sqltypes.MakeTrusted(s.t, format.FormatFloat(s.current)) +} + +func (s *aggregationFloat) Reset() { + s.current = 0 + s.init = false +} + +// aggregationSumAny implements SUM aggregation for non-numeric values. +// Matching MySQL's behavior, all the values are best-effort parsed as FLOAT64 +// before being aggregated. +type aggregationSumAny struct { + aggregationFloat +} + +func (s *aggregationSumAny) Add(value sqltypes.Value) error { + if value.IsNull() { + return nil + } + f, _ := fastparse.ParseFloat64(value.RawStr()) + s.current += f + s.init = true + return nil +} + +func (s *aggregationSumAny) Result() sqltypes.Value { + if !s.init { + return sqltypes.NULL + } + return sqltypes.NewFloat64(s.current) +} + +// aggregationDecimal implements SUM, MIN and MAX aggregations for the DECIMAL type. +// The return of all aggregations is always DECIMAL, except when no values have been +// aggregated, where the return is NULL. +type aggregationDecimal struct { + dec decimal.Decimal + prec int32 +} + +func (s *aggregationDecimal) Add(value sqltypes.Value) error { + if value.IsNull() { + return nil + } + dec, err := decimal.NewFromMySQL(value.Raw()) + if err != nil { + return err + } + if !s.dec.IsInitialized() { + s.dec = dec + s.prec = -dec.Exponent() + } else { + s.dec = s.dec.Add(dec) + s.prec = max(s.prec, -dec.Exponent()) + } + return nil +} + +func (s *aggregationDecimal) Min(value sqltypes.Value) error { + if value.IsNull() { + return nil + } + dec, err := decimal.NewFromMySQL(value.Raw()) + if err != nil { + return err + } + if !s.dec.IsInitialized() || dec.Cmp(s.dec) < 0 { + s.dec = dec + } + return nil +} + +func (s *aggregationDecimal) Max(value sqltypes.Value) error { + if value.IsNull() { + return nil + } + dec, err := decimal.NewFromMySQL(value.Raw()) + if err != nil { + return err + } + if !s.dec.IsInitialized() || dec.Cmp(s.dec) > 0 { + s.dec = dec + } + return nil +} + +func (s *aggregationDecimal) Result() sqltypes.Value { + if !s.dec.IsInitialized() { + return sqltypes.NULL + } + return sqltypes.MakeTrusted(sqltypes.Decimal, s.dec.FormatMySQL(s.prec)) +} + +func (s *aggregationDecimal) Reset() { + s.dec = decimal.Decimal{} + s.prec = 0 +} + +func NewSumOfCounts() Sum { + return &aggregationSumCount{} +} + +func NewAggregationSum(type_ sqltypes.Type) Sum { + switch { + case sqltypes.IsSigned(type_): + return &aggregationInt{t: sqltypes.Decimal} + case sqltypes.IsUnsigned(type_): + return &aggregationUint{t: sqltypes.Decimal} + case sqltypes.IsFloat(type_): + return &aggregationFloat{t: sqltypes.Float64} + case sqltypes.IsDecimal(type_): + return &aggregationDecimal{} + default: + return &aggregationSumAny{} + } +} + +// aggregationMinMax implements MIN and MAX aggregations for all data types +// that cannot be more efficiently handled by one of the numeric aggregators. +// The aggregation is performed using the slow NullSafeComparison path of the +// evaluation engine. +type aggregationMinMax struct { + current sqltypes.Value + collation collations.ID +} + +func (a *aggregationMinMax) minmax(value sqltypes.Value, max bool) (err error) { + if value.IsNull() { + return nil + } + if a.current.IsNull() { + a.current = value + return nil + } + n, err := compare(a.current, value, a.collation) + if err != nil { + return err + } + if (n < 0) == max { + a.current = value + } + return nil +} + +func (a *aggregationMinMax) Min(value sqltypes.Value) (err error) { + return a.minmax(value, false) +} + +func (a *aggregationMinMax) Max(value sqltypes.Value) error { + return a.minmax(value, true) +} + +func (a *aggregationMinMax) Result() sqltypes.Value { + return a.current +} + +func (a *aggregationMinMax) Reset() { + a.current = sqltypes.NULL +} + +func NewAggregationMinMax(type_ sqltypes.Type, collation collations.ID) MinMax { + switch { + case sqltypes.IsSigned(type_): + return &aggregationInt{t: type_} + case sqltypes.IsUnsigned(type_): + return &aggregationUint{t: type_} + case sqltypes.IsFloat(type_): + return &aggregationFloat{t: type_} + case sqltypes.IsDecimal(type_): + return &aggregationDecimal{} + default: + return &aggregationMinMax{collation: collation} + } +} diff --git a/go/vt/vtgate/evalengine/api_aggregation_test.go b/go/vt/vtgate/evalengine/api_aggregation_test.go new file mode 100644 index 00000000000..aab49541e71 --- /dev/null +++ b/go/vt/vtgate/evalengine/api_aggregation_test.go @@ -0,0 +1,166 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/test/utils" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) + +func TestMinMax(t *testing.T) { + tcases := []struct { + type_ sqltypes.Type + coll collations.ID + values []sqltypes.Value + min, max sqltypes.Value + err error + }{ + { + type_: sqltypes.Int64, + values: []sqltypes.Value{}, + min: sqltypes.NULL, + max: sqltypes.NULL, + }, + { + type_: sqltypes.Int64, + values: []sqltypes.Value{NULL, NULL}, + min: sqltypes.NULL, + max: sqltypes.NULL, + }, + { + type_: sqltypes.Int64, + values: []sqltypes.Value{NULL, NewInt64(1)}, + min: NewInt64(1), + max: NewInt64(1), + }, + { + type_: sqltypes.Int64, + values: []sqltypes.Value{NewInt64(1), NewInt64(2)}, + min: NewInt64(1), + max: NewInt64(2), + }, + { + type_: sqltypes.VarChar, + values: []sqltypes.Value{TestValue(sqltypes.VarChar, "aa"), TestValue(sqltypes.VarChar, "bb")}, + err: vterrors.New(vtrpcpb.Code_UNKNOWN, "cannot compare strings, collation is unknown or unsupported (collation ID: 0)"), + }, + { + type_: sqltypes.VarBinary, + values: []sqltypes.Value{sqltypes.NewVarBinary("a"), sqltypes.NewVarBinary("b")}, + min: sqltypes.NewVarBinary("a"), + max: sqltypes.NewVarBinary("b"), + }, + { + // accent insensitive + type_: sqltypes.VarChar, + coll: getCollationID("utf8mb4_0900_as_ci"), + values: []sqltypes.Value{ + sqltypes.NewVarChar("ǍḄÇ"), + sqltypes.NewVarChar("ÁḆĈ"), + }, + min: sqltypes.NewVarChar("ÁḆĈ"), + max: sqltypes.NewVarChar("ǍḄÇ"), + }, + { + // kana sensitive + type_: sqltypes.VarChar, + coll: getCollationID("utf8mb4_ja_0900_as_cs_ks"), + values: []sqltypes.Value{ + sqltypes.NewVarChar("\xE3\x81\xAB\xE3\x81\xBB\xE3\x82\x93\xE3\x81\x94"), + sqltypes.NewVarChar("\xE3\x83\x8B\xE3\x83\x9B\xE3\x83\xB3\xE3\x82\xB4"), + }, + min: sqltypes.NewVarChar("\xE3\x81\xAB\xE3\x81\xBB\xE3\x82\x93\xE3\x81\x94"), + max: sqltypes.NewVarChar("\xE3\x83\x8B\xE3\x83\x9B\xE3\x83\xB3\xE3\x82\xB4"), + }, + { + // non breaking space + type_: sqltypes.VarChar, + coll: getCollationID("utf8mb4_0900_as_cs"), + values: []sqltypes.Value{ + sqltypes.NewVarChar("abc "), + sqltypes.NewVarChar("abc\u00a0"), + }, + min: sqltypes.NewVarChar("abc "), + max: sqltypes.NewVarChar("abc\u00a0"), + }, + { + type_: sqltypes.VarChar, + coll: getCollationID("utf8mb4_hu_0900_ai_ci"), + // "cs" counts as a separate letter, where c < cs < d + values: []sqltypes.Value{ + sqltypes.NewVarChar("c"), + sqltypes.NewVarChar("cs"), + }, + min: sqltypes.NewVarChar("c"), + max: sqltypes.NewVarChar("cs"), + }, + { + type_: sqltypes.VarChar, + coll: getCollationID("utf8mb4_hu_0900_ai_ci"), + // "cs" counts as a separate letter, where c < cs < d + values: []sqltypes.Value{ + sqltypes.NewVarChar("cukor"), + sqltypes.NewVarChar("csak"), + }, + min: sqltypes.NewVarChar("cukor"), + max: sqltypes.NewVarChar("csak"), + }, + } + for i, tcase := range tcases { + t.Run(strconv.Itoa(i), func(t *testing.T) { + t.Run("Min", func(t *testing.T) { + agg := NewAggregationMinMax(tcase.type_, tcase.coll) + + for _, v := range tcase.values { + err := agg.Min(v) + if err != nil { + if tcase.err != nil { + return + } + require.NoError(t, err) + } + } + + utils.MustMatch(t, agg.Result(), tcase.min) + }) + + t.Run("Max", func(t *testing.T) { + agg := NewAggregationMinMax(tcase.type_, tcase.coll) + + for _, v := range tcase.values { + err := agg.Max(v) + if err != nil { + if tcase.err != nil { + return + } + require.NoError(t, err) + } + } + + utils.MustMatch(t, agg.Result(), tcase.max) + }) + }) + } +} diff --git a/go/vt/vtgate/evalengine/api_arithmetic_test.go b/go/vt/vtgate/evalengine/api_arithmetic_test.go index 10199206755..40373423aa5 100644 --- a/go/vt/vtgate/evalengine/api_arithmetic_test.go +++ b/go/vt/vtgate/evalengine/api_arithmetic_test.go @@ -24,7 +24,6 @@ import ( "strconv" "testing" - "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/test/utils" "vitess.io/vitess/go/vt/vthash" @@ -805,248 +804,6 @@ func TestCompareNumeric(t *testing.T) { } } -func TestMin(t *testing.T) { - tcases := []struct { - v1, v2 sqltypes.Value - min sqltypes.Value - err error - }{{ - v1: NULL, - v2: NULL, - min: NULL, - }, { - v1: NewInt64(1), - v2: NULL, - min: NewInt64(1), - }, { - v1: NULL, - v2: NewInt64(1), - min: NewInt64(1), - }, { - v1: NewInt64(1), - v2: NewInt64(2), - min: NewInt64(1), - }, { - v1: NewInt64(2), - v2: NewInt64(1), - min: NewInt64(1), - }, { - v1: NewInt64(1), - v2: NewInt64(1), - min: NewInt64(1), - }, { - v1: TestValue(sqltypes.VarChar, "aa"), - v2: TestValue(sqltypes.VarChar, "aa"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, "cannot compare strings, collation is unknown or unsupported (collation ID: 0)"), - }} - for _, tcase := range tcases { - t.Run(fmt.Sprintf("%v/%v", tcase.v1, tcase.v2), func(t *testing.T) { - v, err := Min(tcase.v1, tcase.v2, collations.Unknown) - if tcase.err == nil { - require.NoError(t, err) - } else { - require.Error(t, err) - } - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Min error: %v, want %v", vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - return - } - - if !reflect.DeepEqual(v, tcase.min) { - t.Errorf("Min(%v, %v): %v, want %v", tcase.v1, tcase.v2, v, tcase.min) - } - }) - } -} - -func TestMinCollate(t *testing.T) { - tcases := []struct { - v1, v2 string - collation collations.ID - out string - err error - }{ - { - // accent insensitive - v1: "ǍḄÇ", - v2: "ÁḆĈ", - out: "ǍḄÇ", - collation: getCollationID("utf8mb4_0900_as_ci"), - }, - { - // kana sensitive - v1: "\xE3\x81\xAB\xE3\x81\xBB\xE3\x82\x93\xE3\x81\x94", - v2: "\xE3\x83\x8B\xE3\x83\x9B\xE3\x83\xB3\xE3\x82\xB4", - out: "\xE3\x83\x8B\xE3\x83\x9B\xE3\x83\xB3\xE3\x82\xB4", - collation: getCollationID("utf8mb4_ja_0900_as_cs_ks"), - }, - { - // non breaking space - v1: "abc ", - v2: "abc\u00a0", - out: "abc\u00a0", - collation: getCollationID("utf8mb4_0900_as_cs"), - }, - { - // "cs" counts as a separate letter, where c < cs < d - v1: "c", - v2: "cs", - out: "cs", - collation: getCollationID("utf8mb4_hu_0900_ai_ci"), - }, - { - // "cs" counts as a separate letter, where c < cs < d - v1: "cukor", - v2: "csak", - out: "csak", - collation: getCollationID("utf8mb4_hu_0900_ai_ci"), - }, - } - for _, tcase := range tcases { - t.Run(fmt.Sprintf("%v/%v", tcase.v1, tcase.v2), func(t *testing.T) { - got, err := Min(TestValue(sqltypes.VarChar, tcase.v1), TestValue(sqltypes.VarChar, tcase.v2), tcase.collation) - if tcase.err == nil { - require.NoError(t, err) - } else { - require.Error(t, err) - } - if !vterrors.Equals(err, tcase.err) { - t.Errorf("NullsafeCompare(%v, %v) error: %v, want %v", tcase.v1, tcase.v2, vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - return - } - - if got.ToString() == tcase.out { - t.Errorf("NullsafeCompare(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) - } - }) - } -} - -func TestMax(t *testing.T) { - tcases := []struct { - v1, v2 sqltypes.Value - max sqltypes.Value - err error - }{{ - v1: NULL, - v2: NULL, - max: NULL, - }, { - v1: NewInt64(1), - v2: NULL, - max: NewInt64(1), - }, { - v1: NULL, - v2: NewInt64(1), - max: NewInt64(1), - }, { - v1: NewInt64(1), - v2: NewInt64(2), - max: NewInt64(2), - }, { - v1: NewInt64(2), - v2: NewInt64(1), - max: NewInt64(2), - }, { - v1: NewInt64(1), - v2: NewInt64(1), - max: NewInt64(1), - }, { - v1: TestValue(sqltypes.VarChar, "aa"), - v2: TestValue(sqltypes.VarChar, "aa"), - err: vterrors.New(vtrpcpb.Code_UNKNOWN, "cannot compare strings, collation is unknown or unsupported (collation ID: 0)"), - }} - for _, tcase := range tcases { - t.Run(fmt.Sprintf("%v/%v", tcase.v1, tcase.v2), func(t *testing.T) { - v, err := Max(tcase.v1, tcase.v2, collations.Unknown) - if tcase.err == nil { - require.NoError(t, err) - } else { - require.Error(t, err) - } - if !vterrors.Equals(err, tcase.err) { - t.Errorf("Max error: %v, want %v", vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - return - } - - if !reflect.DeepEqual(v, tcase.max) { - t.Errorf("Max(%v, %v): %v, want %v", tcase.v1, tcase.v2, v, tcase.max) - } - }) - } -} - -func TestMaxCollate(t *testing.T) { - tcases := []struct { - v1, v2 string - collation collations.ID - out string - err error - }{ - { - // accent insensitive - v1: "ǍḄÇ", - v2: "ÁḆĈ", - out: "ǍḄÇ", - collation: getCollationID("utf8mb4_0900_as_ci"), - }, - { - // kana sensitive - v1: "\xE3\x81\xAB\xE3\x81\xBB\xE3\x82\x93\xE3\x81\x94", - v2: "\xE3\x83\x8B\xE3\x83\x9B\xE3\x83\xB3\xE3\x82\xB4", - out: "\xE3\x83\x8B\xE3\x83\x9B\xE3\x83\xB3\xE3\x82\xB4", - collation: getCollationID("utf8mb4_ja_0900_as_cs_ks"), - }, - { - // non breaking space - v1: "abc ", - v2: "abc\u00a0", - out: "abc\u00a0", - collation: getCollationID("utf8mb4_0900_as_cs"), - }, - { - // "cs" counts as a separate letter, where c < cs < d - v1: "c", - v2: "cs", - out: "cs", - collation: getCollationID("utf8mb4_hu_0900_ai_ci"), - }, - { - // "cs" counts as a separate letter, where c < cs < d - v1: "cukor", - v2: "csak", - out: "csak", - collation: getCollationID("utf8mb4_hu_0900_ai_ci"), - }, - } - for _, tcase := range tcases { - t.Run(fmt.Sprintf("%v/%v", tcase.v1, tcase.v2), func(t *testing.T) { - got, err := Max(TestValue(sqltypes.VarChar, tcase.v1), TestValue(sqltypes.VarChar, tcase.v2), tcase.collation) - if tcase.err == nil { - require.NoError(t, err) - } else { - require.Error(t, err) - } - if !vterrors.Equals(err, tcase.err) { - t.Errorf("NullsafeCompare(%v, %v) error: %v, want %v", tcase.v1, tcase.v2, vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err != nil { - return - } - - if got.ToString() != tcase.out { - t.Errorf("NullsafeCompare(%v, %v): %v, want %v", tcase.v1, tcase.v2, got, tcase.out) - } - }) - } -} - func printValue(v sqltypes.Value) string { vBytes, _ := v.ToBytes() return fmt.Sprintf("%v:%q", v.Type(), vBytes) diff --git a/go/vt/vtgate/evalengine/api_compare.go b/go/vt/vtgate/evalengine/api_compare.go index e5a4e7335f2..3c9e632e819 100644 --- a/go/vt/vtgate/evalengine/api_compare.go +++ b/go/vt/vtgate/evalengine/api_compare.go @@ -51,61 +51,7 @@ func (err UnsupportedCollationError) Error() string { // UnsupportedCollationHashError is returned when we try to get the hash value and are missing the collation to use var UnsupportedCollationHashError = vterrors.Errorf(vtrpcpb.Code_INTERNAL, "text type with an unknown/unsupported collation cannot be hashed") -// Min returns the minimum of v1 and v2. If one of the -// values is NULL, it returns the other value. If both -// are NULL, it returns NULL. -func Min(v1, v2 sqltypes.Value, collation collations.ID) (sqltypes.Value, error) { - return minmax(v1, v2, true, collation) -} - -// Max returns the maximum of v1 and v2. If one of the -// values is NULL, it returns the other value. If both -// are NULL, it returns NULL. -func Max(v1, v2 sqltypes.Value, collation collations.ID) (sqltypes.Value, error) { - return minmax(v1, v2, false, collation) -} - -func minmax(v1, v2 sqltypes.Value, min bool, collation collations.ID) (sqltypes.Value, error) { - if v1.IsNull() { - return v2, nil - } - if v2.IsNull() { - return v1, nil - } - - n, err := NullsafeCompare(v1, v2, collation) - if err != nil { - return sqltypes.NULL, err - } - - // XNOR construct. See tests. - v1isSmaller := n < 0 - if min == v1isSmaller { - return v1, nil - } - return v2, nil -} - -// NullsafeCompare returns 0 if v1==v2, -1 if v1v2. -// NULL is the lowest value. If any value is -// numeric, then a numeric comparison is performed after -// necessary conversions. If none are numeric, then it's -// a simple binary comparison. Uncomparable values return an error. -func NullsafeCompare(v1, v2 sqltypes.Value, collationID collations.ID) (int, error) { - // Based on the categorization defined for the types, - // we're going to allow comparison of the following: - // Null, isNumber, IsBinary. This will exclude IsQuoted - // types that are not Binary, and Expression. - if v1.IsNull() { - if v2.IsNull() { - return 0, nil - } - return -1, nil - } - if v2.IsNull() { - return 1, nil - } - +func compare(v1, v2 sqltypes.Value, collationID collations.ID) (int, error) { // We have a fast path here for the case where both values are // the same type, and it's one of the basic types we can compare // directly. This is a common case for equality checks. @@ -202,3 +148,25 @@ func NullsafeCompare(v1, v2 sqltypes.Value, collationID collations.ID) (int, err } return -1, nil } + +// NullsafeCompare returns 0 if v1==v2, -1 if v1v2. +// NULL is the lowest value. If any value is +// numeric, then a numeric comparison is performed after +// necessary conversions. If none are numeric, then it's +// a simple binary comparison. Uncomparable values return an error. +func NullsafeCompare(v1, v2 sqltypes.Value, collationID collations.ID) (int, error) { + // Based on the categorization defined for the types, + // we're going to allow comparison of the following: + // Null, isNumber, IsBinary. This will exclude IsQuoted + // types that are not Binary, and Expression. + if v1.IsNull() { + if v2.IsNull() { + return 0, nil + } + return -1, nil + } + if v2.IsNull() { + return 1, nil + } + return compare(v1, v2, collationID) +} diff --git a/go/vt/vtgate/evalengine/api_hash_test.go b/go/vt/vtgate/evalengine/api_hash_test.go index f27eb2c1854..832a1ed3b88 100644 --- a/go/vt/vtgate/evalengine/api_hash_test.go +++ b/go/vt/vtgate/evalengine/api_hash_test.go @@ -18,7 +18,6 @@ package evalengine import ( "fmt" - "math/rand" "testing" "time" @@ -79,7 +78,7 @@ func TestHashCodesRandom(t *testing.T) { endTime := time.Now().Add(1 * time.Second) for time.Now().Before(endTime) { tested++ - v1, v2 := randomValues() + v1, v2 := sqltypes.TestRandomValues() cmp, err := NullsafeCompare(v1, v2, collation) require.NoErrorf(t, err, "%s compared with %s", v1.String(), v2.String()) typ, err := coerceTo(v1.Type(), v2.Type()) @@ -168,7 +167,7 @@ func TestHashCodesRandom128(t *testing.T) { endTime := time.Now().Add(1 * time.Second) for time.Now().Before(endTime) { tested++ - v1, v2 := randomValues() + v1, v2 := sqltypes.TestRandomValues() cmp, err := NullsafeCompare(v1, v2, collation) require.NoErrorf(t, err, "%s compared with %s", v1.String(), v2.String()) typ, err := coerceTo(v1.Type(), v2.Type()) @@ -190,89 +189,6 @@ func TestHashCodesRandom128(t *testing.T) { t.Logf("tested %d values, with %d equalities found\n", tested, equal) } -func randomValues() (sqltypes.Value, sqltypes.Value) { - if rand.Int()%2 == 0 { - // create a single value, and turn it into two different types - v := rand.Int() - return randomNumericType(v), randomNumericType(v) - } - - // just produce two arbitrary random values and compare - return randomValue(), randomValue() -} - -func randomNumericType(i int) sqltypes.Value { - r := rand.Intn(len(numericTypes)) - return numericTypes[r](i) - -} - -var numericTypes = []func(int) sqltypes.Value{ - func(i int) sqltypes.Value { return sqltypes.NULL }, - func(i int) sqltypes.Value { return sqltypes.NewInt8(int8(i)) }, - func(i int) sqltypes.Value { return sqltypes.NewInt32(int32(i)) }, - func(i int) sqltypes.Value { return sqltypes.NewInt64(int64(i)) }, - func(i int) sqltypes.Value { return sqltypes.NewUint64(uint64(i)) }, - func(i int) sqltypes.Value { return sqltypes.NewUint32(uint32(i)) }, - func(i int) sqltypes.Value { return sqltypes.NewFloat64(float64(i)) }, - func(i int) sqltypes.Value { return sqltypes.NewDecimal(fmt.Sprintf("%d", i)) }, - func(i int) sqltypes.Value { return sqltypes.NewVarChar(fmt.Sprintf("%d", i)) }, - func(i int) sqltypes.Value { return sqltypes.NewVarChar(fmt.Sprintf(" %f aa", float64(i))) }, -} - -var randomGenerators = []func() sqltypes.Value{ - randomNull, - randomInt8, - randomInt32, - randomInt64, - randomUint64, - randomUint32, - randomVarChar, - randomComplexVarChar, - randomDecimal, - randomDate, - randomDatetime, - randomTimestamp, - randomTime, -} - -func randomValue() sqltypes.Value { - r := rand.Intn(len(randomGenerators)) - return randomGenerators[r]() -} - -func randTime() time.Time { - min := time.Date(1970, 1, 0, 0, 0, 0, 0, time.UTC).Unix() - max := time.Date(2070, 1, 0, 0, 0, 0, 0, time.UTC).Unix() - delta := max - min - - sec := rand.Int63n(delta) + min - return time.Unix(sec, 0) -} - -func randomNull() sqltypes.Value { return sqltypes.NULL } -func randomInt8() sqltypes.Value { return sqltypes.NewInt8(int8(rand.Intn(255))) } -func randomInt32() sqltypes.Value { return sqltypes.NewInt32(rand.Int31()) } -func randomInt64() sqltypes.Value { return sqltypes.NewInt64(rand.Int63()) } -func randomUint32() sqltypes.Value { return sqltypes.NewUint32(rand.Uint32()) } -func randomUint64() sqltypes.Value { return sqltypes.NewUint64(rand.Uint64()) } -func randomDecimal() sqltypes.Value { - dec := fmt.Sprintf("%d.%d", rand.Intn(9999999999), rand.Intn(9999999999)) - if rand.Int()&0x1 == 1 { - dec = "-" + dec - } - return sqltypes.NewDecimal(dec) -} -func randomVarChar() sqltypes.Value { return sqltypes.NewVarChar(fmt.Sprintf("%d", rand.Int63())) } -func randomDate() sqltypes.Value { return sqltypes.NewDate(randTime().Format(time.DateOnly)) } -func randomDatetime() sqltypes.Value { return sqltypes.NewDatetime(randTime().Format(time.DateTime)) } -func randomTimestamp() sqltypes.Value { return sqltypes.NewTimestamp(randTime().Format(time.DateTime)) } -func randomTime() sqltypes.Value { return sqltypes.NewTime(randTime().Format(time.TimeOnly)) } - -func randomComplexVarChar() sqltypes.Value { - return sqltypes.NewVarChar(fmt.Sprintf(" \t %f apa", float64(rand.Intn(1000))*1.10)) -} - // coerceTo takes two input types, and decides how they should be coerced before compared func coerceTo(v1, v2 sqltypes.Type) (sqltypes.Type, error) { if v1 == v2 { diff --git a/go/vt/vtgate/evalengine/weights_test.go b/go/vt/vtgate/evalengine/weights_test.go index 95f7ef980e9..50a1d91f20c 100644 --- a/go/vt/vtgate/evalengine/weights_test.go +++ b/go/vt/vtgate/evalengine/weights_test.go @@ -18,9 +18,7 @@ package evalengine import ( "fmt" - "math/rand" "slices" - "strconv" "testing" "github.com/stretchr/testify/require" @@ -45,17 +43,17 @@ func TestWeightStrings(t *testing.T) { len int prec int }{ - {name: "int64", gen: randomInt64, types: []sqltypes.Type{sqltypes.Int64, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, - {name: "uint64", gen: randomUint64, types: []sqltypes.Type{sqltypes.Uint64, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, - {name: "float64", gen: randomFloat64, types: []sqltypes.Type{sqltypes.Float64, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, - {name: "varchar", gen: randomVarChar, types: []sqltypes.Type{sqltypes.VarChar, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationUtf8mb4ID}, - {name: "varbinary", gen: randomVarBinary, types: []sqltypes.Type{sqltypes.VarBinary, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, - {name: "decimal", gen: randomDecimal, types: []sqltypes.Type{sqltypes.Decimal, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID, len: 20, prec: 10}, - {name: "json", gen: randomJSON, types: []sqltypes.Type{sqltypes.TypeJSON}, col: collations.CollationBinaryID}, - {name: "date", gen: randomDate, types: []sqltypes.Type{sqltypes.Date, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, - {name: "datetime", gen: randomDatetime, types: []sqltypes.Type{sqltypes.Datetime, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, - {name: "timestamp", gen: randomTimestamp, types: []sqltypes.Type{sqltypes.Timestamp, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, - {name: "time", gen: randomTime, types: []sqltypes.Type{sqltypes.Time, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, + {name: "int64", gen: sqltypes.RandomGenerators[sqltypes.Int64], types: []sqltypes.Type{sqltypes.Int64, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, + {name: "uint64", gen: sqltypes.RandomGenerators[sqltypes.Uint64], types: []sqltypes.Type{sqltypes.Uint64, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, + {name: "float64", gen: sqltypes.RandomGenerators[sqltypes.Float64], types: []sqltypes.Type{sqltypes.Float64, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, + {name: "varchar", gen: sqltypes.RandomGenerators[sqltypes.VarChar], types: []sqltypes.Type{sqltypes.VarChar, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationUtf8mb4ID}, + {name: "varbinary", gen: sqltypes.RandomGenerators[sqltypes.VarBinary], types: []sqltypes.Type{sqltypes.VarBinary, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, + {name: "decimal", gen: sqltypes.RandomGenerators[sqltypes.Decimal], types: []sqltypes.Type{sqltypes.Decimal, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID, len: 20, prec: 10}, + {name: "json", gen: sqltypes.RandomGenerators[sqltypes.TypeJSON], types: []sqltypes.Type{sqltypes.TypeJSON}, col: collations.CollationBinaryID}, + {name: "date", gen: sqltypes.RandomGenerators[sqltypes.Date], types: []sqltypes.Type{sqltypes.Date, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, + {name: "datetime", gen: sqltypes.RandomGenerators[sqltypes.Datetime], types: []sqltypes.Type{sqltypes.Datetime, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, + {name: "timestamp", gen: sqltypes.RandomGenerators[sqltypes.Timestamp], types: []sqltypes.Type{sqltypes.Timestamp, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, + {name: "time", gen: sqltypes.RandomGenerators[sqltypes.Time], types: []sqltypes.Type{sqltypes.Time, sqltypes.VarChar, sqltypes.TypeJSON}, col: collations.CollationBinaryID}, } for _, tc := range cases { @@ -103,45 +101,3 @@ func TestWeightStrings(t *testing.T) { } } } - -func randomVarBinary() sqltypes.Value { return sqltypes.NewVarBinary(string(randomBytes())) } -func randomFloat64() sqltypes.Value { - return sqltypes.NewFloat64(rand.NormFloat64()) -} - -func randomBytes() []byte { - const Dictionary = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - - b := make([]byte, 4+rand.Intn(256)) - for i := range b { - b[i] = Dictionary[rand.Intn(len(Dictionary))] - } - return b -} - -func randomJSON() sqltypes.Value { - var j string - switch rand.Intn(6) { - case 0: - j = "null" - case 1: - i := rand.Int63() - if rand.Int()&0x1 == 1 { - i = -i - } - j = strconv.FormatInt(i, 10) - case 2: - j = strconv.FormatFloat(rand.NormFloat64(), 'g', -1, 64) - case 3: - j = strconv.Quote(string(randomBytes())) - case 4: - j = "true" - case 5: - j = "false" - } - v, err := sqltypes.NewJSON(j) - if err != nil { - panic(err) - } - return v -} diff --git a/go/vt/vtgate/semantics/typer.go b/go/vt/vtgate/semantics/typer.go index 5802b539585..6652f1a476b 100644 --- a/go/vt/vtgate/semantics/typer.go +++ b/go/vt/vtgate/semantics/typer.go @@ -55,15 +55,15 @@ func (t *typer) up(cursor *sqlparser.Cursor) error { if !ok { return nil } - var inputType *sqltypes.Type + var inputType sqltypes.Type if arg := node.GetArg(); arg != nil { t, ok := t.exprTypes[arg] if ok { - inputType = &t.Type + inputType = t.Type } } - typ, _ := code.Type(inputType) - t.exprTypes[node] = Type{Type: typ, Collation: collations.DefaultCollationForType(typ)} + type_ := code.Type(inputType) + t.exprTypes[node] = Type{Type: type_, Collation: collations.DefaultCollationForType(type_)} } return nil }