diff --git a/go/mysql/datetime/datetime.go b/go/mysql/datetime/datetime.go index 5b6a41d8a7f..162ac970e67 100644 --- a/go/mysql/datetime/datetime.go +++ b/go/mysql/datetime/datetime.go @@ -17,6 +17,7 @@ limitations under the License. package datetime import ( + "encoding/binary" "time" "vitess.io/vitess/go/mysql/decimal" @@ -641,6 +642,20 @@ func (dt *DateTime) addInterval(itv *Interval) bool { } } +func (dt DateTime) WeightString(dst []byte) []byte { + // This logic does the inverse of what we do in the binlog parser for the datetime2 type. + year, month, day := dt.Date.Year(), dt.Date.Month(), dt.Date.Day() + ymd := uint64(year*13+month)<<5 | uint64(day) + hms := uint64(dt.Time.Hour())<<12 | uint64(dt.Time.Minute())<<6 | uint64(dt.Time.Second()) + raw := (ymd<<17|hms)<<24 + uint64(dt.Time.Nanosecond()/1000) + if dt.Time.Neg() { + raw = -raw + } + + raw = raw ^ (1 << 63) + return binary.BigEndian.AppendUint64(dst, raw) +} + func NewDateFromStd(t time.Time) Date { year, month, day := t.Date() return Date{ diff --git a/go/mysql/datetime/parse.go b/go/mysql/datetime/parse.go index 52861127cde..e8f17191f4c 100644 --- a/go/mysql/datetime/parse.go +++ b/go/mysql/datetime/parse.go @@ -399,5 +399,10 @@ func ParseTimeDecimal(d decimal.Decimal, l int32, prec int) (Time, int, bool) { } else { t = t.Round(prec) } + // We only support a maximum of nanosecond precision, + // so if the decimal has any larger precision we truncate it. + if prec > 9 { + prec = 9 + } return t, prec, ok } diff --git a/go/mysql/decimal/weights.go b/go/mysql/decimal/weights.go new file mode 100644 index 00000000000..9b8f43a0c65 --- /dev/null +++ b/go/mysql/decimal/weights.go @@ -0,0 +1,56 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package decimal + +// Our weight string format is normalizing the weight string to a fixed length, +// so it becomes byte-ordered. The byte lengths are pre-computed based on +// https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html +// and generated empirically with a manual loop: +// +// for i := 1; i <= 65; i++ { +// dec, err := NewFromMySQL(bytes.Repeat([]byte("9"), i)) +// if err != nil { +// t.Fatal(err) +// } +// +// byteLengths = append(byteLengths, len(dec.value.Bytes())) +// } +var weightStringLengths = []int{ + 0, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5, 5, 6, 6, 7, 7, 8, 8, 8, + 9, 9, 10, 10, 10, 11, 11, 12, 12, 13, 13, 13, 14, 14, 15, 15, 15, + 16, 16, 17, 17, 18, 18, 18, 19, 19, 20, 20, 20, 21, 21, 22, 22, + 23, 23, 23, 24, 24, 25, 25, 25, 26, 26, 27, 27, 27, +} + +func (d Decimal) WeightString(dst []byte, length, precision int32) []byte { + dec := d.rescale(-precision) + dec = dec.Clamp(length-precision, precision) + + buf := make([]byte, weightStringLengths[length]+1) + dec.value.FillBytes(buf[:]) + + if dec.value.Sign() < 0 { + for i := range buf { + buf[i] ^= 0xff + } + } + // Use the same trick as used for signed numbers on the first byte. + buf[0] ^= 0x80 + + dst = append(dst, buf[:]...) + return dst +} diff --git a/go/vt/vtgate/evalengine/api_hash_test.go b/go/vt/vtgate/evalengine/api_hash_test.go index 96b7dbac424..55e39dad77f 100644 --- a/go/vt/vtgate/evalengine/api_hash_test.go +++ b/go/vt/vtgate/evalengine/api_hash_test.go @@ -250,13 +250,19 @@ func randTime() time.Time { return time.Unix(sec, 0) } -func randomNull() sqltypes.Value { return sqltypes.NULL } -func randomInt8() sqltypes.Value { return sqltypes.NewInt8(int8(rand.Intn(255))) } -func randomInt32() sqltypes.Value { return sqltypes.NewInt32(rand.Int31()) } -func randomInt64() sqltypes.Value { return sqltypes.NewInt64(rand.Int63()) } -func randomUint32() sqltypes.Value { return sqltypes.NewUint32(rand.Uint32()) } -func randomUint64() sqltypes.Value { return sqltypes.NewUint64(rand.Uint64()) } -func randomDecimal() sqltypes.Value { return sqltypes.NewDecimal(fmt.Sprintf("%d", rand.Int63())) } +func randomNull() sqltypes.Value { return sqltypes.NULL } +func randomInt8() sqltypes.Value { return sqltypes.NewInt8(int8(rand.Intn(255))) } +func randomInt32() sqltypes.Value { return sqltypes.NewInt32(rand.Int31()) } +func randomInt64() sqltypes.Value { return sqltypes.NewInt64(rand.Int63()) } +func randomUint32() sqltypes.Value { return sqltypes.NewUint32(rand.Uint32()) } +func randomUint64() sqltypes.Value { return sqltypes.NewUint64(rand.Uint64()) } +func randomDecimal() sqltypes.Value { + dec := fmt.Sprintf("%d.%d", rand.Intn(9999999999), rand.Intn(9999999999)) + if rand.Int()&0x1 == 1 { + dec = "-" + dec + } + return sqltypes.NewDecimal(dec) +} func randomVarChar() sqltypes.Value { return sqltypes.NewVarChar(fmt.Sprintf("%d", rand.Int63())) } func randomDate() sqltypes.Value { return sqltypes.NewDate(randTime().Format(time.DateOnly)) } func randomDatetime() sqltypes.Value { return sqltypes.NewDatetime(randTime().Format(time.DateTime)) } diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index e8680e6ea45..de45e6ccc84 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -1625,8 +1625,8 @@ func (cached *builtinWeightString) CachedSize(alloc bool) int64 { if alloc { size += int64(48) } - // field String vitess.io/vitess/go/vt/vtgate/evalengine.Expr - if cc, ok := cached.String.(cachedObject); ok { + // field Expr vitess.io/vitess/go/vt/vtgate/evalengine.Expr + if cc, ok := cached.Expr.(cachedObject); ok { size += cc.CachedSize(true) } // field Cast string diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 74741c7b1cc..a67c3d4a259 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -2732,22 +2732,17 @@ func (asm *assembler) Fn_TO_BASE64(t sqltypes.Type, col collations.TypedCollatio }, "FN TO_BASE64 VARCHAR(SP-1)") } -func (asm *assembler) Fn_WEIGHT_STRING_b(length int) { +func (asm *assembler) Fn_WEIGHT_STRING(length int) { asm.emit(func(env *ExpressionEnv) int { - str := env.vm.stack[env.vm.sp-1].(*evalBytes) - w := collations.Binary.WeightString(make([]byte, 0, length), str.bytes, collations.PadToMax) - env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalBinary(w) - return 1 - }, "FN WEIGHT_STRING VARBINARY(SP-1)") -} - -func (asm *assembler) Fn_WEIGHT_STRING_c(col collations.Collation, length int) { - asm.emit(func(env *ExpressionEnv) int { - str := env.vm.stack[env.vm.sp-1].(*evalBytes) - w := col.WeightString(nil, str.bytes, length) + input := env.vm.stack[env.vm.sp-1] + w, _, err := evalWeightString(nil, input, length, 0) + if err != nil { + env.vm.err = err + return 1 + } env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalBinary(w) return 1 - }, "FN WEIGHT_STRING VARCHAR(SP-1)") + }, "FN WEIGHT_STRING (SP-1)") } func (asm *assembler) In_table(not bool, table map[vthash.Hash]struct{}) { diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 8fa5a3b15c5..e27089a44fb 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -464,6 +464,10 @@ func TestCompilerSingle(t *testing.T) { expression: `concat('test', _latin1 0xff)`, result: `VARCHAR("testÿ")`, }, + { + expression: `WEIGHT_STRING('foobar' as char(3))`, + result: `VARBINARY("\x1c\xe5\x1d\xdd\x1d\xdd")`, + }, } for _, tc := range testCases { diff --git a/go/vt/vtgate/evalengine/expr_convert.go b/go/vt/vtgate/evalengine/expr_convert.go index 8c7c079228b..54961dd3774 100644 --- a/go/vt/vtgate/evalengine/expr_convert.go +++ b/go/vt/vtgate/evalengine/expr_convert.go @@ -120,6 +120,10 @@ func (c *ConvertExpr) eval(env *ExpressionEnv) (eval, error) { case "JSON": return evalToJSON(e) case "DATETIME": + switch p := c.Length; { + case p > 6: + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision %d specified for 'CONVERT'. Maximum is 6.", p) + } if dt := evalToDateTime(e, c.Length); dt != nil { return dt, nil } @@ -130,6 +134,10 @@ func (c *ConvertExpr) eval(env *ExpressionEnv) (eval, error) { } return nil, nil case "TIME": + switch p := c.Length; { + case p > 6: + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision %d specified for 'CONVERT'. Maximum is 6.", p) + } if t := evalToTime(e, c.Length); t != nil { return t, nil } @@ -227,6 +235,9 @@ func (conv *ConvertExpr) compile(c *compiler) (ctype, error) { case "DOUBLE", "REAL": convt = c.compileToFloat(arg, 1) + case "FLOAT": + return ctype{}, c.unsupported(conv) + case "SIGNED", "SIGNED INTEGER": convt = c.compileToInt64(arg, 1) @@ -244,9 +255,17 @@ func (conv *ConvertExpr) compile(c *compiler) (ctype, error) { convt = c.compileToDate(arg, 1) case "DATETIME": + switch p := conv.Length; { + case p > 6: + return ctype{}, c.unsupported(conv) + } convt = c.compileToDateTime(arg, 1, conv.Length) case "TIME": + switch p := conv.Length; { + case p > 6: + return ctype{}, c.unsupported(conv) + } convt = c.compileToTime(arg, 1, conv.Length) default: @@ -256,7 +275,6 @@ func (conv *ConvertExpr) compile(c *compiler) (ctype, error) { c.asm.jumpDestination(skip) convt.Flag = arg.Flag | flagNullable return convt, nil - } func (c *ConvertUsingExpr) eval(env *ExpressionEnv) (eval, error) { diff --git a/go/vt/vtgate/evalengine/fn_string.go b/go/vt/vtgate/evalengine/fn_string.go index 211e0c46fd6..2c5c0a71597 100644 --- a/go/vt/vtgate/evalengine/fn_string.go +++ b/go/vt/vtgate/evalengine/fn_string.go @@ -61,7 +61,7 @@ type ( } builtinWeightString struct { - String Expr + Expr Expr Cast string Len int HasLen bool @@ -455,76 +455,93 @@ func (expr *builtinCollation) compile(c *compiler) (ctype, error) { } func (c *builtinWeightString) callable() []Expr { - return []Expr{c.String} + return []Expr{c.Expr} } func (c *builtinWeightString) typeof(env *ExpressionEnv, fields []*querypb.Field) (sqltypes.Type, typeFlag) { - _, f := c.String.typeof(env, fields) + _, f := c.Expr.typeof(env, fields) return sqltypes.VarBinary, f } func (c *builtinWeightString) eval(env *ExpressionEnv) (eval, error) { - var ( - tc collations.TypedCollation - text []byte - weights []byte - length = c.Len - ) - - str, err := c.String.eval(env) + var weights []byte + + input, err := c.Expr.eval(env) if err != nil { return nil, err } - switch str := str.(type) { - case *evalInt64, *evalUint64: - // when calling WEIGHT_STRING with an integral value, MySQL returns the - // internal sort key that would be used in an InnoDB table... we do not - // support that - return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "%s: %s", ErrEvaluatedExprNotSupported, FormatExpr(c)) + if c.Cast == "binary" { + weights, _, err = evalWeightString(weights, evalToBinary(input), c.Len, 0) + if err != nil { + return nil, err + } + return newEvalBinary(weights), nil + } + + switch val := input.(type) { + case *evalInt64, *evalUint64, *evalTemporal: + weights, _, err = evalWeightString(weights, val, 0, 0) case *evalBytes: - text = str.bytes - tc = str.col + if val.isBinary() { + weights, _, err = evalWeightString(weights, val, 0, 0) + } else { + var strLen int + if c.Cast == "char" { + strLen = c.Len + } + weights, _, err = evalWeightString(weights, val, strLen, 0) + } default: return nil, nil } - if c.Cast == "binary" { - tc = collationBinary - weights = make([]byte, 0, c.Len) - length = collations.PadToMax + if err != nil { + return nil, err } - collation := tc.Collation.Get() - weights = collation.WeightString(weights, text, length) return newEvalBinary(weights), nil } func (call *builtinWeightString) compile(c *compiler) (ctype, error) { - str, err := call.String.compile(c) + str, err := call.Expr.compile(c) if err != nil { return ctype{}, err } - switch str.Type { - case sqltypes.Int64, sqltypes.Uint64: - return ctype{}, c.unsupported(call) - - case sqltypes.VarChar, sqltypes.VarBinary: - skip := c.compileNullCheck1(str) + var flag typeFlag + if str.Flag&flagNullable != 0 { + flag = flag | flagNullable + } - if call.Cast == "binary" { - c.asm.Fn_WEIGHT_STRING_b(call.Len) - } else { - c.asm.Fn_WEIGHT_STRING_c(str.Col.Collation.Get(), call.Len) + skip := c.compileNullCheck1(str) + if call.Cast == "binary" { + if !sqltypes.IsBinary(str.Type) { + c.asm.Convert_xb(1, sqltypes.VarBinary, 0, false) } + c.asm.Fn_WEIGHT_STRING(call.Len) c.asm.jumpDestination(skip) - return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil + return ctype{Type: sqltypes.VarBinary, Flag: flagNullable | flagNull, Col: collationBinary}, nil + } + + switch str.Type { + case sqltypes.Int64, sqltypes.Uint64, sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp, sqltypes.Time, sqltypes.VarBinary, sqltypes.Binary, sqltypes.Blob: + c.asm.Fn_WEIGHT_STRING(0) + + case sqltypes.VarChar, sqltypes.Char, sqltypes.Text: + var strLen int + if call.Cast == "char" { + strLen = call.Len + } + c.asm.Fn_WEIGHT_STRING(strLen) default: c.asm.SetNull(1) - return ctype{Type: sqltypes.VarBinary, Flag: flagNullable | flagNull, Col: collationBinary}, nil + flag = flag | flagNull | flagNullable } + + c.asm.jumpDestination(skip) + return ctype{Type: sqltypes.VarBinary, Flag: flag, Col: collationBinary}, nil } func (call builtinLeftRight) eval(env *ExpressionEnv) (eval, error) { diff --git a/go/vt/vtgate/evalengine/format.go b/go/vt/vtgate/evalengine/format.go index 4c043d399d4..fe641e0954a 100644 --- a/go/vt/vtgate/evalengine/format.go +++ b/go/vt/vtgate/evalengine/format.go @@ -170,7 +170,7 @@ func (c *CallExpr) format(w *formatter, depth int) { func (c *builtinWeightString) format(w *formatter, depth int) { w.WriteString("WEIGHT_STRING(") - c.String.format(w, depth) + c.Expr.format(w, depth) if c.Cast != "" { fmt.Fprintf(w, " AS %s(%d)", strings.ToUpper(c.Cast), c.Len) diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index d6e692b1a99..365da12a626 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -829,10 +829,24 @@ func BitwiseOperators(yield Query) { func WeightString(yield Query) { var inputs = []string{ `'foobar'`, `_latin1 'foobar'`, - `'foobar' as char(12)`, `'foobar' as binary(12)`, + `'foobar' as char(12)`, `'foobar' as char(3)`, `'foobar' as binary(12)`, `'foobar' as binary(3)`, + `'foobar' collate utf8mb4_bin as char(12)`, `'foobar' collate utf8mb4_bin as char(3)`, + `'foobar' collate binary as char(12)`, `'foobar' collate binary as char(3)`, `_latin1 'foobar' as char(12)`, `_latin1 'foobar' as binary(12)`, + `_binary 'foobar' as char(12)`, `_binary 'foobar' as binary(12)`, + `1`, `-1`, `9223372036854775807`, `18446744073709551615`, `-9223372036854775808`, + `1 as char(1)`, `-1 as char(1)`, `9223372036854775807 as char(1)`, `18446744073709551615 as char(1)`, `-9223372036854775808 as char(1)`, + `1 as char(32)`, `-1 as char(32)`, `9223372036854775807 as char(32)`, `18446744073709551615 as char(32)`, `-9223372036854775808 as char(32)`, + `1 as binary(1)`, `-1 as binary(1)`, `9223372036854775807 as binary(1)`, `18446744073709551615 as binary(1)`, `-9223372036854775808 as binary(1)`, + `1 as binary(32)`, `-1 as binary(32)`, `9223372036854775807 as binary(32)`, `18446744073709551615 as binary(32)`, `-9223372036854775808 as binary(32)`, `1234.0`, `12340e0`, `0x1234`, `0x1234 as char(12)`, `0x1234 as char(2)`, + `date'2000-01-01'`, `date'2000-01-01' as char(12)`, `date'2000-01-01' as char(2)`, `date'2000-01-01' as binary(12)`, `date'2000-01-01' as binary(2)`, + `timestamp'2000-01-01 11:22:33'`, `timestamp'2000-01-01 11:22:33' as char(12)`, `timestamp'2000-01-01 11:22:33' as char(2)`, `timestamp'2000-01-01 11:22:33' as binary(12)`, `timestamp'2000-01-01 11:22:33' as binary(2)`, + `timestamp'2000-01-01 11:22:33.123456'`, `timestamp'2000-01-01 11:22:33.123456' as char(12)`, `timestamp'2000-01-01 11:22:33.123456' as char(2)`, `timestamp'2000-01-01 11:22:33.123456' as binary(12)`, `timestamp'2000-01-01 11:22:33.123456' as binary(2)`, + `time'-11:22:33'`, `time'-11:22:33' as char(12)`, `time'-11:22:33' as char(2)`, `time'-11:22:33' as binary(12)`, `time'-11:22:33' as binary(2)`, + `time'11:22:33'`, `time'11:22:33' as char(12)`, `time'11:22:33' as char(2)`, `time'11:22:33' as binary(12)`, `time'11:22:33' as binary(2)`, + `time'101:22:33'`, `time'101:22:33' as char(12)`, `time'101:22:33' as char(2)`, `time'101:22:33' as binary(12)`, `time'101:22:33' as binary(2)`, } for _, i := range inputs { diff --git a/go/vt/vtgate/evalengine/translate_builtin.go b/go/vt/vtgate/evalengine/translate_builtin.go index 49784973180..4a4c3f1d9d2 100644 --- a/go/vt/vtgate/evalengine/translate_builtin.go +++ b/go/vt/vtgate/evalengine/translate_builtin.go @@ -594,7 +594,7 @@ func (ast *astCompiler) translateCallable(call sqlparser.Callable) (Expr, error) var ws builtinWeightString var err error - ws.String, err = ast.translateExpr(call.Expr) + ws.Expr, err = ast.translateExpr(call.Expr) if err != nil { return nil, err } diff --git a/go/vt/vtgate/evalengine/translate_simplify.go b/go/vt/vtgate/evalengine/translate_simplify.go index 3e957a943fc..d851114751d 100644 --- a/go/vt/vtgate/evalengine/translate_simplify.go +++ b/go/vt/vtgate/evalengine/translate_simplify.go @@ -123,12 +123,12 @@ func (c *CallExpr) simplify(env *ExpressionEnv) error { } func (c *builtinWeightString) constant() bool { - return c.String.constant() + return c.Expr.constant() } func (c *builtinWeightString) simplify(env *ExpressionEnv) error { var err error - c.String, err = simplifyExpr(env, c.String) + c.Expr, err = simplifyExpr(env, c.Expr) return err } diff --git a/go/vt/vtgate/evalengine/weights.go b/go/vt/vtgate/evalengine/weights.go new file mode 100644 index 00000000000..ee7fd42774d --- /dev/null +++ b/go/vt/vtgate/evalengine/weights.go @@ -0,0 +1,175 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "encoding/binary" + "math" + + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/mysql/collations/charset" + "vitess.io/vitess/go/mysql/decimal" + "vitess.io/vitess/go/mysql/json" + "vitess.io/vitess/go/sqltypes" + vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/vterrors" +) + +// WeightString returns the weight string for a value. +// It appends to dst if an existing slice is given, otherwise it +// returns a new one. +// The returned boolean indicates whether the weight string is a +// fixed-width weight string, such as for fixed size integer values. +// Our WeightString implementation supports more types that MySQL +// externally communicates with the `WEIGHT_STRING` function, so that we +// can also use this to order / sort other types like Float and Decimal +// as well. +func WeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int) ([]byte, bool, error) { + // We optimize here for the case where we already have the desired type. + // Otherwise, we fall back to the general evalengine conversion logic. + if v.Type() != coerceTo { + return fallbackWeightString(dst, v, coerceTo, col, length, precision) + } + + switch { + case sqltypes.IsNull(coerceTo): + return nil, true, nil + + case sqltypes.IsSigned(coerceTo): + i, err := v.ToInt64() + if err != nil { + return dst, false, err + } + raw := uint64(i) + raw = raw ^ (1 << 63) + return binary.BigEndian.AppendUint64(dst, raw), true, nil + + case sqltypes.IsUnsigned(coerceTo): + u, err := v.ToUint64() + if err != nil { + return dst, false, err + } + return binary.BigEndian.AppendUint64(dst, u), true, nil + + case sqltypes.IsFloat(coerceTo): + f, err := v.ToFloat64() + if err != nil { + return dst, false, err + } + + raw := math.Float64bits(f) + if math.Signbit(f) { + raw = ^raw + } else { + raw = raw ^ (1 << 63) + } + return binary.BigEndian.AppendUint64(dst, raw), true, nil + + case sqltypes.IsBinary(coerceTo): + b := v.Raw() + if length != 0 { + if length > cap(b) { + b = append(b, make([]byte, length-len(b))...) + } else { + b = b[:length] + } + } + return append(dst, b...), false, nil + + case sqltypes.IsText(coerceTo): + coll := col.Get() + if coll == nil { + return dst, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot hash unsupported collation") + } + b := v.Raw() + if length != 0 { + b = charset.Slice(coll.Charset(), b, 0, length) + } + return coll.WeightString(dst, b, length), false, nil + + case sqltypes.IsDecimal(coerceTo): + dec, err := decimal.NewFromMySQL(v.Raw()) + if err != nil { + return dst, false, err + } + return dec.WeightString(dst, int32(length), int32(precision)), true, nil + case coerceTo == sqltypes.TypeJSON: + j, err := json.NewFromSQL(v) + if err != nil { + return dst, false, err + } + return j.WeightString(dst), false, nil + default: + return fallbackWeightString(dst, v, coerceTo, col, length, precision) + } +} + +func fallbackWeightString(dst []byte, v sqltypes.Value, coerceTo sqltypes.Type, col collations.ID, length, precision int) ([]byte, bool, error) { + e, err := valueToEvalCast(v, coerceTo, col) + if err != nil { + return dst, false, err + } + return evalWeightString(dst, e, length, precision) +} + +func evalWeightString(dst []byte, e eval, length, precision int) ([]byte, bool, error) { + switch e := e.(type) { + case nil: + return nil, true, nil + case *evalInt64: + raw := uint64(e.i) + raw = raw ^ (1 << 63) + return binary.BigEndian.AppendUint64(dst, raw), true, nil + case *evalUint64: + return binary.BigEndian.AppendUint64(dst, e.u), true, nil + case *evalFloat: + raw := math.Float64bits(e.f) + if math.Signbit(e.f) { + raw = ^raw + } else { + raw = raw ^ (1 << 63) + } + return binary.BigEndian.AppendUint64(dst, raw), true, nil + case *evalDecimal: + return e.dec.WeightString(dst, int32(length), int32(precision)), true, nil + case *evalBytes: + if e.isBinary() { + b := e.bytes + if length != 0 { + if length > cap(b) { + b = append(b, make([]byte, length-len(b))...) + } else { + b = b[:length] + } + } + return append(dst, b...), false, nil + } + coll := e.col.Collation.Get() + if coll == nil { + return dst, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "cannot hash unsupported collation") + } + b := e.bytes + if length != 0 { + b = charset.Slice(coll.Charset(), b, 0, length) + } + return coll.WeightString(dst, b, length), false, nil + case *evalTemporal: + return e.dt.WeightString(dst), true, nil + } + + return dst, false, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "unexpected type %v", e.SQLType()) +} diff --git a/go/vt/vtgate/evalengine/weights_test.go b/go/vt/vtgate/evalengine/weights_test.go new file mode 100644 index 00000000000..f3969d9d1cf --- /dev/null +++ b/go/vt/vtgate/evalengine/weights_test.go @@ -0,0 +1,139 @@ +/* +Copyright 2023 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "math/rand" + "strconv" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" + + "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/sqltypes" +) + +func TestWeightStrings(t *testing.T) { + const Length = 1000 + + type item struct { + value sqltypes.Value + weight string + } + + var cases = []struct { + name string + gen func() sqltypes.Value + t sqltypes.Type + col collations.ID + len int + prec int + }{ + {name: "int64", gen: randomInt64, t: sqltypes.Int64, col: collations.CollationBinaryID}, + {name: "uint64", gen: randomUint64, t: sqltypes.Uint64, col: collations.CollationBinaryID}, + {name: "float64", gen: randomFloat64, t: sqltypes.Float64, col: collations.CollationBinaryID}, + {name: "varchar", gen: randomVarChar, t: sqltypes.VarChar, col: collations.CollationUtf8mb4ID}, + {name: "varbinary", gen: randomVarBinary, t: sqltypes.VarBinary, col: collations.CollationBinaryID}, + {name: "decimal", gen: randomDecimal, t: sqltypes.Decimal, col: collations.CollationBinaryID, len: 20, prec: 10}, + {name: "json", gen: randomJSON, t: sqltypes.TypeJSON, col: collations.CollationBinaryID}, + {name: "date", gen: randomDate, t: sqltypes.Date, col: collations.CollationBinaryID}, + {name: "datetime", gen: randomDatetime, t: sqltypes.Datetime, col: collations.CollationBinaryID}, + {name: "timestamp", gen: randomTimestamp, t: sqltypes.Timestamp, col: collations.CollationBinaryID}, + {name: "time", gen: randomTime, t: sqltypes.Time, col: collations.CollationBinaryID}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + items := make([]item, 0, Length) + for i := 0; i < Length; i++ { + v := tc.gen() + w, _, err := WeightString(nil, v, tc.t, tc.col, tc.len, tc.prec) + require.NoError(t, err) + + items = append(items, item{value: v, weight: string(w)}) + } + + slices.SortFunc(items, func(a, b item) int { + if a.weight < b.weight { + return -1 + } else if a.weight > b.weight { + return 1 + } else { + return 0 + } + }) + + for i := 0; i < Length-1; i++ { + a := items[i] + b := items[i+1] + + cmp, err := NullsafeCompare(a.value, b.value, tc.col) + require.NoError(t, err) + + if cmp > 0 { + t.Fatalf("expected %v [pos=%d] to come after %v [pos=%d]\nav = %v\nbv = %v", + a.value, i, b.value, i+1, + []byte(a.weight), []byte(b.weight), + ) + } + } + }) + } +} + +func randomVarBinary() sqltypes.Value { return sqltypes.NewVarBinary(string(randomBytes())) } +func randomFloat64() sqltypes.Value { + return sqltypes.NewFloat64(rand.NormFloat64()) +} + +func randomBytes() []byte { + const Dictionary = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + + b := make([]byte, 4+rand.Intn(256)) + for i := range b { + b[i] = Dictionary[rand.Intn(len(Dictionary))] + } + return b +} + +func randomJSON() sqltypes.Value { + var j string + switch rand.Intn(6) { + case 0: + j = "null" + case 1: + i := rand.Int63() + if rand.Int()&0x1 == 1 { + i = -i + } + j = strconv.FormatInt(i, 10) + case 2: + j = strconv.FormatFloat(rand.NormFloat64(), 'g', -1, 64) + case 3: + j = strconv.Quote(string(randomBytes())) + case 4: + j = "true" + case 5: + j = "false" + } + v, err := sqltypes.NewJSON(j) + if err != nil { + panic(err) + } + return v +}