Skip to content

Commit

Permalink
bugfix: wrong field type returned for SUM (#15192)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
Signed-off-by: Vicent Marti <vmg@strn.cat>
Co-authored-by: Vicent Marti <vmg@strn.cat>
  • Loading branch information
systay and vmg committed Feb 12, 2024
1 parent 8586b6d commit cd61d85
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 73 deletions.
70 changes: 43 additions & 27 deletions go/test/endtoend/utils/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@ import (

"github.com/stretchr/testify/assert"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/mysql/collations"

"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/dbconfigs"
"vitess.io/vitess/go/vt/sqlparser"

"vitess.io/vitess/go/mysql"
"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/vt/dbconfigs"
"vitess.io/vitess/go/vt/mysqlctl"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/sqlparser"
)

const mysqlShutdownTimeout = 1 * time.Minute
Expand Down Expand Up @@ -160,7 +159,9 @@ func prepareMySQLWithSchema(params mysql.ConnParams, sql string) error {
return nil
}

func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn, vtQr, mysqlQr *sqltypes.Result, compareColumns bool) error {
func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn, vtQr, mysqlQr *sqltypes.Result, compareColumnNames bool) error {
t.Helper()

if vtQr == nil && mysqlQr == nil {
return nil
}
Expand All @@ -173,28 +174,29 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn
return errors.New("MySQL result is 'nil' while Vitess' is not.\n")
}

