Skip to content

Commit

Permalink
evalengine: Improve weight string support (#13658)
Browse files Browse the repository at this point in the history
Signed-off-by: Dirkjan Bussink <d.bussink@gmail.com>
  • Loading branch information
dbussink authored Jul 28, 2023
1 parent 8ea976d commit 700e93e
Show file tree
Hide file tree
Showing 15 changed files with 510 additions and 66 deletions.
15 changes: 15 additions & 0 deletions go/mysql/datetime/datetime.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package datetime

import (
"encoding/binary"
"time"

"vitess.io/vitess/go/mysql/decimal"
Expand Down Expand Up @@ -641,6 +642,20 @@ func (dt *DateTime) addInterval(itv *Interval) bool {
}
}

func (dt DateTime) WeightString(dst []byte) []byte {
// This logic does the inverse of what we do in the binlog parser for the datetime2 type.
year, month, day := dt.Date.Year(), dt.Date.Month(), dt.Date.Day()
ymd := uint64(year*13+month)<<5 | uint64(day)
hms := uint64(dt.Time.Hour())<<12 | uint64(dt.Time.Minute())<<6 | uint64(dt.Time.Second())
raw := (ymd<<17|hms)<<24 + uint64(dt.Time.Nanosecond()/1000)
if dt.Time.Neg() {
raw = -raw
}

raw = raw ^ (1 << 63)
return binary.BigEndian.AppendUint64(dst, raw)
}

func NewDateFromStd(t time.Time) Date {
year, month, day := t.Date()
return Date{
Expand Down
5 changes: 5 additions & 0 deletions go/mysql/datetime/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,5 +399,10 @@ func ParseTimeDecimal(d decimal.Decimal, l int32, prec int) (Time, int, bool) {
} else {
t = t.Round(prec)
}
// We only support a maximum of nanosecond precision,
// so if the decimal has any larger precision we truncate it.
if prec > 9 {
prec = 9
}
return t, prec, ok
}
56 changes: 56 additions & 0 deletions go/mysql/decimal/weights.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
Copyright 2023 The Vitess Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package decimal

// Our weight string format is normalizing the weight string to a fixed length,
// so it becomes byte-ordered. The byte lengths are pre-computed based on
// https://dev.mysql.com/doc/refman/8.0/en/fixed-point-types.html
// and generated empirically with a manual loop:
//
// for i := 1; i <= 65; i++ {
// dec, err := NewFromMySQL(bytes.Repeat([]byte("9"), i))
// if err != nil {
// t.Fatal(err)
// }
//
// byteLengths = append(byteLengths, len(dec.value.Bytes()))
// }
var weightStringLengths = []int{
0, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5, 5, 6, 6, 7, 7, 8, 8, 8,
9, 9, 10, 10, 10, 11, 11, 12, 12, 13, 13, 13, 14, 14, 15, 15, 15,
16, 16, 17, 17, 18, 18, 18, 19, 19, 20, 20, 20, 21, 21, 22, 22,
23, 23, 23, 24, 24, 25, 25, 25, 26, 26, 27, 27, 27,
}

func (d Decimal) WeightString(dst []byte, length, precision int32) []byte {
dec := d.rescale(-precision)
dec = dec.Clamp(length-precision, precision)

buf := make([]byte, weightStringLengths[length]+1)
dec.value.FillBytes(buf[:])

if dec.value.Sign() < 0 {
for i := range buf {
buf[i] ^= 0xff
}
}
// Use the same trick as used for signed numbers on the first byte.
buf[0] ^= 0x80

dst = append(dst, buf[:]...)
return dst
}
20 changes: 13 additions & 7 deletions go/vt/vtgate/evalengine/api_hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,19 @@ func randTime() time.Time {
return time.Unix(sec, 0)
}

func randomNull() sqltypes.Value { return sqltypes.NULL }
func randomInt8() sqltypes.Value { return sqltypes.NewInt8(int8(rand.Intn(255))) }
func randomInt32() sqltypes.Value { return sqltypes.NewInt32(rand.Int31()) }
func randomInt64() sqltypes.Value { return sqltypes.NewInt64(rand.Int63()) }
func randomUint32() sqltypes.Value { return sqltypes.NewUint32(rand.Uint32()) }
func randomUint64() sqltypes.Value { return sqltypes.NewUint64(rand.Uint64()) }
func randomDecimal() sqltypes.Value { return sqltypes.NewDecimal(fmt.Sprintf("%d", rand.Int63())) }
func randomNull() sqltypes.Value { return sqltypes.NULL }
func randomInt8() sqltypes.Value { return sqltypes.NewInt8(int8(rand.Intn(255))) }
func randomInt32() sqltypes.Value { return sqltypes.NewInt32(rand.Int31()) }
func randomInt64() sqltypes.Value { return sqltypes.NewInt64(rand.Int63()) }
func randomUint32() sqltypes.Value { return sqltypes.NewUint32(rand.Uint32()) }
func randomUint64() sqltypes.Value { return sqltypes.NewUint64(rand.Uint64()) }
func randomDecimal() sqltypes.Value {
dec := fmt.Sprintf("%d.%d", rand.Intn(9999999999), rand.Intn(9999999999))
if rand.Int()&0x1 == 1 {
dec = "-" + dec
}
return sqltypes.NewDecimal(dec)
}
func randomVarChar() sqltypes.Value { return sqltypes.NewVarChar(fmt.Sprintf("%d", rand.Int63())) }
func randomDate() sqltypes.Value { return sqltypes.NewDate(randTime().Format(time.DateOnly)) }
func randomDatetime() sqltypes.Value { return sqltypes.NewDatetime(randTime().Format(time.DateTime)) }
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/evalengine/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 8 additions & 13 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2732,22 +2732,17 @@ func (asm *assembler) Fn_TO_BASE64(t sqltypes.Type, col collations.TypedCollatio
}, "FN TO_BASE64 VARCHAR(SP-1)")
}

