diff --git a/go/test/endtoend/utils/utils.go b/go/test/endtoend/utils/utils.go index d9e94911e30..468841c23f6 100644 --- a/go/test/endtoend/utils/utils.go +++ b/go/test/endtoend/utils/utils.go @@ -88,7 +88,13 @@ func AssertMatchesAny(t testing.TB, conn *mysql.Conn, query string, expected ... return } } - t.Errorf("Query: %s (-want +got):\n%v\nGot:%s", query, expected, got) + + var err strings.Builder + _, _ = fmt.Fprintf(&err, "Query did not match:\n%s\n", query) + for i, e := range expected { + _, _ = fmt.Fprintf(&err, "Expected query %d does not match.\nwant: %v\ngot: %v\n\n", i, e, got) + } + t.Error(err.String()) } // AssertMatchesCompareMySQL executes the given query on both Vitess and MySQL and make sure diff --git a/go/test/endtoend/vtgate/queries/informationschema/informationschema_test.go b/go/test/endtoend/vtgate/queries/informationschema/informationschema_test.go index 5ba9877bf5f..887eefc7747 100644 --- a/go/test/endtoend/vtgate/queries/informationschema/informationschema_test.go +++ b/go/test/endtoend/vtgate/queries/informationschema/informationschema_test.go @@ -220,6 +220,26 @@ func TestInfrSchemaAndUnionAll(t *testing.T) { } } +func TestInfoschemaTypes(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") + + require.NoError(t, + utils.WaitForAuthoritative(t, "ks", "t1", clusterInstance.VtgateProcess.ReadVSchema)) + + mcmp, closer := start(t) + defer closer() + + mcmp.Exec(` + SELECT ORDINAL_POSITION + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = 'ks' AND TABLE_NAME = 't1' + UNION + SELECT ORDINAL_POSITION + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = 'ks' AND TABLE_NAME = 't2'; + `) +} + func TestTypeORMQuery(t *testing.T) { utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") // This test checks that we can run queries similar to the ones that the TypeORM framework uses diff --git a/go/test/endtoend/vtgate/queries/orderby/orderby_test.go b/go/test/endtoend/vtgate/queries/orderby/orderby_test.go index 993f7834301..b63ecc1b004 100644 --- a/go/test/endtoend/vtgate/queries/orderby/orderby_test.go +++ b/go/test/endtoend/vtgate/queries/orderby/orderby_test.go @@ -145,6 +145,7 @@ func TestOrderByComplex(t *testing.T) { "select email, max(col) as max_col from (select email, col from user where col > 20) as filtered group by email order by max_col", "select a.email, a.max_col from (select email, max(col) as max_col from user group by email) as a order by a.max_col desc", "select email, max(col) as max_col from user where email like 'a%' group by email order by max_col, email", + `select email, max(col) as max_col from user group by email union select email, avg(col) as avg_col from user group by email order by email desc`, } for _, query := range queries { diff --git a/go/vt/vtgate/engine/concatenate.go b/go/vt/vtgate/engine/concatenate.go index 27b35c32aa8..352b190fb1d 100644 --- a/go/vt/vtgate/engine/concatenate.go +++ b/go/vt/vtgate/engine/concatenate.go @@ -96,13 +96,13 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars return nil, err } - fields, err := c.getFields(res) + fields, fieldTypes, err := c.getFieldTypes(vcursor, res) if err != nil { return nil, err } var rows [][]sqltypes.Value - err = c.coerceAndVisitResults(res, fields, func(result *sqltypes.Result) error { + err = c.coerceAndVisitResults(res, fieldTypes, func(result *sqltypes.Result) error { rows = append(rows, result.Rows...) return nil }, evalengine.ParseSQLMode(vcursor.SQLMode())) @@ -116,8 +116,8 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars }, nil } -func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field, sqlmode evalengine.SQLMode) error { - if len(row) != len(fields) { +func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fieldTypes []evalengine.Type, sqlmode evalengine.SQLMode) error { + if len(row) != len(fieldTypes) { return errWrongNumberOfColumnsInSelect } @@ -125,8 +125,8 @@ func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field, if _, found := c.NoNeedToTypeCheck[i]; found { continue } - if fields[i].Type != value.Type() { - newValue, err := evalengine.CoerceTo(value, fields[i].Type, sqlmode) + if fieldTypes[i].Type() != value.Type() { + newValue, err := evalengine.CoerceTo(value, fieldTypes[i], sqlmode) if err != nil { return err } @@ -136,44 +136,44 @@ func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field, return nil } -func (c *Concatenate) getFields(res []*sqltypes.Result) (resultFields []*querypb.Field, err error) { +func (c *Concatenate) getFieldTypes(vcursor VCursor, res []*sqltypes.Result) ([]*querypb.Field, []evalengine.Type, error) { if len(res) == 0 { - return nil, nil + return nil, nil, nil } - resultFields = res[0].Fields - columns := make([][]sqltypes.Type, len(resultFields)) - - addFields := func(fields []*querypb.Field) error { - if len(fields) != len(columns) { - return errWrongNumberOfColumnsInSelect - } - for idx, field := range fields { - columns[idx] = append(columns[idx], field.Type) - } - return nil - } + typers := make([]evalengine.TypeAggregator, len(res[0].Fields)) + collations := vcursor.Environment().CollationEnv() for _, r := range res { if r == nil || r.Fields == nil { continue } - err := addFields(r.Fields) - if err != nil { - return nil, err + if len(r.Fields) != len(typers) { + return nil, nil, errWrongNumberOfColumnsInSelect + } + for idx, field := range r.Fields { + if err := typers[idx].AddField(field, collations); err != nil { + return nil, nil, err + } } } - // The resulting column types need to be the coercion of all the input columns - for colIdx, t := range columns { + fields := make([]*querypb.Field, 0, len(typers)) + types := make([]evalengine.Type, 0, len(typers)) + for colIdx, typer := range typers { + f := res[0].Fields[colIdx] + if _, found := c.NoNeedToTypeCheck[colIdx]; found { + fields = append(fields, f) + types = append(types, evalengine.NewTypeFromField(f)) continue } - resultFields[colIdx].Type = evalengine.AggregateTypes(t) + t := typer.Type() + fields = append(fields, t.ToField(f.Name)) + types = append(types, t) } - - return resultFields, nil + return fields, types, nil } func (c *Concatenate) execSources(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) ([]*sqltypes.Result, error) { @@ -250,7 +250,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, condFields = sync.NewCond(&muFields) // Condition var for field arrival wg errgroup.Group // Wait group for all streaming goroutines rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields - fields []*querypb.Field // Cached final field types + fieldTypes []evalengine.Type // Cached final field types ) // Process each result chunk, considering type coercion. @@ -263,7 +263,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, needsCoercion := false for idx, field := range rest[srcIdx].Fields { _, skip := c.NoNeedToTypeCheck[idx] - if !skip && fields[idx].Type != field.Type { + if !skip && fieldTypes[idx].Type() != field.Type { needsCoercion = true break } @@ -272,7 +272,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, // Apply type coercion if needed. if needsCoercion { for _, row := range res.Rows { - if err := c.coerceValuesTo(row, fields, sqlmode); err != nil { + if err := c.coerceValuesTo(row, fieldTypes, sqlmode); err != nil { return err } } @@ -299,11 +299,10 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor, // We have received fields from all sources. We can now calculate the output types var err error - fields, err = c.getFields(rest) + resultChunk.Fields, fieldTypes, err = c.getFieldTypes(vcursor, rest) if err != nil { return err } - resultChunk.Fields = fields defer condFields.Broadcast() return callback(resultChunk, currIndex) @@ -370,12 +369,12 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, firsts[i] = result[0] } - fields, err := c.getFields(firsts) + _, fieldTypes, err := c.getFieldTypes(vcursor, firsts) if err != nil { return err } for _, res := range results { - if err = c.coerceAndVisitResults(res, fields, callback, sqlmode); err != nil { + if err = c.coerceAndVisitResults(res, fieldTypes, callback, sqlmode); err != nil { return err } } @@ -385,26 +384,26 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor, func (c *Concatenate) coerceAndVisitResults( res []*sqltypes.Result, - fields []*querypb.Field, + fieldTypes []evalengine.Type, callback func(*sqltypes.Result) error, sqlmode evalengine.SQLMode, ) error { for _, r := range res { if len(r.Rows) > 0 && - len(fields) != len(r.Rows[0]) { + len(fieldTypes) != len(r.Rows[0]) { return errWrongNumberOfColumnsInSelect } needsCoercion := false for idx, field := range r.Fields { - if fields[idx].Type != field.Type { + if fieldTypes[idx].Type() != field.Type { needsCoercion = true break } } if needsCoercion { for _, row := range r.Rows { - err := c.coerceValuesTo(row, fields, sqlmode) + err := c.coerceValuesTo(row, fieldTypes, sqlmode) if err != nil { return err } @@ -420,35 +419,29 @@ func (c *Concatenate) coerceAndVisitResults( // GetFields fetches the field info. func (c *Concatenate) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) { - res, err := c.Sources[0].GetFields(ctx, vcursor, bindVars) - if err != nil { - return nil, err - } - - columns := make([][]sqltypes.Type, len(res.Fields)) - - addFields := func(fields []*querypb.Field) { - for idx, field := range fields { - columns[idx] = append(columns[idx], field.Type) - } - } - - addFields(res.Fields) - - for i := 1; i < len(c.Sources); i++ { - result, err := c.Sources[i].GetFields(ctx, vcursor, bindVars) + sourceFields := make([][]*querypb.Field, 0, len(c.Sources)) + for _, src := range c.Sources { + f, err := src.GetFields(ctx, vcursor, bindVars) if err != nil { return nil, err } - addFields(result.Fields) + sourceFields = append(sourceFields, f.Fields) } - // The resulting column types need to be the coercion of all the input columns - for colIdx, t := range columns { - res.Fields[colIdx].Type = evalengine.AggregateTypes(t) - } + fields := make([]*querypb.Field, 0, len(sourceFields[0])) + collations := vcursor.Environment().CollationEnv() - return res, nil + for colIdx := 0; colIdx < len(sourceFields[0]); colIdx++ { + var typer evalengine.TypeAggregator + for _, src := range sourceFields { + if err := typer.AddField(src[colIdx], collations); err != nil { + return nil, err + } + } + name := sourceFields[0][colIdx].Name + fields = append(fields, typer.Field(name)) + } + return &sqltypes.Result{Fields: fields}, nil } // NeedsTransaction returns whether a transaction is needed for this primitive diff --git a/go/vt/vtgate/engine/concatenate_test.go b/go/vt/vtgate/engine/concatenate_test.go index b886d1312af..dd2b1300e9b 100644 --- a/go/vt/vtgate/engine/concatenate_test.go +++ b/go/vt/vtgate/engine/concatenate_test.go @@ -23,6 +23,7 @@ import ( "strings" "testing" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/test/utils" "github.com/stretchr/testify/assert" @@ -32,7 +33,17 @@ import ( ) func r(names, types string, rows ...string) *sqltypes.Result { - return sqltypes.MakeTestResult(sqltypes.MakeTestFields(names, types), rows...) + fields := sqltypes.MakeTestFields(names, types) + for _, f := range fields { + if sqltypes.IsText(f.Type) { + f.Charset = collations.CollationUtf8mb4ID + } else { + f.Charset = collations.CollationBinaryID + } + _, flags := sqltypes.TypeToMySQL(f.Type) + f.Flags = uint32(flags) + } + return sqltypes.MakeTestResult(fields, rows...) } func TestConcatenate_NoErrors(t *testing.T) { @@ -173,12 +184,12 @@ func TestConcatenateTypes(t *testing.T) { tests := []struct { t1, t2, expected string }{ - {t1: "int32", t2: "int64", expected: "int64"}, - {t1: "int32", t2: "int32", expected: "int32"}, - {t1: "int32", t2: "varchar", expected: "varchar"}, - {t1: "int32", t2: "decimal", expected: "decimal"}, - {t1: "hexval", t2: "uint64", expected: "varchar"}, - {t1: "varchar", t2: "varbinary", expected: "varbinary"}, + {t1: "int32", t2: "int64", expected: `[name:"id" type:int64 charset:63]`}, + {t1: "int32", t2: "int32", expected: `[name:"id" type:int32 charset:63]`}, + {t1: "int32", t2: "varchar", expected: `[name:"id" type:varchar charset:255]`}, + {t1: "int32", t2: "decimal", expected: `[name:"id" type:decimal charset:63]`}, + {t1: "hexval", t2: "uint64", expected: `[name:"id" type:varchar charset:255]`}, + {t1: "varchar", t2: "varbinary", expected: `[name:"id" type:varbinary charset:63 flags:128]`}, } for _, test := range tests { @@ -196,8 +207,7 @@ func TestConcatenateTypes(t *testing.T) { res, err := concatenate.GetFields(context.Background(), &noopVCursor{}, nil) require.NoError(t, err) - expected := fmt.Sprintf(`[name:"id" type:%s]`, test.expected) - assert.Equal(t, expected, strings.ToLower(fmt.Sprintf("%v", res.Fields))) + assert.Equal(t, test.expected, strings.ToLower(fmt.Sprintf("%v", res.Fields))) }) } } diff --git a/go/vt/vtgate/engine/distinct.go b/go/vt/vtgate/engine/distinct.go index e292d516d51..c47cf6be8d1 100644 --- a/go/vt/vtgate/engine/distinct.go +++ b/go/vt/vtgate/engine/distinct.go @@ -26,6 +26,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/vthash" ) // Distinct Primitive is used to uniqueify results @@ -45,127 +46,55 @@ type ( CollationEnv *collations.Environment } probeTable struct { - seenRows map[evalengine.HashCode][]sqltypes.Row + seenRows map[vthash.Hash]struct{} checkCols []CheckCol sqlmode evalengine.SQLMode collationEnv *collations.Environment } ) -func (pt *probeTable) exists(inputRow sqltypes.Row) (bool, error) { - // the two prime numbers used here (17 and 31) are used to - // calculate hashcode from all column values in the input sqltypes.Row +func (pt *probeTable) exists(inputRow sqltypes.Row) (sqltypes.Row, error) { code, err := pt.hashCodeForRow(inputRow) if err != nil { - return false, err - } - - existingRows, found := pt.seenRows[code] - if !found { - // nothing with this hash code found, we can be sure it's a not seen sqltypes.Row - pt.seenRows[code] = []sqltypes.Row{inputRow} - return false, nil + return nil, err } - // we found something in the map - still need to check all individual values - // so we don't just fall for a hash collision - for _, existingRow := range existingRows { - exists, err := pt.equal(existingRow, inputRow) - if err != nil { - return false, err - } - if exists { - return true, nil - } + if _, found := pt.seenRows[code]; found { + return nil, nil } - pt.seenRows[code] = append(existingRows, inputRow) - - return false, nil + pt.seenRows[code] = struct{}{} + return inputRow, nil } -func (pt *probeTable) hashCodeForRow(inputRow sqltypes.Row) (evalengine.HashCode, error) { - // Why use 17 and 31 in this method? - // Copied from an old usenet discussion on the topic: - // https://groups.google.com/g/comp.programming/c/HSurZEyrZ1E?pli=1#d887b5bdb2dac99d - // > It's a mixture of superstition and good sense. - // > Suppose the multiplier were 26, and consider - // > hashing a hundred-character string. How much influence does - // > the string's first character have on the final value of `h', - // > just before the mod operation? The first character's value - // > will have been multiplied by MULT 99 times, so if the arithmetic - // > were done in infinite precision the value would consist of some - // > jumble of bits followed by 99 low-order zero bits -- each time - // > you multiply by MULT you introduce another low-order zero, right? - // > The computer's finite arithmetic just chops away all the excess - // > high-order bits, so the first character's actual contribution to - // > `h' is ... precisely zero! The `h' value depends only on the - // > rightmost 32 string characters (assuming a 32-bit int), and even - // > then things are not wonderful: the first of those final 32 bytes - // > influences only the leftmost bit of `h' and has no effect on - // > the remaining 31. Clearly, an even-valued MULT is a poor idea. - // > - // > Need MULT be prime? Not as far as I know (I don't know - // > everything); any odd value ought to suffice. 31 may be attractive - // > because it is close to a power of two, and it may be easier for - // > the compiler to replace a possibly slow multiply instruction with - // > a shift and subtract (31*x == (x << 5) - x) on machines where it - // > makes a difference. Setting MULT one greater than a power of two - // > (e.g., 33) would also be easy to optimize, but might produce too - // > "simple" an arrangement: mostly a juxtaposition of two copies - // > of the original set of bits, with a little mixing in the middle. - // > So you want an odd MULT that has plenty of one-bits. - - code := evalengine.HashCode(17) +func (pt *probeTable) hashCodeForRow(inputRow sqltypes.Row) (vthash.Hash, error) { + hasher := vthash.New() for i, checkCol := range pt.checkCols { if i >= len(inputRow) { - return 0, vterrors.VT13001("index out of range in row when creating the DISTINCT hash code") + return vthash.Hash{}, vterrors.VT13001("index out of range in row when creating the DISTINCT hash code") } col := inputRow[checkCol.Col] - hashcode, err := evalengine.NullsafeHashcode(col, checkCol.Type.Collation(), col.Type(), pt.sqlmode) + err := evalengine.NullsafeHashcode128(&hasher, col, checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode) if err != nil { if err != evalengine.UnsupportedCollationHashError || checkCol.WsCol == nil { - return 0, err + return vthash.Hash{}, err } checkCol = checkCol.SwitchToWeightString() pt.checkCols[i] = checkCol - hashcode, err = evalengine.NullsafeHashcode(inputRow[checkCol.Col], checkCol.Type.Collation(), col.Type(), pt.sqlmode) + err = evalengine.NullsafeHashcode128(&hasher, inputRow[checkCol.Col], checkCol.Type.Collation(), checkCol.Type.Type(), pt.sqlmode) if err != nil { - return 0, err + return vthash.Hash{}, err } } - code = code*31 + hashcode - } - return code, nil -} - -func (pt *probeTable) equal(a, b sqltypes.Row) (bool, error) { - for i, checkCol := range pt.checkCols { - cmp, err := evalengine.NullsafeCompare(a[i], b[i], pt.collationEnv, checkCol.Type.Collation()) - if err != nil { - _, isCollErr := err.(evalengine.UnsupportedCollationError) - if !isCollErr || checkCol.WsCol == nil { - return false, err - } - checkCol = checkCol.SwitchToWeightString() - pt.checkCols[i] = checkCol - cmp, err = evalengine.NullsafeCompare(a[i], b[i], pt.collationEnv, checkCol.Type.Collation()) - if err != nil { - return false, err - } - } - if cmp != 0 { - return false, nil - } } - return true, nil + return hasher.Sum128(), nil } func newProbeTable(checkCols []CheckCol, collationEnv *collations.Environment) *probeTable { cols := make([]CheckCol, len(checkCols)) copy(cols, checkCols) return &probeTable{ - seenRows: map[evalengine.HashCode][]sqltypes.Row{}, + seenRows: make(map[vthash.Hash]struct{}), checkCols: cols, collationEnv: collationEnv, } @@ -186,12 +115,12 @@ func (d *Distinct) TryExecute(ctx context.Context, vcursor VCursor, bindVars map pt := newProbeTable(d.CheckCols, vcursor.Environment().CollationEnv()) for _, row := range input.Rows { - exists, err := pt.exists(row) + appendRow, err := pt.exists(row) if err != nil { return nil, err } - if !exists { - result.Rows = append(result.Rows, row) + if appendRow != nil { + result.Rows = append(result.Rows, appendRow) } } if d.Truncate > 0 { @@ -213,12 +142,12 @@ func (d *Distinct) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVa mu.Lock() defer mu.Unlock() for _, row := range input.Rows { - exists, err := pt.exists(row) + appendRow, err := pt.exists(row) if err != nil { return err } - if !exists { - result.Rows = append(result.Rows, row) + if appendRow != nil { + result.Rows = append(result.Rows, appendRow) } } return callback(result.Truncate(len(d.CheckCols))) @@ -289,7 +218,7 @@ func (cc CheckCol) SwitchToWeightString() CheckCol { func (cc CheckCol) String() string { var collation string - if sqltypes.IsText(cc.Type.Type()) && cc.Type.Collation() != collations.Unknown { + if cc.Type.Valid() && sqltypes.IsText(cc.Type.Type()) && cc.Type.Collation() != collations.Unknown { collation = ": " + cc.CollationEnv.LookupName(cc.Type.Collation()) } diff --git a/go/vt/vtgate/engine/distinct_test.go b/go/vt/vtgate/engine/distinct_test.go index 76e46496e21..cb414d8de28 100644 --- a/go/vt/vtgate/engine/distinct_test.go +++ b/go/vt/vtgate/engine/distinct_test.go @@ -189,6 +189,7 @@ func TestWeightStringFallBack(t *testing.T) { checkCols := []CheckCol{{ Col: 0, WsCol: &offsetOne, + Type: evalengine.NewType(sqltypes.VarBinary, collations.CollationBinaryID), }} input := r("myid|weightstring(myid)", "varchar|varbinary", @@ -213,5 +214,6 @@ func TestWeightStringFallBack(t *testing.T) { utils.MustMatch(t, []CheckCol{{ Col: 0, WsCol: &offsetOne, + Type: evalengine.NewType(sqltypes.VarBinary, collations.CollationBinaryID), }}, distinct.CheckCols, "checkCols should not be updated") } diff --git a/go/vt/vtgate/engine/insert_common.go b/go/vt/vtgate/engine/insert_common.go index 014fc7681c8..8a35732dff4 100644 --- a/go/vt/vtgate/engine/insert_common.go +++ b/go/vt/vtgate/engine/insert_common.go @@ -23,6 +23,7 @@ import ( "strconv" "strings" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/sqltypes" @@ -468,7 +469,7 @@ func shouldGenerate(v sqltypes.Value, sqlmode evalengine.SQLMode) bool { // Unless the NO_AUTO_VALUE_ON_ZERO sql mode is active in mysql, it also // treats 0 as a value that should generate a new sequence. - value, err := evalengine.CoerceTo(v, sqltypes.Uint64, sqlmode) + value, err := evalengine.CoerceTo(v, evalengine.NewType(sqltypes.Uint64, collations.CollationBinaryID), sqlmode) if err != nil { return false } diff --git a/go/vt/vtgate/evalengine/api_arithmetic.go b/go/vt/vtgate/evalengine/api_arithmetic.go deleted file mode 100644 index 4da7e3450a2..00000000000 --- a/go/vt/vtgate/evalengine/api_arithmetic.go +++ /dev/null @@ -1,140 +0,0 @@ -/* -Copyright 2023 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package evalengine - -import ( - "vitess.io/vitess/go/sqltypes" -) - -// evalengine represents a numeric value extracted from -// a Value, used for arithmetic operations. -var zeroBytes = []byte("0") - -// Add adds two values together -// if v1 or v2 is null, then it returns null -func Add(v1, v2 sqltypes.Value) (sqltypes.Value, error) { - if v1.IsNull() || v2.IsNull() { - return sqltypes.NULL, nil - } - e1, err := valueToEval(v1, collationNumeric) - if err != nil { - return sqltypes.NULL, err - } - e2, err := valueToEval(v2, collationNumeric) - if err != nil { - return sqltypes.NULL, err - } - r, err := addNumericWithError(e1, e2) - if err != nil { - return sqltypes.NULL, err - } - return evalToSQLValue(r), nil -} - -// Subtract takes two values and subtracts them -func Subtract(v1, v2 sqltypes.Value) (sqltypes.Value, error) { - if v1.IsNull() || v2.IsNull() { - return sqltypes.NULL, nil - } - e1, err := valueToEval(v1, collationNumeric) - if err != nil { - return sqltypes.NULL, err - } - e2, err := valueToEval(v2, collationNumeric) - if err != nil { - return sqltypes.NULL, err - } - r, err := subtractNumericWithError(e1, e2) - if err != nil { - return sqltypes.NULL, err - } - return evalToSQLValue(r), nil -} - -// Multiply takes two values and multiplies it together -func Multiply(v1, v2 sqltypes.Value) (sqltypes.Value, error) { - if v1.IsNull() || v2.IsNull() { - return sqltypes.NULL, nil - } - e1, err := valueToEval(v1, collationNumeric) - if err != nil { - return sqltypes.NULL, err - } - e2, err := valueToEval(v2, collationNumeric) - if err != nil { - return sqltypes.NULL, err - } - r, err := multiplyNumericWithError(e1, e2) - if err != nil { - return sqltypes.NULL, err - } - return evalToSQLValue(r), nil -} - -// Divide (Float) for MySQL. Replicates behavior of "/" operator -func Divide(v1, v2 sqltypes.Value) (sqltypes.Value, error) { - if v1.IsNull() || v2.IsNull() { - return sqltypes.NULL, nil - } - e1, err := valueToEval(v1, collationNumeric) - if err != nil { - return sqltypes.NULL, err - } - e2, err := valueToEval(v2, collationNumeric) - if err != nil { - return sqltypes.NULL, err - } - r, err := divideNumericWithError(e1, e2, true) - if err != nil { - return sqltypes.NULL, err - } - return evalToSQLValue(r), nil -} - -// NullSafeAdd adds two Values in a null-safe manner. A null value -// is treated as 0. If both values are null, then a null is returned. -// If both values are not null, a numeric value is built -// from each input: Signed->int64, Unsigned->uint64, Float->float64. -// Otherwise the 'best type fit' is chosen for the number: int64 or float64. -// opArithAdd is performed by upgrading types as needed, or in case -// of overflow: int64->uint64, int64->float64, uint64->float64. -// Unsigned ints can only be added to positive ints. After the -// addition, if one of the input types was Decimal, then -// a Decimal is built. Otherwise, the final type of the -// result is preserved. -func NullSafeAdd(v1, v2 sqltypes.Value, resultType sqltypes.Type) (sqltypes.Value, error) { - if v1.IsNull() { - v1 = sqltypes.MakeTrusted(resultType, zeroBytes) - } - if v2.IsNull() { - v2 = sqltypes.MakeTrusted(resultType, zeroBytes) - } - - e1, err := valueToEval(v1, collationNumeric) - if err != nil { - return sqltypes.NULL, err - } - e2, err := valueToEval(v2, collationNumeric) - if err != nil { - return sqltypes.NULL, err - } - r, err := addNumericWithError(e1, e2) - if err != nil { - return sqltypes.NULL, err - } - return evalToSQLValueWithType(r, resultType), nil -} diff --git a/go/vt/vtgate/evalengine/api_arithmetic_test.go b/go/vt/vtgate/evalengine/api_arithmetic_test.go index 37f79d08c6c..c0a68de8f83 100644 --- a/go/vt/vtgate/evalengine/api_arithmetic_test.go +++ b/go/vt/vtgate/evalengine/api_arithmetic_test.go @@ -17,548 +17,27 @@ limitations under the License. package evalengine import ( - "encoding/binary" "fmt" - "math" "reflect" - "strconv" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/test/utils" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vthash" ) var ( NULL = sqltypes.NULL - NewInt32 = sqltypes.NewInt32 NewInt64 = sqltypes.NewInt64 NewUint64 = sqltypes.NewUint64 NewFloat64 = sqltypes.NewFloat64 TestValue = sqltypes.TestValue - NewDecimal = sqltypes.NewDecimal - - maxUint64 uint64 = math.MaxUint64 ) -func TestArithmetics(t *testing.T) { - type tcase struct { - v1, v2, out sqltypes.Value - err string - } - - tests := []struct { - operator string - f func(a, b sqltypes.Value) (sqltypes.Value, error) - cases []tcase - }{{ - operator: "-", - f: Subtract, - cases: []tcase{{ - // All Nulls - v1: NULL, - v2: NULL, - out: NULL, - }, { - // First value null. - v1: NewInt32(1), - v2: NULL, - out: NULL, - }, { - // Second value null. - v1: NULL, - v2: NewInt32(1), - out: NULL, - }, { - // case with negative value - v1: NewInt64(-1), - v2: NewInt64(-2), - out: NewInt64(1), - }, { - // testing for int64 overflow with min negative value - v1: NewInt64(math.MinInt64), - v2: NewInt64(1), - err: dataOutOfRangeError(int64(math.MinInt64), int64(1), "BIGINT", "-").Error(), - }, { - v1: NewUint64(4), - v2: NewInt64(5), - err: dataOutOfRangeError(uint64(4), int64(5), "BIGINT UNSIGNED", "-").Error(), - }, { - // testing uint - int - v1: NewUint64(7), - v2: NewInt64(5), - out: NewUint64(2), - }, { - v1: NewUint64(math.MaxUint64), - v2: NewInt64(0), - out: NewUint64(math.MaxUint64), - }, { - // testing for int64 overflow - v1: NewInt64(math.MinInt64), - v2: NewUint64(0), - err: dataOutOfRangeError(int64(math.MinInt64), uint64(0), "BIGINT UNSIGNED", "-").Error(), - }, { - v1: TestValue(sqltypes.VarChar, "c"), - v2: NewInt64(1), - out: NewFloat64(-1), - }, { - v1: NewUint64(1), - v2: TestValue(sqltypes.VarChar, "c"), - out: NewFloat64(1), - }, { - // testing for error for parsing float value to uint64 - v1: TestValue(sqltypes.Uint64, "1.2"), - v2: NewInt64(2), - err: "unparsed tail left after parsing uint64 from \"1.2\": \".2\"", - }, { - // testing for error for parsing float value to uint64 - v1: NewUint64(2), - v2: TestValue(sqltypes.Uint64, "1.2"), - err: "unparsed tail left after parsing uint64 from \"1.2\": \".2\"", - }, { - // uint64 - uint64 - v1: NewUint64(8), - v2: NewUint64(4), - out: NewUint64(4), - }, { - // testing for float subtraction: float - int - v1: NewFloat64(1.2), - v2: NewInt64(2), - out: NewFloat64(-0.8), - }, { - // testing for float subtraction: float - uint - v1: NewFloat64(1.2), - v2: NewUint64(2), - out: NewFloat64(-0.8), - }, { - v1: NewInt64(-1), - v2: NewUint64(2), - err: dataOutOfRangeError(int64(-1), int64(2), "BIGINT UNSIGNED", "-").Error(), - }, { - v1: NewInt64(2), - v2: NewUint64(1), - out: NewUint64(1), - }, { - // testing int64 - float64 method - v1: NewInt64(-2), - v2: NewFloat64(1.0), - out: NewFloat64(-3.0), - }, { - // testing uint64 - float64 method - v1: NewUint64(1), - v2: NewFloat64(-2.0), - out: NewFloat64(3.0), - }, { - // testing uint - int to return uintplusint - v1: NewUint64(1), - v2: NewInt64(-2), - out: NewUint64(3), - }, { - // testing for float - float - v1: NewFloat64(1.2), - v2: NewFloat64(3.2), - out: NewFloat64(-2), - }, { - // testing uint - uint if v2 > v1 - v1: NewUint64(2), - v2: NewUint64(4), - err: dataOutOfRangeError(uint64(2), uint64(4), "BIGINT UNSIGNED", "-").Error(), - }, { - // testing uint - (- int) - v1: NewUint64(1), - v2: NewInt64(-2), - out: NewUint64(3), - }}, - }, { - operator: "+", - f: Add, - cases: []tcase{{ - // All Nulls - v1: NULL, - v2: NULL, - out: NULL, - }, { - // First value null. - v1: NewInt32(1), - v2: NULL, - out: NULL, - }, { - // Second value null. - v1: NULL, - v2: NewInt32(1), - out: NULL, - }, { - // case with negatives - v1: NewInt64(-1), - v2: NewInt64(-2), - out: NewInt64(-3), - }, { - // testing for overflow int64, result will be unsigned int - v1: NewInt64(math.MaxInt64), - v2: NewUint64(2), - out: NewUint64(9223372036854775809), - }, { - v1: NewInt64(-2), - v2: NewUint64(1), - err: dataOutOfRangeError(uint64(1), int64(-2), "BIGINT UNSIGNED", "+").Error(), - }, { - v1: NewInt64(math.MaxInt64), - v2: NewInt64(-2), - out: NewInt64(9223372036854775805), - }, { - // Normal case - v1: NewUint64(1), - v2: NewUint64(2), - out: NewUint64(3), - }, { - // testing for overflow uint64 - v1: NewUint64(maxUint64), - v2: NewUint64(2), - err: dataOutOfRangeError(maxUint64, uint64(2), "BIGINT UNSIGNED", "+").Error(), - }, { - // int64 underflow - v1: NewInt64(math.MinInt64), - v2: NewInt64(-2), - err: dataOutOfRangeError(int64(math.MinInt64), int64(-2), "BIGINT", "+").Error(), - }, { - // checking int64 max value can be returned - v1: NewInt64(math.MaxInt64), - v2: NewUint64(0), - out: NewUint64(9223372036854775807), - }, { - // testing whether uint64 max value can be returned - v1: NewUint64(math.MaxUint64), - v2: NewInt64(0), - out: NewUint64(math.MaxUint64), - }, { - v1: NewUint64(math.MaxInt64), - v2: NewInt64(1), - out: NewUint64(9223372036854775808), - }, { - v1: NewUint64(1), - v2: TestValue(sqltypes.VarChar, "c"), - out: NewFloat64(1), - }, { - v1: NewUint64(1), - v2: TestValue(sqltypes.VarChar, "1.2"), - out: NewFloat64(2.2), - }, { - v1: TestValue(sqltypes.Int64, "1.2"), - v2: NewInt64(2), - err: "unparsed tail left after parsing int64 from \"1.2\": \".2\"", - }, { - v1: NewInt64(2), - v2: TestValue(sqltypes.Int64, "1.2"), - err: "unparsed tail left after parsing int64 from \"1.2\": \".2\"", - }, { - // testing for uint64 overflow with max uint64 + int value - v1: NewUint64(maxUint64), - v2: NewInt64(2), - err: dataOutOfRangeError(maxUint64, int64(2), "BIGINT UNSIGNED", "+").Error(), - }, { - v1: sqltypes.NewHexNum([]byte("0x9")), - v2: NewInt64(1), - out: NewUint64(10), - }}, - }, { - operator: "/", - f: Divide, - cases: []tcase{{ - // All Nulls - v1: NULL, - v2: NULL, - out: NULL, - }, { - // First value null. - v1: NULL, - v2: NewInt32(1), - out: NULL, - }, { - // Second value null. - v1: NewInt32(1), - v2: NULL, - out: NULL, - }, { - // Second arg 0 - v1: NewInt32(5), - v2: NewInt32(0), - out: NULL, - }, { - // Both arguments zero - v1: NewInt32(0), - v2: NewInt32(0), - out: NULL, - }, { - // case with negative value - v1: NewInt64(-1), - v2: NewInt64(-2), - out: NewDecimal("0.5000"), - }, { - // float64 division by zero - v1: NewFloat64(2), - v2: NewFloat64(0), - out: NULL, - }, { - // Lower bound for int64 - v1: NewInt64(math.MinInt64), - v2: NewInt64(1), - out: NewDecimal(strconv.FormatInt(math.MinInt64, 10) + ".0000"), - }, { - // upper bound for uint64 - v1: NewUint64(math.MaxUint64), - v2: NewUint64(1), - out: NewDecimal(strconv.FormatUint(math.MaxUint64, 10) + ".0000"), - }, { - // testing for error in types - v1: TestValue(sqltypes.Int64, "1.2"), - v2: NewInt64(2), - err: "unparsed tail left after parsing int64 from \"1.2\": \".2\"", - }, { - // testing for error in types - v1: NewInt64(2), - v2: TestValue(sqltypes.Int64, "1.2"), - err: "unparsed tail left after parsing int64 from \"1.2\": \".2\"", - }, { - // testing for uint/int - v1: NewUint64(4), - v2: NewInt64(5), - out: NewDecimal("0.8000"), - }, { - // testing for uint/uint - v1: NewUint64(1), - v2: NewUint64(2), - out: NewDecimal("0.5000"), - }, { - // testing for float64/int64 - v1: TestValue(sqltypes.Float64, "1.2"), - v2: NewInt64(-2), - out: NewFloat64(-0.6), - }, { - // testing for float64/uint64 - v1: TestValue(sqltypes.Float64, "1.2"), - v2: NewUint64(2), - out: NewFloat64(0.6), - }, { - // testing for overflow of float64 - v1: NewFloat64(math.MaxFloat64), - v2: NewFloat64(0.5), - err: dataOutOfRangeError(math.MaxFloat64, 0.5, "DOUBLE", "/").Error(), - }}, - }, { - operator: "*", - f: Multiply, - cases: []tcase{{ - // All Nulls - v1: NULL, - v2: NULL, - out: NULL, - }, { - // First value null. - v1: NewInt32(1), - v2: NULL, - out: NULL, - }, { - // Second value null. - v1: NULL, - v2: NewInt32(1), - out: NULL, - }, { - // case with negative value - v1: NewInt64(-1), - v2: NewInt64(-2), - out: NewInt64(2), - }, { - // testing for int64 overflow with min negative value - v1: NewInt64(math.MinInt64), - v2: NewInt64(1), - out: NewInt64(math.MinInt64), - }, { - // testing for error in types - v1: TestValue(sqltypes.Int64, "1.2"), - v2: NewInt64(2), - err: "unparsed tail left after parsing int64 from \"1.2\": \".2\"", - }, { - // testing for error in types - v1: NewInt64(2), - v2: TestValue(sqltypes.Int64, "1.2"), - err: "unparsed tail left after parsing int64 from \"1.2\": \".2\"", - }, { - // testing for uint*int - v1: NewUint64(4), - v2: NewInt64(5), - out: NewUint64(20), - }, { - // testing for uint*uint - v1: NewUint64(1), - v2: NewUint64(2), - out: NewUint64(2), - }, { - // testing for float64*int64 - v1: TestValue(sqltypes.Float64, "1.2"), - v2: NewInt64(-2), - out: NewFloat64(-2.4), - }, { - // testing for float64*uint64 - v1: TestValue(sqltypes.Float64, "1.2"), - v2: NewUint64(2), - out: NewFloat64(2.4), - }, { - // testing for overflow of int64 - v1: NewInt64(math.MaxInt64), - v2: NewInt64(2), - err: dataOutOfRangeError(int64(math.MaxInt64), int64(2), "BIGINT", "*").Error(), - }, { - // testing for underflow of uint64*max.uint64 - v1: NewInt64(2), - v2: NewUint64(maxUint64), - err: dataOutOfRangeError(maxUint64, int64(2), "BIGINT UNSIGNED", "*").Error(), - }, { - v1: NewUint64(math.MaxUint64), - v2: NewUint64(1), - out: NewUint64(math.MaxUint64), - }, { - // Checking whether maxInt value can be passed as uint value - v1: NewUint64(math.MaxInt64), - v2: NewInt64(3), - err: dataOutOfRangeError(uint64(math.MaxInt64), int64(3), "BIGINT UNSIGNED", "*").Error(), - }}, - }} - - for _, test := range tests { - t.Run(test.operator, func(t *testing.T) { - for _, tcase := range test.cases { - name := fmt.Sprintf("%s%s%s", tcase.v1.String(), test.operator, tcase.v2.String()) - t.Run(name, func(t *testing.T) { - got, err := test.f(tcase.v1, tcase.v2) - if tcase.err == "" { - require.NoError(t, err) - require.Equal(t, tcase.out, got) - } else { - require.EqualError(t, err, tcase.err) - } - }) - } - }) - } -} - -func TestNullSafeAdd(t *testing.T) { - tcases := []struct { - v1, v2 sqltypes.Value - out sqltypes.Value - err error - }{{ - // All nulls. - v1: NULL, - v2: NULL, - out: NewInt64(0), - }, { - // First value null. - v1: NewInt32(1), - v2: NULL, - out: NewInt64(1), - }, { - // Second value null. - v1: NULL, - v2: NewInt32(1), - out: NewInt64(1), - }, { - // Normal case. - v1: NewInt64(1), - v2: NewInt64(2), - out: NewInt64(3), - }, { - // Make sure underlying error is returned for LHS. - v1: TestValue(sqltypes.Int64, "1.2"), - v2: NewInt64(2), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unparsed tail left after parsing int64 from \"1.2\": \".2\""), - }, { - // Make sure underlying error is returned for RHS. - v1: NewInt64(2), - v2: TestValue(sqltypes.Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unparsed tail left after parsing int64 from \"1.2\": \".2\""), - }, { - // Make sure underlying error is returned while adding. - v1: NewInt64(-1), - v2: NewUint64(2), - out: NewInt64(1), - }, { - v1: NewInt64(-100), - v2: NewUint64(10), - err: dataOutOfRangeError(uint64(10), int64(-100), "BIGINT UNSIGNED", "+"), - }, { - // Make sure underlying error is returned while converting. - v1: NewFloat64(1), - v2: NewFloat64(2), - out: NewInt64(3), - }} - for _, tcase := range tcases { - got, err := NullSafeAdd(tcase.v1, tcase.v2, sqltypes.Int64) - - if tcase.err == nil { - require.NoError(t, err) - } else { - require.EqualError(t, err, tcase.err.Error()) - } - - if !reflect.DeepEqual(got, tcase.out) { - t.Errorf("NullSafeAdd(%v, %v): %v, want %v", printValue(tcase.v1), printValue(tcase.v2), printValue(got), printValue(tcase.out)) - } - } -} - -func TestNewIntegralNumeric(t *testing.T) { - tcases := []struct { - v sqltypes.Value - out eval - err error - }{{ - v: NewInt64(1), - out: newEvalInt64(1), - }, { - v: NewUint64(1), - out: newEvalUint64(1), - }, { - v: NewFloat64(1), - out: newEvalInt64(1), - }, { - // For non-number type, Int64 is the default. - v: TestValue(sqltypes.VarChar, "1"), - out: newEvalInt64(1), - }, { - // If Int64 can't work, we use Uint64. - v: TestValue(sqltypes.VarChar, "18446744073709551615"), - out: newEvalUint64(18446744073709551615), - }, { - // Only valid Int64 allowed if type is Int64. - v: TestValue(sqltypes.Int64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unparsed tail left after parsing int64 from \"1.2\": \".2\""), - }, { - // Only valid Uint64 allowed if type is Uint64. - v: TestValue(sqltypes.Uint64, "1.2"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "unparsed tail left after parsing uint64 from \"1.2\": \".2\""), - }, { - v: TestValue(sqltypes.VarChar, "abcd"), - err: vterrors.New(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: 'abcd'"), - }} - for _, tcase := range tcases { - got, err := valueToEvalNumeric(tcase.v) - if err != nil && !vterrors.Equals(err, tcase.err) { - t.Errorf("newIntegralNumeric(%s) error: %v, want %v", printValue(tcase.v), vterrors.Print(err), vterrors.Print(tcase.err)) - } - if tcase.err == nil { - continue - } - - utils.MustMatch(t, tcase.out, got, "newIntegralNumeric") - } -} - func TestAddNumeric(t *testing.T) { tcases := []struct { v1, v2 eval @@ -684,75 +163,79 @@ func TestPrioritize(t *testing.T) { } func TestToSqlValue(t *testing.T) { + nt := func(t sqltypes.Type) Type { + return NewType(t, collations.CollationBinaryID) + } + tcases := []struct { - typ sqltypes.Type + typ Type v eval out sqltypes.Value err error }{{ - typ: sqltypes.Int64, + typ: nt(sqltypes.Int64), v: newEvalInt64(1), out: NewInt64(1), }, { - typ: sqltypes.Int64, + typ: nt(sqltypes.Int64), v: newEvalUint64(1), out: NewInt64(1), }, { - typ: sqltypes.Int64, + typ: nt(sqltypes.Int64), v: newEvalFloat(1.2e-16), out: NewInt64(0), }, { - typ: sqltypes.Uint64, + typ: nt(sqltypes.Uint64), v: newEvalInt64(1), out: NewUint64(1), }, { - typ: sqltypes.Uint64, + typ: nt(sqltypes.Uint64), v: newEvalUint64(1), out: NewUint64(1), }, { - typ: sqltypes.Uint64, + typ: nt(sqltypes.Uint64), v: newEvalFloat(1.2e-16), out: NewUint64(0), }, { - typ: sqltypes.Float64, + typ: nt(sqltypes.Float64), v: newEvalInt64(1), out: TestValue(sqltypes.Float64, "1"), }, { - typ: sqltypes.Float64, + typ: nt(sqltypes.Float64), v: newEvalUint64(1), out: TestValue(sqltypes.Float64, "1"), }, { - typ: sqltypes.Float64, + typ: nt(sqltypes.Float64), v: newEvalFloat(1.2e-16), out: TestValue(sqltypes.Float64, "1.2e-16"), }, { - typ: sqltypes.Decimal, + typ: nt(sqltypes.Decimal), v: newEvalInt64(1), out: TestValue(sqltypes.Decimal, "1"), }, { - typ: sqltypes.Decimal, + typ: nt(sqltypes.Decimal), v: newEvalUint64(1), out: TestValue(sqltypes.Decimal, "1"), }, { // For float, we should not use scientific notation. - typ: sqltypes.Decimal, + typ: nt(sqltypes.Decimal), v: newEvalFloat(1.2e-16), out: TestValue(sqltypes.Decimal, "0.00000000000000012"), }, { // null in should return null out no matter what type - typ: sqltypes.Int64, + typ: nt(sqltypes.Int64), v: nil, out: sqltypes.NULL, }, { - typ: sqltypes.Uint64, + typ: nt(sqltypes.Uint64), v: nil, out: sqltypes.NULL, }, { - typ: sqltypes.Float64, + typ: nt(sqltypes.Float64), v: nil, out: sqltypes.NULL, }, { - typ: sqltypes.VarChar, + typ: nt(sqltypes.VarChar), v: nil, out: sqltypes.NULL, }} @@ -823,71 +306,3 @@ func printValue(v sqltypes.Value) string { vBytes, _ := v.ToBytes() return fmt.Sprintf("%v:%q", v.Type(), vBytes) } - -// These benchmarks show that using existing ASCII representations -// for numbers is about 6x slower than using native representations. -// However, 229ns is still a negligible time compared to the cost of -// other operations. The additional complexity of introducing native -// types is currently not worth it. So, we'll stay with the existing -// ASCII representation for now. Using interfaces is more expensive -// than native representation of values. This is probably because -// interfaces also allocate memory, and also perform type assertions. -// Actual benchmark is based on NoNative. So, the numbers are similar. -// Date: 6/4/17 -// Version: go1.8 -// BenchmarkAddActual-8 10000000 263 ns/op -// BenchmarkAddNoNative-8 10000000 228 ns/op -// BenchmarkAddNative-8 50000000 40.0 ns/op -// BenchmarkAddGoInterface-8 30000000 52.4 ns/op -// BenchmarkAddGoNonInterface-8 2000000000 1.00 ns/op -// BenchmarkAddGo-8 2000000000 1.00 ns/op -func BenchmarkAddActual(b *testing.B) { - v1 := sqltypes.MakeTrusted(sqltypes.Int64, []byte("1")) - v2 := sqltypes.MakeTrusted(sqltypes.Int64, []byte("12")) - for i := 0; i < b.N; i++ { - v1, _ = NullSafeAdd(v1, v2, sqltypes.Int64) - } -} - -func BenchmarkAddNoNative(b *testing.B) { - v1 := sqltypes.MakeTrusted(sqltypes.Int64, []byte("1")) - v2 := sqltypes.MakeTrusted(sqltypes.Int64, []byte("12")) - for i := 0; i < b.N; i++ { - iv1, _ := v1.ToInt64() - iv2, _ := v2.ToInt64() - v1 = sqltypes.MakeTrusted(sqltypes.Int64, strconv.AppendInt(nil, iv1+iv2, 10)) - } -} - -func BenchmarkAddNative(b *testing.B) { - v1 := makeNativeInt64(1) - v2 := makeNativeInt64(12) - for i := 0; i < b.N; i++ { - iv1 := int64(binary.BigEndian.Uint64(v1.Raw())) - iv2 := int64(binary.BigEndian.Uint64(v2.Raw())) - v1 = makeNativeInt64(iv1 + iv2) - } -} - -func makeNativeInt64(v int64) sqltypes.Value { - buf := make([]byte, 8) - binary.BigEndian.PutUint64(buf, uint64(v)) - return sqltypes.MakeTrusted(sqltypes.Int64, buf) -} - -func BenchmarkAddGoInterface(b *testing.B) { - var v1, v2 any - v1 = int64(1) - v2 = int64(2) - for i := 0; i < b.N; i++ { - v1 = v1.(int64) + v2.(int64) - } -} - -func BenchmarkAddGo(b *testing.B) { - v1 := int64(1) - v2 := int64(2) - for i := 0; i < b.N; i++ { - v1 += v2 - } -} diff --git a/go/vt/vtgate/evalengine/api_coerce.go b/go/vt/vtgate/evalengine/api_coerce.go index 2730cedff07..907c578df8a 100644 --- a/go/vt/vtgate/evalengine/api_coerce.go +++ b/go/vt/vtgate/evalengine/api_coerce.go @@ -23,7 +23,7 @@ import ( "vitess.io/vitess/go/vt/vterrors" ) -func CoerceTo(value sqltypes.Value, typ sqltypes.Type, sqlmode SQLMode) (sqltypes.Value, error) { +func CoerceTo(value sqltypes.Value, typ Type, sqlmode SQLMode) (sqltypes.Value, error) { cast, err := valueToEvalCast(value, value.Type(), collations.Unknown, sqlmode) if err != nil { return sqltypes.Value{}, err diff --git a/go/vt/vtgate/evalengine/api_hash.go b/go/vt/vtgate/evalengine/api_hash.go index 3bce100839c..2d3bc2d3b56 100644 --- a/go/vt/vtgate/evalengine/api_hash.go +++ b/go/vt/vtgate/evalengine/api_hash.go @@ -199,7 +199,7 @@ func NullsafeHashcode128(hash *vthash.Hasher, v sqltypes.Value, collation collat case sqltypes.IsText(coerceTo): coll := colldata.Lookup(collation) if coll == nil { - panic("cannot hash unsupported collation") + return UnsupportedCollationHashError } hash.Write16(hashPrefixBytes) coll.Hash(hash, v.Raw(), 0) diff --git a/go/vt/vtgate/evalengine/api_type_aggregation.go b/go/vt/vtgate/evalengine/api_type_aggregation.go index 83703d4532c..cb2b646fa67 100644 --- a/go/vt/vtgate/evalengine/api_type_aggregation.go +++ b/go/vt/vtgate/evalengine/api_type_aggregation.go @@ -19,6 +19,7 @@ package evalengine import ( "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/proto/query" ) type typeAggregation struct { @@ -49,27 +50,46 @@ type typeAggregation struct { nullable bool } -func AggregateEvalTypes(types []Type, env *collations.Environment) (Type, error) { - var typeAgg typeAggregation - var collAgg collationAggregation - var size, scale int32 - for _, typ := range types { - typeAgg.addNullable(typ.typ, typ.nullable) - if err := collAgg.add(typedCoercionCollation(typ.typ, typ.collation), env); err != nil { - return Type{}, err - } - size = max(typ.size, size) - scale = max(typ.scale, scale) +type TypeAggregator struct { + types typeAggregation + collations collationAggregation + size, scale int32 + invalid int32 +} + +func (ta *TypeAggregator) Add(typ Type, env *collations.Environment) error { + if !typ.Valid() { + ta.invalid++ + return nil } - return NewTypeEx(typeAgg.result(), collAgg.result().Collation, typeAgg.nullable, size, scale), nil + + ta.types.addNullable(typ.typ, typ.nullable) + if err := ta.collations.add(typedCoercionCollation(typ.typ, typ.collation), env); err != nil { + return err + } + ta.size = max(typ.size, ta.size) + ta.scale = max(typ.scale, ta.scale) + return nil } -func AggregateTypes(types []sqltypes.Type) sqltypes.Type { - var typeAgg typeAggregation - for _, typ := range types { - typeAgg.addNullable(typ, false) +func (ta *TypeAggregator) AddField(f *query.Field, env *collations.Environment) error { + return ta.Add(NewTypeFromField(f), env) +} + +func (ta *TypeAggregator) Type() Type { + if ta.invalid > 0 || ta.types.empty() { + return Type{} } - return typeAgg.result() + return NewTypeEx(ta.types.result(), ta.collations.result().Collation, ta.types.nullable, ta.size, ta.scale) +} + +func (ta *TypeAggregator) Field(name string) *query.Field { + typ := ta.Type() + return typ.ToField(name) +} + +func (ta *typeAggregation) empty() bool { + return ta.total == 0 } func (ta *typeAggregation) addEval(e eval) { diff --git a/go/vt/vtgate/evalengine/api_type_aggregation_test.go b/go/vt/vtgate/evalengine/api_type_aggregation_test.go index 1bf29eaffb3..257653553bd 100644 --- a/go/vt/vtgate/evalengine/api_type_aggregation_test.go +++ b/go/vt/vtgate/evalengine/api_type_aggregation_test.go @@ -51,28 +51,21 @@ var aggregationCases = []struct { {[]sqltypes.Type{sqltypes.Geometry, sqltypes.Geometry}, sqltypes.Geometry}, } -func TestTypeAggregations(t *testing.T) { - for i, tc := range aggregationCases { - t.Run(fmt.Sprintf("%d.%v", i, tc.result), func(t *testing.T) { - res := AggregateTypes(tc.types) - require.Equalf(t, tc.result, res, "expected aggregate(%v) = %v, got %v", tc.types, tc.result, res) - }) - } -} - func TestEvalengineTypeAggregations(t *testing.T) { for i, tc := range aggregationCases { t.Run(fmt.Sprintf("%d.%v", i, tc.result), func(t *testing.T) { - var types []Type + var typer TypeAggregator + for _, tt := range tc.types { // this test only aggregates binary collations because textual collation // aggregation is tested in the `mysql/collations` package - types = append(types, NewType(tt, collations.CollationBinaryID)) + + err := typer.Add(NewType(tt, collations.CollationBinaryID), collations.MySQL8()) + require.NoError(t, err) } - res, err := AggregateEvalTypes(types, collations.MySQL8()) - require.NoError(t, err) - require.Equalf(t, tc.result, res.Type(), "expected aggregate(%v) = %v, got %v", tc.types, tc.result, res) + res := typer.Type() + require.Equalf(t, tc.result, res.Type(), "expected aggregate(%v) = %v, got %v", tc.types, tc.result, res.Type()) }) } } diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index 8c5700d751d..c0b628b1aa8 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -22,6 +22,7 @@ import ( "vitess.io/vitess/go/mysql/collations/colldata" "vitess.io/vitess/go/mysql/json" "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtenv" @@ -80,6 +81,38 @@ func NewTypeEx(t sqltypes.Type, collation collations.ID, nullable bool, size, sc } } +func NewTypeFromField(f *querypb.Field) Type { + return Type{ + typ: f.Type, + collation: collations.ID(f.Charset), + nullable: f.Flags&uint32(querypb.MySqlFlag_NOT_NULL_FLAG) == 0, + init: true, + size: int32(f.ColumnLength), + scale: int32(f.Decimals), + } +} + +func (t *Type) ToField(name string) *querypb.Field { + // need to get the proper flags for the type; usually leaving flags + // to 0 is OK, because Vitess' MySQL client will generate the right + // ones for the column's type, but here we're also setting the NotNull + // flag, so it needs to be set with the full flags for the column + _, flags := sqltypes.TypeToMySQL(t.typ) + if !t.nullable { + flags |= int64(querypb.MySqlFlag_NOT_NULL_FLAG) + } + + f := &querypb.Field{ + Name: name, + Type: t.typ, + Charset: uint32(t.collation), + ColumnLength: uint32(t.size), + Decimals: uint32(t.scale), + Flags: uint32(flags), + } + return f +} + func (t *Type) Type() sqltypes.Type { if t.init { return t.typ diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index 86d3c949b4d..36ce482d967 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -21,7 +21,6 @@ import ( "time" "unicode/utf8" - "vitess.io/vitess/go/hack" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/decimal" "vitess.io/vitess/go/mysql/fastparse" @@ -87,50 +86,45 @@ func evalToSQLValue(e eval) sqltypes.Value { return sqltypes.MakeTrusted(e.SQLType(), e.ToRawBytes()) } -func evalToSQLValueWithType(e eval, resultType sqltypes.Type) sqltypes.Value { +func evalToSQLValueWithType(e eval, resultType Type) sqltypes.Value { + tt := resultType.Type() switch { - case sqltypes.IsSigned(resultType): + case sqltypes.IsSigned(tt): switch e := e.(type) { case *evalInt64: - return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, e.i, 10)) + return sqltypes.MakeTrusted(tt, strconv.AppendInt(nil, e.i, 10)) case *evalUint64: - return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, e.u, 10)) + return sqltypes.MakeTrusted(tt, strconv.AppendUint(nil, e.u, 10)) case *evalFloat: - return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, int64(e.f), 10)) + return sqltypes.MakeTrusted(tt, strconv.AppendInt(nil, int64(e.f), 10)) } - case sqltypes.IsUnsigned(resultType): + case sqltypes.IsUnsigned(tt): switch e := e.(type) { case *evalInt64: - return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, uint64(e.i), 10)) + return sqltypes.MakeTrusted(tt, strconv.AppendUint(nil, uint64(e.i), 10)) case *evalUint64: - return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, e.u, 10)) + return sqltypes.MakeTrusted(tt, strconv.AppendUint(nil, e.u, 10)) case *evalFloat: - return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, uint64(e.f), 10)) + return sqltypes.MakeTrusted(tt, strconv.AppendUint(nil, uint64(e.f), 10)) } - case sqltypes.IsFloat(resultType): + case sqltypes.IsFloat(tt): switch e := e.(type) { case *evalInt64: - return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, e.i, 10)) + return sqltypes.MakeTrusted(tt, strconv.AppendInt(nil, e.i, 10)) case *evalUint64: - return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, e.u, 10)) + return sqltypes.MakeTrusted(tt, strconv.AppendUint(nil, e.u, 10)) case *evalFloat: - return sqltypes.MakeTrusted(resultType, format.FormatFloat(e.f)) + return sqltypes.MakeTrusted(tt, format.FormatFloat(e.f)) case *evalDecimal: - return sqltypes.MakeTrusted(resultType, e.dec.FormatMySQL(e.length)) + return sqltypes.MakeTrusted(tt, e.dec.FormatMySQL(e.length)) } - case sqltypes.IsDecimal(resultType): - switch e := e.(type) { - case *evalInt64: - return sqltypes.MakeTrusted(resultType, strconv.AppendInt(nil, e.i, 10)) - case *evalUint64: - return sqltypes.MakeTrusted(resultType, strconv.AppendUint(nil, e.u, 10)) - case *evalFloat: - return sqltypes.MakeTrusted(resultType, hack.StringBytes(strconv.FormatFloat(e.f, 'f', -1, 64))) - case *evalDecimal: - return sqltypes.MakeTrusted(resultType, e.dec.FormatMySQL(e.length)) + case sqltypes.IsDecimal(tt): + if numeric, ok := e.(evalNumeric); ok { + dec := numeric.toDecimal(resultType.size, resultType.scale) + return sqltypes.MakeTrusted(tt, dec.dec.FormatMySQL(dec.length)) } case e != nil: - return sqltypes.MakeTrusted(resultType, e.ToRawBytes()) + return sqltypes.MakeTrusted(tt, e.ToRawBytes()) } return sqltypes.NULL } @@ -369,34 +363,6 @@ func valueToEvalCast(v sqltypes.Value, typ sqltypes.Type, collation collations.I return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "coercion should not try to coerce this value: %v", v) } -func valueToEvalNumeric(v sqltypes.Value) (eval, error) { - switch { - case v.IsSigned(): - ival, err := v.ToInt64() - if err != nil { - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) - } - return &evalInt64{i: ival}, nil - case v.IsUnsigned(): - var uval uint64 - uval, err := v.ToUint64() - if err != nil { - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err) - } - return newEvalUint64(uval), nil - default: - uval, err := strconv.ParseUint(v.RawStr(), 10, 64) - if err == nil { - return newEvalUint64(uval), nil - } - ival, err := strconv.ParseInt(v.RawStr(), 10, 64) - if err == nil { - return &evalInt64{i: ival}, nil - } - return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not parse value: '%s'", v.RawStr()) - } -} - func valueToEval(value sqltypes.Value, collation collations.TypedCollation) (eval, error) { wrap := func(err error) error { if err == nil { diff --git a/go/vt/vtgate/evalengine/eval_numeric.go b/go/vt/vtgate/evalengine/eval_numeric.go index 8584fa4a714..fb34caab85d 100644 --- a/go/vt/vtgate/evalengine/eval_numeric.go +++ b/go/vt/vtgate/evalengine/eval_numeric.go @@ -81,10 +81,14 @@ func newEvalFloat(f float64) *evalFloat { } func newEvalDecimal(dec decimal.Decimal, m, d int32) *evalDecimal { - if m == 0 && d == 0 { + switch { + case m == 0 && d == 0: return newEvalDecimalWithPrec(dec, -dec.Exponent()) + case m == 0: + return newEvalDecimalWithPrec(dec, d) + default: + return newEvalDecimalWithPrec(dec.Clamp(m-d, d), d) } - return newEvalDecimalWithPrec(dec.Clamp(m-d, d), d) } func newEvalDecimalWithPrec(dec decimal.Decimal, prec int32) *evalDecimal { diff --git a/go/vt/vtgate/planbuilder/operators/union_merging.go b/go/vt/vtgate/planbuilder/operators/union_merging.go index 67853e44c7f..81ca2f5623e 100644 --- a/go/vt/vtgate/planbuilder/operators/union_merging.go +++ b/go/vt/vtgate/planbuilder/operators/union_merging.go @@ -203,10 +203,17 @@ func createMergedUnion( rt, foundR := ctx.SemTable.TypeForExpr(rae.Expr) lt, foundL := ctx.SemTable.TypeForExpr(lae.Expr) if foundR && foundL { - t, err := evalengine.AggregateEvalTypes([]evalengine.Type{rt, lt}, ctx.VSchema.Environment().CollationEnv()) - if err == nil { - ctx.SemTable.ExprTypes[col] = t + collations := ctx.VSchema.Environment().CollationEnv() + var typer evalengine.TypeAggregator + + if err := typer.Add(rt, collations); err != nil { + panic(err) + } + if err := typer.Add(lt, collations); err != nil { + panic(err) } + + ctx.SemTable.ExprTypes[col] = typer.Type() } ctx.SemTable.Recursive[col] = deps diff --git a/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json b/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json index c951b70a8d0..d0a4911fb74 100644 --- a/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/info_schema80_cases.json @@ -149,10 +149,10 @@ "0: utf8mb3_general_ci", "1: utf8mb3_general_ci", "2: utf8mb3_general_ci", - "3", + "3: binary", "4: utf8mb3_general_ci", "5", - "6", + "6: binary", "7", "8", "9", @@ -166,8 +166,9 @@ "17: utf8mb3_general_ci", "18", "19: utf8mb3_general_ci", - "20: utf8mb3_general_ci" + "(20:21)" ], + "ResultColumns": 21, "Inputs": [ { "OperatorType": "Concatenate", @@ -179,8 +180,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE, `ENGINE`, VERSION, `ROW_FORMAT`, TABLE_ROWS, `AVG_ROW_LENGTH`, DATA_LENGTH, MAX_DATA_LENGTH, INDEX_LENGTH, DATA_FREE, `AUTO_INCREMENT`, CREATE_TIME, UPDATE_TIME, CHECK_TIME, TABLE_COLLATION, `CHECKSUM`, CREATE_OPTIONS, TABLE_COMMENT from information_schema.`tables` where 1 != 1", - "Query": "select distinct TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE, `ENGINE`, VERSION, `ROW_FORMAT`, TABLE_ROWS, `AVG_ROW_LENGTH`, DATA_LENGTH, MAX_DATA_LENGTH, INDEX_LENGTH, DATA_FREE, `AUTO_INCREMENT`, CREATE_TIME, UPDATE_TIME, CHECK_TIME, TABLE_COLLATION, `CHECKSUM`, CREATE_OPTIONS, TABLE_COMMENT from information_schema.`tables` where table_schema = :__vtschemaname /* VARCHAR */", + "FieldQuery": "select TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE, `ENGINE`, VERSION, `ROW_FORMAT`, TABLE_ROWS, `AVG_ROW_LENGTH`, DATA_LENGTH, MAX_DATA_LENGTH, INDEX_LENGTH, DATA_FREE, `AUTO_INCREMENT`, CREATE_TIME, UPDATE_TIME, CHECK_TIME, TABLE_COLLATION, `CHECKSUM`, CREATE_OPTIONS, TABLE_COMMENT, weight_string(TABLE_COMMENT) from information_schema.`tables` where 1 != 1", + "Query": "select distinct TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE, `ENGINE`, VERSION, `ROW_FORMAT`, TABLE_ROWS, `AVG_ROW_LENGTH`, DATA_LENGTH, MAX_DATA_LENGTH, INDEX_LENGTH, DATA_FREE, `AUTO_INCREMENT`, CREATE_TIME, UPDATE_TIME, CHECK_TIME, TABLE_COLLATION, `CHECKSUM`, CREATE_OPTIONS, TABLE_COMMENT, weight_string(TABLE_COMMENT) from information_schema.`tables` where table_schema = :__vtschemaname /* VARCHAR */", "SysTableTableSchema": "['user']", "Table": "information_schema.`tables`" }, @@ -191,8 +192,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE, `ENGINE`, VERSION, `ROW_FORMAT`, TABLE_ROWS, `AVG_ROW_LENGTH`, DATA_LENGTH, MAX_DATA_LENGTH, INDEX_LENGTH, DATA_FREE, `AUTO_INCREMENT`, CREATE_TIME, UPDATE_TIME, CHECK_TIME, TABLE_COLLATION, `CHECKSUM`, CREATE_OPTIONS, TABLE_COMMENT from information_schema.`tables` where 1 != 1", - "Query": "select distinct TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE, `ENGINE`, VERSION, `ROW_FORMAT`, TABLE_ROWS, `AVG_ROW_LENGTH`, DATA_LENGTH, MAX_DATA_LENGTH, INDEX_LENGTH, DATA_FREE, `AUTO_INCREMENT`, CREATE_TIME, UPDATE_TIME, CHECK_TIME, TABLE_COLLATION, `CHECKSUM`, CREATE_OPTIONS, TABLE_COMMENT from information_schema.`tables` where table_schema = :__vtschemaname /* VARCHAR */", + "FieldQuery": "select TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE, `ENGINE`, VERSION, `ROW_FORMAT`, TABLE_ROWS, `AVG_ROW_LENGTH`, DATA_LENGTH, MAX_DATA_LENGTH, INDEX_LENGTH, DATA_FREE, `AUTO_INCREMENT`, CREATE_TIME, UPDATE_TIME, CHECK_TIME, TABLE_COLLATION, `CHECKSUM`, CREATE_OPTIONS, TABLE_COMMENT, weight_string(TABLE_COMMENT) from information_schema.`tables` where 1 != 1", + "Query": "select distinct TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, TABLE_TYPE, `ENGINE`, VERSION, `ROW_FORMAT`, TABLE_ROWS, `AVG_ROW_LENGTH`, DATA_LENGTH, MAX_DATA_LENGTH, INDEX_LENGTH, DATA_FREE, `AUTO_INCREMENT`, CREATE_TIME, UPDATE_TIME, CHECK_TIME, TABLE_COLLATION, `CHECKSUM`, CREATE_OPTIONS, TABLE_COMMENT, weight_string(TABLE_COMMENT) from information_schema.`tables` where table_schema = :__vtschemaname /* VARCHAR */", "SysTableTableSchema": "['main']", "Table": "information_schema.`tables`" } diff --git a/go/vt/vtgate/planbuilder/testdata/union_cases.json b/go/vt/vtgate/planbuilder/testdata/union_cases.json index 7c225862235..76f1fa460ca 100644 --- a/go/vt/vtgate/planbuilder/testdata/union_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/union_cases.json @@ -322,8 +322,9 @@ "Instructions": { "OperatorType": "Distinct", "Collations": [ - "0: utf8mb3_general_ci" + "(0:1)" ], + "ResultColumns": 1, "Inputs": [ { "OperatorType": "Concatenate", @@ -335,8 +336,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select CHARACTER_SET_NAME from information_schema.CHARACTER_SETS where 1 != 1", - "Query": "select distinct CHARACTER_SET_NAME from information_schema.CHARACTER_SETS", + "FieldQuery": "select CHARACTER_SET_NAME, weight_string(CHARACTER_SET_NAME) from information_schema.CHARACTER_SETS where 1 != 1", + "Query": "select distinct CHARACTER_SET_NAME, weight_string(CHARACTER_SET_NAME) from information_schema.CHARACTER_SETS", "Table": "information_schema.CHARACTER_SETS" }, { @@ -346,8 +347,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select user_name from unsharded where 1 != 1", - "Query": "select distinct user_name from unsharded", + "FieldQuery": "select user_name, weight_string(user_name) from unsharded where 1 != 1", + "Query": "select distinct user_name, weight_string(user_name) from unsharded", "Table": "unsharded" } ] @@ -523,8 +524,9 @@ "Instructions": { "OperatorType": "Distinct", "Collations": [ - "0" + "(0:1)" ], + "ResultColumns": 1, "Inputs": [ { "OperatorType": "Route", @@ -533,8 +535,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select 1 from dual where 1 != 1 union select null from dual where 1 != 1 union select 1.0 from dual where 1 != 1 union select '1' from dual where 1 != 1 union select 2 from dual where 1 != 1 union select 2.0 from `user` where 1 != 1", - "Query": "select 1 from dual union select null from dual union select 1.0 from dual union select '1' from dual union select 2 from dual union select 2.0 from `user`", + "FieldQuery": "select dt.`1`, weight_string(dt.`1`) from (select 1 from dual where 1 != 1 union select null from dual where 1 != 1 union select 1.0 from dual where 1 != 1 union select '1' from dual where 1 != 1 union select 2 from dual where 1 != 1 union select 2.0 from `user` where 1 != 1) as dt where 1 != 1", + "Query": "select dt.`1`, weight_string(dt.`1`) from (select 1 from dual union select null from dual union select 1.0 from dual union select '1' from dual union select 2 from dual union select 2.0 from `user`) as dt", "Table": "`user`, dual" } ] @@ -622,9 +624,10 @@ "Instructions": { "OperatorType": "Distinct", "Collations": [ - "0: utf8mb4_0900_ai_ci", - "1: utf8mb4_0900_ai_ci" + "(0:2)", + "(1:3)" ], + "ResultColumns": 2, "Inputs": [ { "OperatorType": "Concatenate", @@ -636,14 +639,14 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select 'b', 'c' from `user` where 1 != 1", - "Query": "select distinct 'b', 'c' from `user`", + "FieldQuery": "select 'b', 'c', weight_string('b'), weight_string('c') from `user` where 1 != 1", + "Query": "select distinct 'b', 'c', weight_string('b'), weight_string('c') from `user`", "Table": "`user`" }, { "OperatorType": "Join", "Variant": "Join", - "JoinColumnIndexes": "L:0,L:1", + "JoinColumnIndexes": "L:0,L:1,L:2,L:3", "TableName": "`user`_user_extra", "Inputs": [ { @@ -653,8 +656,8 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select `user`.id, `user`.`name` from `user` where 1 != 1", - "Query": "select distinct `user`.id, `user`.`name` from `user`", + "FieldQuery": "select `user`.id, `user`.`name`, weight_string(`user`.id), weight_string(`user`.`name`) from `user` where 1 != 1", + "Query": "select distinct `user`.id, `user`.`name`, weight_string(`user`.id), weight_string(`user`.`name`) from `user`", "Table": "`user`" }, { diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index e2205a6f6a8..f604f2a4ec7 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -341,6 +341,7 @@ func isParentSelectStatement(cursor *sqlparser.Cursor) bool { type originable interface { tableSetFor(t *sqlparser.AliasedTableExpr) TableSet depsForExpr(expr sqlparser.Expr) (direct, recursive TableSet, typ evalengine.Type) + collationEnv() *collations.Environment } func (a *analyzer) depsForExpr(expr sqlparser.Expr) (direct, recursive TableSet, typ evalengine.Type) { @@ -350,6 +351,10 @@ func (a *analyzer) depsForExpr(expr sqlparser.Expr) (direct, recursive TableSet, return } +func (a *analyzer) collationEnv() *collations.Environment { + return a.typer.collationEnv +} + func (a *analyzer) analyze(statement sqlparser.Statement) error { _ = sqlparser.Rewrite(statement, nil, a.earlyUp) if a.err != nil { diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index 5bc160f52a6..e17a75044ba 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -129,9 +129,10 @@ func (tc *tableCollector) visitUnion(union *sqlparser.Union) error { size := len(firstSelect.SelectExprs) info.recursive = make([]TableSet, size) - info.types = make([]evalengine.Type, size) + typers := make([]evalengine.TypeAggregator, size) + collations := tc.org.collationEnv() - _ = sqlparser.VisitAllSelects(union, func(s *sqlparser.Select, idx int) error { + err := sqlparser.VisitAllSelects(union, func(s *sqlparser.Select, idx int) error { for i, expr := range s.SelectExprs { ae, ok := expr.(*sqlparser.AliasedExpr) if !ok { @@ -139,13 +140,19 @@ func (tc *tableCollector) visitUnion(union *sqlparser.Union) error { } _, recursiveDeps, qt := tc.org.depsForExpr(ae.Expr) info.recursive[i] = info.recursive[i].Merge(recursiveDeps) - if idx == 0 { - // TODO: we probably should coerce these types together somehow, but I'm not sure how - info.types[i] = qt + if err := typers[i].Add(qt, collations); err != nil { + return err } } return nil }) + if err != nil { + return err + } + + for _, ts := range typers { + info.types = append(info.types, ts.Type()) + } tc.unionInfo[union] = info return nil }