var errStr string
if compareColumns {
vtColCount := len(vtQr.Fields)
myColCount := len(mysqlQr.Fields)
if vtColCount > 0 && myColCount > 0 {
if vtColCount != myColCount {
t.Errorf("column count does not match: %d vs %d", vtColCount, myColCount)
errStr += fmt.Sprintf("column count does not match: %d vs %d\n", vtColCount, myColCount)
}

var vtCols []string
var myCols []string
for i, vtField := range vtQr.Fields {
vtCols = append(vtCols, vtField.Name)
myCols = append(myCols, mysqlQr.Fields[i].Name)
}
if !assert.Equal(t, myCols, vtCols, "column names do not match - the expected values are what mysql produced") {
errStr += "column names do not match - the expected values are what mysql produced\n"
errStr += fmt.Sprintf("Not equal: \nexpected: %v\nactual: %v\n", myCols, vtCols)
}
vtColCount := len(vtQr.Fields)
myColCount := len(mysqlQr.Fields)

if vtColCount != myColCount {
t.Errorf("column count does not match: %d vs %d", vtColCount, myColCount)
}

if vtColCount > 0 {
var vtCols []string
var myCols []string
for i, vtField := range vtQr.Fields {
myField := mysqlQr.Fields[i]
checkFields(t, myField.Name, vtField, myField)

vtCols = append(vtCols, vtField.Name)
myCols = append(myCols, myField.Name)
}

if compareColumnNames && !assert.Equal(t, myCols, vtCols, "column names do not match - the expected values are what mysql produced") {
t.Errorf("column names do not match - the expected values are what mysql produced\nNot equal: \nexpected: %v\nactual: %v\n", myCols, vtCols)
}
}

stmt, err := sqlparser.NewTestParser().Parse(query)
if err != nil {
t.Error(err)
Expand All @@ -209,7 +211,7 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn
return nil
}

errStr += "Query (" + query + ") results mismatched.\nVitess Results:\n"
errStr := "Query (" + query + ") results mismatched.\nVitess Results:\n"
for _, row := range vtQr.Rows {
errStr += fmt.Sprintf("%s\n", row)
}
Expand All @@ -229,6 +231,20 @@ func compareVitessAndMySQLResults(t *testing.T, query string, vtConn *mysql.Conn
return errors.New(errStr)
}

func checkFields(t *testing.T, columnName string, vtField, myField *querypb.Field) {
t.Helper()
if vtField.Type != myField.Type {
t.Errorf("for column %s field types do not match\nNot equal: \nMySQL: %v\nVitess: %v\n", columnName, myField.Type.String(), vtField.Type.String())
}

// starting in Vitess 20, decimal types are properly sized in their field information
if BinaryIsAtLeastAtVersion(20, "vtgate") && vtField.Type == sqltypes.Decimal {
if vtField.Decimals != myField.Decimals {
t.Errorf("for column %s field decimals count do not match\nNot equal: \nMySQL: %v\nVitess: %v\n", columnName, myField.Decimals, vtField.Decimals)
}
}
}

func compareVitessAndMySQLErrors(t *testing.T, vtErr, mysqlErr error) {
if vtErr != nil && mysqlErr != nil || vtErr == nil && mysqlErr == nil {
return
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/engine/aggregations.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ func (ap *AggregateParams) String() string {

func (ap *AggregateParams) typ(inputType querypb.Type) querypb.Type {
if ap.OrigOpcode != AggregateUnassigned {
return ap.OrigOpcode.Type(inputType)
return ap.OrigOpcode.SQLType(inputType)
}
return ap.Opcode.Type(inputType)
return ap.Opcode.SQLType(inputType)
}

type aggregator interface {
Expand Down
26 changes: 25 additions & 1 deletion go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ package opcode
import (
"fmt"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"
"vitess.io/vitess/go/vt/vtgate/evalengine"
)

// PulloutOpcode is a number representing the opcode
Expand Down Expand Up @@ -138,7 +140,7 @@ func (code AggregateOpcode) MarshalJSON() ([]byte, error) {
}

// Type returns the opcode return sql type, and a bool telling is we are sure about this type or not
func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
func (code AggregateOpcode) SQLType(typ querypb.Type) querypb.Type {
switch code {
case AggregateUnassigned:
return sqltypes.Null
Expand Down Expand Up @@ -169,6 +171,28 @@ func (code AggregateOpcode) Type(typ querypb.Type) querypb.Type {
}
}

func (code AggregateOpcode) Nullable() bool {
switch code {
case AggregateCount, AggregateCountStar:
return false
default:
return true
}
}

func (code AggregateOpcode) ResolveType(t evalengine.Type, env *collations.Environment) evalengine.Type {
sqltype := code.SQLType(t.Type())
collation := collations.CollationForType(sqltype, env.DefaultConnectionCharset())
nullable := code.Nullable()
size := t.Size()

scale := t.Scale()
if code == AggregateAvg {
scale += 4
}
return evalengine.NewTypeEx(sqltype, collation, nullable, size, scale)
}

func (code AggregateOpcode) NeedsComparableValues() bool {
switch code {
case AggregateCountDistinct, AggregateSumDistinct, AggregateMin, AggregateMax:
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/engine/opcode/constants_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
func TestCheckAllAggrOpCodes(t *testing.T) {
// This test is just checking that we never reach the panic when using Type() on valid opcodes
for i := AggregateOpcode(0); i < _NumOfOpCodes; i++ {
i.Type(sqltypes.Null)
i.SQLType(sqltypes.Null)
}
}

Expand All @@ -56,7 +56,7 @@ func TestType(t *testing.T) {

for _, tc := range tt {
t.Run(tc.opcode.String()+"_"+tc.typ.String(), func(t *testing.T) {
out := tc.opcode.Type(tc.typ)
out := tc.opcode.SQLType(tc.typ)
assert.Equal(t, tc.out, out)
})
}
Expand All @@ -70,7 +70,7 @@ func TestType_Panic(t *testing.T) {
assert.Contains(t, errMsg, "ERROR", "Expected panic message containing 'ERROR'")
}
}()
AggregateOpcode(999).Type(sqltypes.VarChar)
AggregateOpcode(999).SQLType(sqltypes.VarChar)
}

func TestNeedsListArg(t *testing.T) {
Expand Down
10 changes: 6 additions & 4 deletions go/vt/vtgate/engine/projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,12 @@ func (p *Projection) evalFields(env *evalengine.ExpressionEnv, infields []*query
fl |= uint32(querypb.MySqlFlag_NOT_NULL_FLAG)
}
fields = append(fields, &querypb.Field{
Name: col,
Type: typ.Type(),
Charset: uint32(typ.Collation()),
Flags: fl,
Name: col,
Type: typ.Type(),
Charset: uint32(typ.Collation()),
ColumnLength: uint32(typ.Size()),
Decimals: uint32(typ.Scale()),
Flags: fl,
})
}
return fields, nil
Expand Down
Loading

0 comments on commit cd61d85

Please sign in to comment.