Skip to content

Commit

Permalink
Fix type coercion between the sides of an UNION (#15340)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
Signed-off-by: Vicent Marti <vmg@strn.cat>
Co-authored-by: Vicent Marti <vmg@strn.cat>
  • Loading branch information
systay and vmg authored Mar 8, 2024
1 parent 96e0a62 commit 983a3c8
Show file tree
Hide file tree
Showing 22 changed files with 310 additions and 1,034 deletions.
8 changes: 7 additions & 1 deletion go/test/endtoend/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,13 @@ func AssertMatchesAny(t testing.TB, conn *mysql.Conn, query string, expected ...
return
}
}
t.Errorf("Query: %s (-want +got):\n%v\nGot:%s", query, expected, got)

var err strings.Builder
_, _ = fmt.Fprintf(&err, "Query did not match:\n%s\n", query)
for i, e := range expected {
_, _ = fmt.Fprintf(&err, "Expected query %d does not match.\nwant: %v\ngot: %v\n\n", i, e, got)
}
t.Error(err.String())
}

// AssertMatchesCompareMySQL executes the given query on both Vitess and MySQL and make sure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,26 @@ func TestInfrSchemaAndUnionAll(t *testing.T) {
}
}

func TestInfoschemaTypes(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")

require.NoError(t,
utils.WaitForAuthoritative(t, "ks", "t1", clusterInstance.VtgateProcess.ReadVSchema))

mcmp, closer := start(t)
defer closer()

mcmp.Exec(`
SELECT ORDINAL_POSITION
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = 'ks' AND TABLE_NAME = 't1'
UNION
SELECT ORDINAL_POSITION
FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA = 'ks' AND TABLE_NAME = 't2';
`)
}