func (asm *assembler) Fn_WEIGHT_STRING_b(length int) {
func (asm *assembler) Fn_WEIGHT_STRING(length int) {
asm.emit(func(env *ExpressionEnv) int {
str := env.vm.stack[env.vm.sp-1].(*evalBytes)
w := collations.Binary.WeightString(make([]byte, 0, length), str.bytes, collations.PadToMax)
env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalBinary(w)
return 1
}, "FN WEIGHT_STRING VARBINARY(SP-1)")
}

func (asm *assembler) Fn_WEIGHT_STRING_c(col collations.Collation, length int) {
asm.emit(func(env *ExpressionEnv) int {
str := env.vm.stack[env.vm.sp-1].(*evalBytes)
w := col.WeightString(nil, str.bytes, length)
input := env.vm.stack[env.vm.sp-1]
w, _, err := evalWeightString(nil, input, length, 0)
if err != nil {
env.vm.err = err
return 1
}
env.vm.stack[env.vm.sp-1] = env.vm.arena.newEvalBinary(w)
return 1
}, "FN WEIGHT_STRING VARCHAR(SP-1)")
}, "FN WEIGHT_STRING (SP-1)")
}

func (asm *assembler) In_table(not bool, table map[vthash.Hash]struct{}) {
Expand Down
4 changes: 4 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,10 @@ func TestCompilerSingle(t *testing.T) {
expression: `concat('test', _latin1 0xff)`,
result: `VARCHAR("testÿ")`,
},
{
expression: `WEIGHT_STRING('foobar' as char(3))`,
result: `VARBINARY("\x1c\xe5\x1d\xdd\x1d\xdd")`,
},
}

