From cd61d851308e7ed638274436ae17a1e4bb69ea60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Taylor?= Date: Mon, 12 Feb 2024 13:27:42 +0100 Subject: [PATCH] bugfix: wrong field type returned for SUM (#15192) Signed-off-by: Andres Taylor Signed-off-by: Vicent Marti Co-authored-by: Vicent Marti --- go/test/endtoend/utils/mysql.go | 70 ++++++++++++-------- go/vt/vtgate/engine/aggregations.go | 4 +- go/vt/vtgate/engine/opcode/constants.go | 26 +++++++- go/vt/vtgate/engine/opcode/constants_test.go | 6 +- go/vt/vtgate/engine/projection.go | 10 +-- go/vt/vtgate/evalengine/expr_arithmetic.go | 64 ++++++++++-------- go/vt/vtgate/semantics/typer.go | 10 +-- 7 files changed, 117 insertions(+), 73 deletions(-) diff --git a/go/test/endtoend/utils/mysql.go b/go/test/endtoend/utils/mysql.go index ca43ff15970..1e770b87516 100644 --- a/go/test/endtoend/utils/mysql.go +++ b/go/test/endtoend/utils/mysql.go @@ -27,15 +27,14 @@ import ( "github.com/stretchr/testify/assert" + "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/sqltypes" - "vitess.io/vitess/go/vt/dbconfigs" - "vitess.io/vitess/go/vt/sqlparser" - - "vitess.io/vitess/go/mysql" "vitess.io/vitess/go/test/endtoend/cluster" + "vitess.io/vitess/go/vt/dbconfigs" "vitess.io/vitess/go/vt/mysqlctl" + querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/sqlparser" ) const mysqlShutdownTimeout = 1 * time.Minute @@ -160,7 +159,9 @@ func prepareMySQLWithSchema(params mysql.ConnParams, sql string) error { return nil } -func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn, vtQr, mysqlQr *sqltypes.Result, compareColumns bool) error { +func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn, vtQr, mysqlQr *sqltypes.Result, compareColumnNames bool) error { + t.Helper() + if vtQr == nil && mysqlQr == nil { return nil } @@ -173,28 +174,29 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn return errors.New("MySQL result is 'nil' while Vitess' is not.\n") } - var errStr string - if compareColumns { - vtColCount := len(vtQr.Fields) - myColCount := len(mysqlQr.Fields) - if vtColCount > 0 && myColCount > 0 { - if vtColCount != myColCount { - t.Errorf("column count does not match: %d vs %d", vtColCount, myColCount) - errStr += fmt.Sprintf("column count does not match: %d vs %d\n", vtColCount, myColCount) - } - - var vtCols []string - var myCols []string - for i, vtField := range vtQr.Fields { - vtCols = append(vtCols, vtField.Name) - myCols = append(myCols, mysqlQr.Fields[i].Name) - } - if !assert.Equal(t, myCols, vtCols, "column names do not match - the expected values are what mysql produced") { - errStr += "column names do not match - the expected values are what mysql produced\n" - errStr += fmt.Sprintf("Not equal: \nexpected: %v\nactual: %v\n", myCols, vtCols) - } + vtColCount := len(vtQr.Fields) + myColCount := len(mysqlQr.Fields) + + if vtColCount != myColCount { + t.Errorf("column count does not match: %d vs %d", vtColCount, myColCount) + } + + if vtColCount > 0 { + var vtCols []string + var myCols []string + for i, vtField := range vtQr.Fields { + myField := mysqlQr.Fields[i] + checkFields(t, myField.Name, vtField, myField) + + vtCols = append(vtCols, vtField.Name) + myCols = append(myCols, myField.Name) + } + + if compareColumnNames && !assert.Equal(t, myCols, vtCols, "column names do not match - the expected values are what mysql produced") { + t.Errorf("column names do not match - the expected values are what mysql produced\nNot equal: \nexpected: %v\nactual: %v\n", myCols, vtCols) } } + stmt, err := sqlparser.NewTestParser().Parse(query) if err != nil { t.Error(err) @@ -209,7 +211,7 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn return nil } - errStr += "Query (" + query + ") results mismatched.\nVitess Results:\n" + errStr := "Query (" + query + ") results mismatched.\nVitess Results:\n" for _, row := range vtQr.Rows { errStr += fmt.Sprintf("%s\n", row) } @@ -229,6 +231,20 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn return errors.New(errStr) } +func checkFields(t *testing.T, columnName string, vtField, myField *querypb.Field) { + t.Helper() + if vtField.Type != myField.Type { + t.Errorf("for column %s field types do not match\nNot equal: \nMySQL: %v\nVitess: %v\n", columnName, myField.Type.String(), vtField.Type.String()) + } + + // starting in Vitess 20, decimal types are properly sized in their field information + if BinaryIsAtLeastAtVersion(20, "vtgate") && vtField.Type == sqltypes.Decimal { + if vtField.Decimals != myField.Decimals { + t.Errorf("for column %s field decimals count do not match\nNot equal: \nMySQL: %v\nVitess: %v\n", columnName, myField.Decimals, vtField.Decimals) + } + } +} + func compareVitessAndMySQLErrors(t *testing.T, vtErr, mysqlErr error) { if vtErr != nil && mysqlErr != nil || vtErr == nil && mysqlErr == nil { return diff --git a/go/vt/vtgate/engine/aggregations.go b/go/vt/vtgate/engine/aggregations.go index 3413234c84f..ea10267a7e6 100644 --- a/go/vt/vtgate/engine/aggregations.go +++ b/go/vt/vtgate/engine/aggregations.go @@ -91,9 +91,9 @@ func (ap *AggregateParams) String() string { func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type { if ap.OrigOpcode != AggregateUnassigned { - return ap.OrigOpcode.Type(inputType) + return ap.OrigOpcode.SQLType(inputType) } - return ap.Opcode.Type(inputType) + return ap.Opcode.SQLType(inputType) } type aggregator interface { diff --git a/go/vt/vtgate/engine/opcode/constants.go b/go/vt/vtgate/engine/opcode/constants.go index 8a70df79d0c..2fa0e9446a4 100644 --- a/go/vt/vtgate/engine/opcode/constants.go +++ b/go/vt/vtgate/engine/opcode/constants.go @@ -19,8 +19,10 @@ package opcode import ( "fmt" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + "vitess.io/vitess/go/vt/vtgate/evalengine" ) // PulloutOpcode is a number representing the opcode @@ -138,7 +140,7 @@ func (code AggregateOpcode) MarshalJSON() ([]byte, error) { } // Type returns the opcode return sql type, and a bool telling is we are sure about this type or not -func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type { +func (code AggregateOpcode) SQLType(typ querypb.Type) querypb.Type { switch code { case AggregateUnassigned: return sqltypes.Null @@ -169,6 +171,28 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type { } } +func (code AggregateOpcode) Nullable() bool { + switch code { + case AggregateCount, AggregateCountStar: + return false + default: + return true + } +} + +func (code AggregateOpcode) ResolveType(t evalengine.Type, env *collations.Environment) evalengine.Type { + sqltype := code.SQLType(t.Type()) + collation := collations.CollationForType(sqltype, env.DefaultConnectionCharset()) + nullable := code.Nullable() + size := t.Size() + + scale := t.Scale() + if code == AggregateAvg { + scale += 4 + } + return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale) +} + func (code AggregateOpcode) NeedsComparableValues() bool { switch code { case AggregateCountDistinct, AggregateSumDistinct, AggregateMin, AggregateMax: diff --git a/go/vt/vtgate/engine/opcode/constants_test.go b/go/vt/vtgate/engine/opcode/constants_test.go index 1a14c110086..5687a42433d 100644 --- a/go/vt/vtgate/engine/opcode/constants_test.go +++ b/go/vt/vtgate/engine/opcode/constants_test.go @@ -30,7 +30,7 @@ import ( func TestCheckAllAggrOpCodes(t *testing.T) { // This test is just checking that we never reach the panic when using Type() on valid opcodes for i := AggregateOpcode(0); i < _NumOfOpCodes; i++ { - i.Type(sqltypes.Null) + i.SQLType(sqltypes.Null) } } @@ -56,7 +56,7 @@ func TestType(t *testing.T) { for _, tc := range tt { t.Run(tc.opcode.String()+"_"+tc.typ.String(), func(t *testing.T) { - out := tc.opcode.Type(tc.typ) + out := tc.opcode.SQLType(tc.typ) assert.Equal(t, tc.out, out) }) } @@ -70,7 +70,7 @@ func TestType_Panic(t *testing.T) { assert.Contains(t, errMsg, "ERROR", "Expected panic message containing 'ERROR'") } }() - AggregateOpcode(999).Type(sqltypes.VarChar) + AggregateOpcode(999).SQLType(sqltypes.VarChar) } func TestNeedsListArg(t *testing.T) { diff --git a/go/vt/vtgate/engine/projection.go b/go/vt/vtgate/engine/projection.go index e9714b6a8cb..77e07203476 100644 --- a/go/vt/vtgate/engine/projection.go +++ b/go/vt/vtgate/engine/projection.go @@ -158,10 +158,12 @@ func (p *Projection) evalFields(env *evalengine.ExpressionEnv, infields []*query fl |= uint32(querypb.MySqlFlag_NOT_NULL_FLAG) } fields = append(fields, &querypb.Field{ - Name: col, - Type: typ.Type(), - Charset: uint32(typ.Collation()), - Flags: fl, + Name: col, + Type: typ.Type(), + Charset: uint32(typ.Collation()), + ColumnLength: uint32(typ.Size()), + Decimals: uint32(typ.Scale()), + Flags: fl, }) } return fields, nil diff --git a/go/vt/vtgate/evalengine/expr_arithmetic.go b/go/vt/vtgate/evalengine/expr_arithmetic.go index 938803910cb..5d82d279d69 100644 --- a/go/vt/vtgate/evalengine/expr_arithmetic.go +++ b/go/vt/vtgate/evalengine/expr_arithmetic.go @@ -94,12 +94,12 @@ func (op *opArithAdd) compile(c *compiler, left, right IR) (ctype, error) { rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true) lt, rt, swap = c.compileNumericPriority(lt, rt) - var sumtype sqltypes.Type + ct := ctype{Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric} switch lt.Type { case sqltypes.Int64: c.asm.Add_ii() - sumtype = sqltypes.Int64 + ct.Type = sqltypes.Int64 case sqltypes.Uint64: switch rt.Type { case sqltypes.Int64: @@ -107,7 +107,7 @@ func (op *opArithAdd) compile(c *compiler, left, right IR) (ctype, error) { case sqltypes.Uint64: c.asm.Add_uu() } - sumtype = sqltypes.Uint64 + ct.Type = sqltypes.Uint64 case sqltypes.Decimal: if swap { c.compileToDecimal(rt, 2) @@ -115,7 +115,8 @@ func (op *opArithAdd) compile(c *compiler, left, right IR) (ctype, error) { c.compileToDecimal(rt, 1) } c.asm.Add_dd() - sumtype = sqltypes.Decimal + ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) case sqltypes.Float64: if swap { c.compileToFloat(rt, 2) @@ -123,11 +124,11 @@ func (op *opArithAdd) compile(c *compiler, left, right IR) (ctype, error) { c.compileToFloat(rt, 1) } c.asm.Add_ff() - sumtype = sqltypes.Float64 + ct.Type = sqltypes.Float64 } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: sumtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil + return ct, nil } func (op *opArithSub) eval(left, right eval) (eval, error) { @@ -151,66 +152,68 @@ func (op *opArithSub) compile(c *compiler, left, right IR) (ctype, error) { lt = c.compileToNumeric(lt, 2, sqltypes.Float64, true) rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true) - var subtype sqltypes.Type - + ct := ctype{Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric} switch lt.Type { case sqltypes.Int64: switch rt.Type { case sqltypes.Int64: c.asm.Sub_ii() - subtype = sqltypes.Int64 + ct.Type = sqltypes.Int64 case sqltypes.Uint64: c.asm.Sub_iu() - subtype = sqltypes.Uint64 + ct.Type = sqltypes.Uint64 case sqltypes.Float64: c.compileToFloat(lt, 2) c.asm.Sub_ff() - subtype = sqltypes.Float64 + ct.Type = sqltypes.Float64 case sqltypes.Decimal: c.compileToDecimal(lt, 2) c.asm.Sub_dd() - subtype = sqltypes.Decimal + ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) } case sqltypes.Uint64: switch rt.Type { case sqltypes.Int64: c.asm.Sub_ui() - subtype = sqltypes.Uint64 + ct.Type = sqltypes.Uint64 case sqltypes.Uint64: c.asm.Sub_uu() - subtype = sqltypes.Uint64 + ct.Type = sqltypes.Uint64 case sqltypes.Float64: c.compileToFloat(lt, 2) c.asm.Sub_ff() - subtype = sqltypes.Float64 + ct.Type = sqltypes.Float64 case sqltypes.Decimal: c.compileToDecimal(lt, 2) c.asm.Sub_dd() - subtype = sqltypes.Decimal + ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) } case sqltypes.Float64: c.compileToFloat(rt, 1) c.asm.Sub_ff() - subtype = sqltypes.Float64 + ct.Type = sqltypes.Float64 case sqltypes.Decimal: switch rt.Type { case sqltypes.Float64: c.compileToFloat(lt, 2) c.asm.Sub_ff() - subtype = sqltypes.Float64 + ct.Type = sqltypes.Float64 default: c.compileToDecimal(rt, 1) c.asm.Sub_dd() - subtype = sqltypes.Decimal + ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) } } - if subtype == 0 { + if ct.Type == 0 { panic("did not compile?") } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: subtype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil + return ct, nil } func (op *opArithMul) eval(left, right eval) (eval, error) { @@ -237,12 +240,11 @@ func (op *opArithMul) compile(c *compiler, left, right IR) (ctype, error) { rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true) lt, rt, swap = c.compileNumericPriority(lt, rt) - var multype sqltypes.Type - + ct := ctype{Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric} switch lt.Type { case sqltypes.Int64: c.asm.Mul_ii() - multype = sqltypes.Int64 + ct.Type = sqltypes.Int64 case sqltypes.Uint64: switch rt.Type { case sqltypes.Int64: @@ -250,7 +252,7 @@ func (op *opArithMul) compile(c *compiler, left, right IR) (ctype, error) { case sqltypes.Uint64: c.asm.Mul_uu() } - multype = sqltypes.Uint64 + ct.Type = sqltypes.Uint64 case sqltypes.Float64: if swap { c.compileToFloat(rt, 2) @@ -258,7 +260,7 @@ func (op *opArithMul) compile(c *compiler, left, right IR) (ctype, error) { c.compileToFloat(rt, 1) } c.asm.Mul_ff() - multype = sqltypes.Float64 + ct.Type = sqltypes.Float64 case sqltypes.Decimal: if swap { c.compileToDecimal(rt, 2) @@ -266,11 +268,12 @@ func (op *opArithMul) compile(c *compiler, left, right IR) (ctype, error) { c.compileToDecimal(rt, 1) } c.asm.Mul_dd() - multype = sqltypes.Decimal + ct.Type = sqltypes.Decimal + ct.Scale = lt.Scale + rt.Scale } c.asm.jumpDestination(skip1, skip2) - return ctype{Type: multype, Flag: nullableFlags(lt.Flag | rt.Flag), Col: collationNumeric}, nil + return ct, nil } func (op *opArithDiv) eval(left, right eval) (eval, error) { @@ -306,6 +309,7 @@ func (op *opArithDiv) compile(c *compiler, left, right IR) (ctype, error) { c.compileToDecimal(lt, 2) c.compileToDecimal(rt, 1) c.asm.Div_dd() + ct.Scale = lt.Scale + divPrecisionIncrement } c.asm.jumpDestination(skip1, skip2) return ct, nil @@ -419,7 +423,7 @@ func (op *opArithMod) compile(c *compiler, left, right IR) (ctype, error) { lt = c.compileToNumeric(lt, 2, sqltypes.Float64, true) rt = c.compileToNumeric(rt, 1, sqltypes.Float64, true) - ct := ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: flagNullable} + ct := ctype{Col: collationNumeric, Flag: flagNullable} switch lt.Type { case sqltypes.Int64: ct.Type = sqltypes.Int64 @@ -434,6 +438,7 @@ func (op *opArithMod) compile(c *compiler, left, right IR) (ctype, error) { c.asm.Mod_ff() case sqltypes.Decimal: ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) c.asm.Convert_xd(2, 0, 0) c.asm.Mod_dd() } @@ -450,6 +455,7 @@ func (op *opArithMod) compile(c *compiler, left, right IR) (ctype, error) { c.asm.Mod_ff() case sqltypes.Decimal: ct.Type = sqltypes.Decimal + ct.Scale = max(lt.Scale, rt.Scale) c.asm.Convert_xd(2, 0, 0) c.asm.Mod_dd() } diff --git a/go/vt/vtgate/semantics/typer.go b/go/vt/vtgate/semantics/typer.go index a9c783cee18..54261339114 100644 --- a/go/vt/vtgate/semantics/typer.go +++ b/go/vt/vtgate/semantics/typer.go @@ -18,7 +18,6 @@ package semantics import ( "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine/opcode" "vitess.io/vitess/go/vt/vtgate/evalengine" @@ -55,16 +54,13 @@ func (t *typer) up(cursor *sqlparser.Cursor) error { if !ok { return nil } - inputType := sqltypes.Unknown + var inputType evalengine.Type if arg := node.GetArg(); arg != nil { if tt, ok := t.m[arg]; ok { - inputType = tt.Type() + inputType = tt } } - type_ := code.Type(inputType) - _, isCount := node.(*sqlparser.Count) - _, isCountStart := node.(*sqlparser.CountStar) - t.m[node] = evalengine.NewTypeEx(type_, collations.CollationForType(type_, t.collationEnv.DefaultConnectionCharset()), !(isCount || isCountStart), 0, 0) + t.m[node] = code.ResolveType(inputType, t.collationEnv) } return nil }