func TestTypeORMQuery(t *testing.T) {
utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate")
// This test checks that we can run queries similar to the ones that the TypeORM framework uses
Expand Down
1 change: 1 addition & 0 deletions go/test/endtoend/vtgate/queries/orderby/orderby_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ func TestOrderByComplex(t *testing.T) {
"select email, max(col) as max_col from (select email, col from user where col > 20) as filtered group by email order by max_col",
"select a.email, a.max_col from (select email, max(col) as max_col from user group by email) as a order by a.max_col desc",
"select email, max(col) as max_col from user where email like 'a%' group by email order by max_col, email",
`select email, max(col) as max_col from user group by email union select email, avg(col) as avg_col from user group by email order by email desc`,
}

for _, query := range queries {
Expand Down
117 changes: 55 additions & 62 deletions go/vt/vtgate/engine/concatenate.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars
return nil, err
}

fields, err := c.getFields(res)
fields, fieldTypes, err := c.getFieldTypes(vcursor, res)
if err != nil {
return nil, err
}

var rows [][]sqltypes.Value
err = c.coerceAndVisitResults(res, fields, func(result *sqltypes.Result) error {
err = c.coerceAndVisitResults(res, fieldTypes, func(result *sqltypes.Result) error {
rows = append(rows, result.Rows...)
return nil
}, evalengine.ParseSQLMode(vcursor.SQLMode()))
Expand All @@ -116,17 +116,17 @@ func (c *Concatenate) TryExecute(ctx context.Context, vcursor VCursor, bindVars
}, nil
}

func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field, sqlmode evalengine.SQLMode) error {
if len(row) != len(fields) {
func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fieldTypes []evalengine.Type, sqlmode evalengine.SQLMode) error {
if len(row) != len(fieldTypes) {
return errWrongNumberOfColumnsInSelect
}

for i, value := range row {
if _, found := c.NoNeedToTypeCheck[i]; found {
continue
}
if fields[i].Type != value.Type() {
newValue, err := evalengine.CoerceTo(value, fields[i].Type, sqlmode)
if fieldTypes[i].Type() != value.Type() {
newValue, err := evalengine.CoerceTo(value, fieldTypes[i], sqlmode)
if err != nil {
return err
}
Expand All @@ -136,44 +136,44 @@ func (c *Concatenate) coerceValuesTo(row sqltypes.Row, fields []*querypb.Field,
return nil
}

func (c *Concatenate) getFields(res []*sqltypes.Result) (resultFields []*querypb.Field, err error) {
func (c *Concatenate) getFieldTypes(vcursor VCursor, res []*sqltypes.Result) ([]*querypb.Field, []evalengine.Type, error) {
if len(res) == 0 {
return nil, nil
return nil, nil, nil
}

resultFields = res[0].Fields
columns := make([][]sqltypes.Type, len(resultFields))

addFields := func(fields []*querypb.Field) error {
if len(fields) != len(columns) {
return errWrongNumberOfColumnsInSelect
}
for idx, field := range fields {
columns[idx] = append(columns[idx], field.Type)
}
return nil
}
typers := make([]evalengine.TypeAggregator, len(res[0].Fields))
collations := vcursor.Environment().CollationEnv()

for _, r := range res {
if r == nil || r.Fields == nil {
continue
}
err := addFields(r.Fields)
if err != nil {
return nil, err
if len(r.Fields) != len(typers) {
return nil, nil, errWrongNumberOfColumnsInSelect
}
for idx, field := range r.Fields {
if err := typers[idx].AddField(field, collations); err != nil {
return nil, nil, err
}
}
}

// The resulting column types need to be the coercion of all the input columns
for colIdx, t := range columns {
fields := make([]*querypb.Field, 0, len(typers))
types := make([]evalengine.Type, 0, len(typers))
for colIdx, typer := range typers {
f := res[0].Fields[colIdx]

if _, found := c.NoNeedToTypeCheck[colIdx]; found {
fields = append(fields, f)
types = append(types, evalengine.NewTypeFromField(f))
continue
}

resultFields[colIdx].Type = evalengine.AggregateTypes(t)
t := typer.Type()
fields = append(fields, t.ToField(f.Name))
types = append(types, t)
}

return resultFields, nil
return fields, types, nil
}

func (c *Concatenate) execSources(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) ([]*sqltypes.Result, error) {
Expand Down Expand Up @@ -250,7 +250,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,
condFields = sync.NewCond(&muFields) // Condition var for field arrival
wg errgroup.Group // Wait group for all streaming goroutines
rest = make([]*sqltypes.Result, len(c.Sources)) // Collects first result from each source to derive fields
fields []*querypb.Field // Cached final field types
fieldTypes []evalengine.Type // Cached final field types
)

// Process each result chunk, considering type coercion.
Expand All @@ -263,7 +263,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,
needsCoercion := false
for idx, field := range rest[srcIdx].Fields {
_, skip := c.NoNeedToTypeCheck[idx]
if !skip && fields[idx].Type != field.Type {
if !skip && fieldTypes[idx].Type() != field.Type {
needsCoercion = true
break
}
Expand All @@ -272,7 +272,7 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,
// Apply type coercion if needed.
if needsCoercion {
for _, row := range res.Rows {
if err := c.coerceValuesTo(row, fields, sqlmode); err != nil {
if err := c.coerceValuesTo(row, fieldTypes, sqlmode); err != nil {
return err
}
}
Expand All @@ -299,11 +299,10 @@ func (c *Concatenate) parallelStreamExec(inCtx context.Context, vcursor VCursor,

// We have received fields from all sources. We can now calculate the output types
var err error
fields, err = c.getFields(rest)
resultChunk.Fields, fieldTypes, err = c.getFieldTypes(vcursor, rest)
if err != nil {
return err
}
resultChunk.Fields = fields

defer condFields.Broadcast()
return callback(resultChunk, currIndex)
Expand Down Expand Up @@ -370,12 +369,12 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor,
firsts[i] = result[0]
}

fields, err := c.getFields(firsts)
_, fieldTypes, err := c.getFieldTypes(vcursor, firsts)
if err != nil {
return err
}
for _, res := range results {
if err = c.coerceAndVisitResults(res, fields, callback, sqlmode); err != nil {
if err = c.coerceAndVisitResults(res, fieldTypes, callback, sqlmode); err != nil {
return err
}
}
Expand All @@ -385,26 +384,26 @@ func (c *Concatenate) sequentialStreamExec(ctx context.Context, vcursor VCursor,

func (c *Concatenate) coerceAndVisitResults(
res []*sqltypes.Result,
fields []*querypb.Field,
fieldTypes []evalengine.Type,
callback func(*sqltypes.Result) error,
sqlmode evalengine.SQLMode,
) error {
for _, r := range res {
if len(r.Rows) > 0 &&
len(fields) != len(r.Rows[0]) {
len(fieldTypes) != len(r.Rows[0]) {
return errWrongNumberOfColumnsInSelect
}

needsCoercion := false
for idx, field := range r.Fields {
if fields[idx].Type != field.Type {
if fieldTypes[idx].Type() != field.Type {
needsCoercion = true
break
}
}
if needsCoercion {
for _, row := range r.Rows {
err := c.coerceValuesTo(row, fields, sqlmode)
err := c.coerceValuesTo(row, fieldTypes, sqlmode)
if err != nil {
return err
}
Expand All @@ -420,35 +419,29 @@ func (c *Concatenate) coerceAndVisitResults(

// GetFields fetches the field info.
func (c *Concatenate) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
res, err := c.Sources[0].GetFields(ctx, vcursor, bindVars)
if err != nil {
return nil, err
}

columns := make([][]sqltypes.Type, len(res.Fields))

addFields := func(fields []*querypb.Field) {
for idx, field := range fields {
columns[idx] = append(columns[idx], field.Type)
}
}

addFields(res.Fields)

for i := 1; i < len(c.Sources); i++ {
result, err := c.Sources[i].GetFields(ctx, vcursor, bindVars)
sourceFields := make([][]*querypb.Field, 0, len(c.Sources))
for _, src := range c.Sources {
f, err := src.GetFields(ctx, vcursor, bindVars)
if err != nil {
return nil, err
}
addFields(result.Fields)
sourceFields = append(sourceFields, f.Fields)
}

// The resulting column types need to be the coercion of all the input columns
for colIdx, t := range columns {
res.Fields[colIdx].Type = evalengine.AggregateTypes(t)
}
fields := make([]*querypb.Field, 0, len(sourceFields[0]))
collations := vcursor.Environment().CollationEnv()

return res, nil
for colIdx := 0; colIdx < len(sourceFields[0]); colIdx++ {
var typer evalengine.TypeAggregator
for _, src := range sourceFields {
if err := typer.AddField(src[colIdx], collations); err != nil {
return nil, err
}
}
name := sourceFields[0][colIdx].Name
fields = append(fields, typer.Field(name))
}
return &sqltypes.Result{Fields: fields}, nil
}

// NeedsTransaction returns whether a transaction is needed for this primitive
Expand Down
28 changes: 19 additions & 9 deletions go/vt/vtgate/engine/concatenate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"strings"
"testing"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/test/utils"

"github.com/stretchr/testify/assert"
Expand All @@ -32,7 +33,17 @@ import (
)

func r(names, types string, rows ...string) *sqltypes.Result {
return sqltypes.MakeTestResult(sqltypes.MakeTestFields(names, types), rows...)
fields := sqltypes.MakeTestFields(names, types)
for _, f := range fields {
if sqltypes.IsText(f.Type) {
f.Charset = collations.CollationUtf8mb4ID
} else {
f.Charset = collations.CollationBinaryID
}
_, flags := sqltypes.TypeToMySQL(f.Type)
f.Flags = uint32(flags)
}
return sqltypes.MakeTestResult(fields, rows...)
}

func TestConcatenate_NoErrors(t *testing.T) {
Expand Down Expand Up @@ -173,12 +184,12 @@ func TestConcatenateTypes(t *testing.T) {
tests := []struct {
t1, t2, expected string
}{
{t1: "int32", t2: "int64", expected: "int64"},
{t1: "int32", t2: "int32", expected: "int32"},
{t1: "int32", t2: "varchar", expected: "varchar"},
{t1: "int32", t2: "decimal", expected: "decimal"},
{t1: "hexval", t2: "uint64", expected: "varchar"},
{t1: "varchar", t2: "varbinary", expected: "varbinary"},
{t1: "int32", t2: "int64", expected: `[name:"id" type:int64 charset:63]`},
{t1: "int32", t2: "int32", expected: `[name:"id" type:int32 charset:63]`},
{t1: "int32", t2: "varchar", expected: `[name:"id" type:varchar charset:255]`},
{t1: "int32", t2: "decimal", expected: `[name:"id" type:decimal charset:63]`},
{t1: "hexval", t2: "uint64", expected: `[name:"id" type:varchar charset:255]`},
{t1: "varchar", t2: "varbinary", expected: `[name:"id" type:varbinary charset:63 flags:128]`},
}

for _, test := range tests {
Expand All @@ -196,8 +207,7 @@ func TestConcatenateTypes(t *testing.T) {
res, err := concatenate.GetFields(context.Background(), &noopVCursor{}, nil)
require.NoError(t, err)

expected := fmt.Sprintf(`[name:"id" type:%s]`, test.expected)
assert.Equal(t, expected, strings.ToLower(fmt.Sprintf("%v", res.Fields)))
assert.Equal(t, test.expected, strings.ToLower(fmt.Sprintf("%v", res.Fields)))
})
}
}
Loading

0 comments on commit 983a3c8

Please sign in to comment.