for _, tc := range testCases {
Expand Down
20 changes: 19 additions & 1 deletion go/vt/vtgate/evalengine/expr_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ func (c *ConvertExpr) eval(env *ExpressionEnv) (eval, error) {
case "JSON":
return evalToJSON(e)
case "DATETIME":
switch p := c.Length; {
case p > 6:
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision %d specified for 'CONVERT'. Maximum is 6.", p)
}
if dt := evalToDateTime(e, c.Length); dt != nil {
return dt, nil
}
Expand All @@ -130,6 +134,10 @@ func (c *ConvertExpr) eval(env *ExpressionEnv) (eval, error) {
}
return nil, nil
case "TIME":
switch p := c.Length; {
case p > 6:
return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Too-big precision %d specified for 'CONVERT'. Maximum is 6.", p)
}
if t := evalToTime(e, c.Length); t != nil {
return t, nil
}
Expand Down Expand Up @@ -227,6 +235,9 @@ func (conv *ConvertExpr) compile(c *compiler) (ctype, error) {
case "DOUBLE", "REAL":
convt = c.compileToFloat(arg, 1)

case "FLOAT":
return ctype{}, c.unsupported(conv)

case "SIGNED", "SIGNED INTEGER":
convt = c.compileToInt64(arg, 1)

Expand All @@ -244,9 +255,17 @@ func (conv *ConvertExpr) compile(c *compiler) (ctype, error) {
convt = c.compileToDate(arg, 1)

case "DATETIME":
switch p := conv.Length; {
case p > 6:
return ctype{}, c.unsupported(conv)
}
convt = c.compileToDateTime(arg, 1, conv.Length)

case "TIME":
switch p := conv.Length; {
case p > 6:
return ctype{}, c.unsupported(conv)
}
convt = c.compileToTime(arg, 1, conv.Length)

default:
Expand All @@ -256,7 +275,6 @@ func (conv *ConvertExpr) compile(c *compiler) (ctype, error) {
c.asm.jumpDestination(skip)
convt.Flag = arg.Flag | flagNullable
return convt, nil

}

func (c *ConvertUsingExpr) eval(env *ExpressionEnv) (eval, error) {
Expand Down
93 changes: 55 additions & 38 deletions go/vt/vtgate/evalengine/fn_string.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ type (
}

builtinWeightString struct {
String Expr
Expr Expr
Cast string
Len int
HasLen bool
Expand Down Expand Up @@ -455,76 +455,93 @@ func (expr *builtinCollation) compile(c *compiler) (ctype, error) {
}

func (c *builtinWeightString) callable() []Expr {
return []Expr{c.String}
return []Expr{c.Expr}
}

func (c *builtinWeightString) typeof(env *ExpressionEnv, fields []*querypb.Field) (sqltypes.Type, typeFlag) {
_, f := c.String.typeof(env, fields)
_, f := c.Expr.typeof(env, fields)
return sqltypes.VarBinary, f
}

func (c *builtinWeightString) eval(env *ExpressionEnv) (eval, error) {
var (
tc collations.TypedCollation
text []byte
weights []byte
length = c.Len
)

str, err := c.String.eval(env)
var weights []byte

input, err := c.Expr.eval(env)
if err != nil {
return nil, err
}

switch str := str.(type) {
case *evalInt64, *evalUint64:
// when calling WEIGHT_STRING with an integral value, MySQL returns the
// internal sort key that would be used in an InnoDB table... we do not
// support that
return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "%s: %s", ErrEvaluatedExprNotSupported, FormatExpr(c))
if c.Cast == "binary" {
weights, _, err = evalWeightString(weights, evalToBinary(input), c.Len, 0)
if err != nil {
return nil, err
}
return newEvalBinary(weights), nil
}

switch val := input.(type) {
case *evalInt64, *evalUint64, *evalTemporal:
weights, _, err = evalWeightString(weights, val, 0, 0)
case *evalBytes:
text = str.bytes
tc = str.col
if val.isBinary() {
weights, _, err = evalWeightString(weights, val, 0, 0)
} else {
var strLen int
if c.Cast == "char" {
strLen = c.Len
}
weights, _, err = evalWeightString(weights, val, strLen, 0)
}
default:
return nil, nil
}

if c.Cast == "binary" {
tc = collationBinary
weights = make([]byte, 0, c.Len)
length = collations.PadToMax
if err != nil {
return nil, err
}

collation := tc.Collation.Get()
weights = collation.WeightString(weights, text, length)
return newEvalBinary(weights), nil
}

func (call *builtinWeightString) compile(c *compiler) (ctype, error) {
str, err := call.String.compile(c)
str, err := call.Expr.compile(c)
if err != nil {
return ctype{}, err
}

switch str.Type {
case sqltypes.Int64, sqltypes.Uint64:
return ctype{}, c.unsupported(call)

case sqltypes.VarChar, sqltypes.VarBinary:
skip := c.compileNullCheck1(str)
var flag typeFlag
if str.Flag&flagNullable != 0 {
flag = flag | flagNullable
}

if call.Cast == "binary" {
c.asm.Fn_WEIGHT_STRING_b(call.Len)
} else {
c.asm.Fn_WEIGHT_STRING_c(str.Col.Collation.Get(), call.Len)
skip := c.compileNullCheck1(str)
if call.Cast == "binary" {
if !sqltypes.IsBinary(str.Type) {
c.asm.Convert_xb(1, sqltypes.VarBinary, 0, false)
}
c.asm.Fn_WEIGHT_STRING(call.Len)
c.asm.jumpDestination(skip)
return ctype{Type: sqltypes.VarBinary, Col: collationBinary}, nil
return ctype{Type: sqltypes.VarBinary, Flag: flagNullable | flagNull, Col: collationBinary}, nil
}

switch str.Type {
case sqltypes.Int64, sqltypes.Uint64, sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp, sqltypes.Time, sqltypes.VarBinary, sqltypes.Binary, sqltypes.Blob:
c.asm.Fn_WEIGHT_STRING(0)

case sqltypes.VarChar, sqltypes.Char, sqltypes.Text:
var strLen int
if call.Cast == "char" {
strLen = call.Len
}
c.asm.Fn_WEIGHT_STRING(strLen)

default:
c.asm.SetNull(1)
return ctype{Type: sqltypes.VarBinary, Flag: flagNullable | flagNull, Col: collationBinary}, nil
flag = flag | flagNull | flagNullable
}

c.asm.jumpDestination(skip)
return ctype{Type: sqltypes.VarBinary, Flag: flag, Col: collationBinary}, nil
}

func (call builtinLeftRight) eval(env *ExpressionEnv) (eval, error) {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (c *CallExpr) format(w *formatter, depth int) {

func (c *builtinWeightString) format(w *formatter, depth int) {
w.WriteString("WEIGHT_STRING(")
c.String.format(w, depth)
c.Expr.format(w, depth)

if c.Cast != "" {
fmt.Fprintf(w, " AS %s(%d)", strings.ToUpper(c.Cast), c.Len)
Expand Down
Loading

0 comments on commit 700e93e

Please sign in to comment.