Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Scale and length handling in CASE and JOIN bind variables #15787

Merged
merged 6 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions go/mysql/json/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,14 @@ type Value struct {
n NumberType
}

func (v *Value) Size() int32 {
return 0
}

func (v *Value) Scale() int32 {
return 0
}

func (v *Value) MarshalDate() string {
if d, ok := v.Date(); ok {
return d.ToStdTime(time.Local).Format("2006-01-02")
Expand Down
17 changes: 15 additions & 2 deletions go/test/endtoend/vtgate/queries/tpch/tpch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package union
import (
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/test/endtoend/utils"

"github.com/stretchr/testify/require"
)

func start(t *testing.T) (utils.MySQLCompare, func()) {
Expand Down Expand Up @@ -161,6 +161,19 @@ group by
order by
value desc;`,
},
{
name: "Q14 without decimal literal",
query: `select sum(case
when p_type like 'PROMO%'
then l_extendedprice * (1 - l_discount)
else 0
end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue
from lineitem,
part
where l_partkey = p_partkey
and l_shipdate >= '1996-12-01'
and l_shipdate < date_add('1996-12-01', interval '1' month);`,
},
}

for _, testcase := range testcases {
Expand Down
11 changes: 6 additions & 5 deletions go/vt/vtgate/engine/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package engine

import (
"bytes"
"context"
"fmt"
"strings"
Expand Down Expand Up @@ -61,7 +62,7 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st
result := &sqltypes.Result{}
if len(lresult.Rows) == 0 && wantfields {
for k, col := range jn.Vars {
joinVars[k] = bindvarForType(lresult.Fields[col].Type)
joinVars[k] = bindvarForType(lresult.Fields[col])
}
rresult, err := jn.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars))
if err != nil {
Expand Down Expand Up @@ -95,19 +96,19 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st
return result, nil
}

func bindvarForType(t querypb.Type) *querypb.BindVariable {
func bindvarForType(field *querypb.Field) *querypb.BindVariable {
bv := &querypb.BindVariable{
Type: t,
Type: field.Type,
Value: nil,
}
switch t {
switch field.Type {
case querypb.Type_INT8, querypb.Type_UINT8, querypb.Type_INT16, querypb.Type_UINT16,
querypb.Type_INT32, querypb.Type_UINT32, querypb.Type_INT64, querypb.Type_UINT64:
bv.Value = []byte("0")
case querypb.Type_FLOAT32, querypb.Type_FLOAT64:
bv.Value = []byte("0e0")
case querypb.Type_DECIMAL:
bv.Value = []byte("0.0")
bv.Value = append(append(bytes.Repeat([]byte{'0'}, max(1, int(field.ColumnLength-field.Decimals))), byte('.')), bytes.Repeat([]byte{'0'}, max(1, int(field.Decimals)))...)
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved
default:
return sqltypes.NullBindVariable
}
Expand Down
23 changes: 17 additions & 6 deletions go/vt/vtgate/evalengine/api_type_aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ type typeAggregation struct {
blob uint16
total uint16

nullable bool
nullable bool
scale, size int32
}

type TypeAggregator struct {
Expand All @@ -63,7 +64,7 @@ func (ta *TypeAggregator) Add(typ Type, env *collations.Environment) error {
return nil
}

ta.types.addNullable(typ.typ, typ.nullable)
ta.types.addNullable(typ.typ, typ.nullable, typ.size, typ.scale)
if err := ta.collations.add(typedCoercionCollation(typ.typ, typ.collation), env); err != nil {
return err
}
Expand Down Expand Up @@ -105,10 +106,10 @@ func (ta *typeAggregation) addEval(e eval) {
default:
t = e.SQLType()
}
ta.add(t, f)
ta.add(t, f, e.Size(), e.Scale())
}

func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool) {
func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool, size, scale int32) {
var flag typeFlag
if typ == sqltypes.HexVal || typ == sqltypes.HexNum {
typ = sqltypes.Binary
Expand All @@ -117,13 +118,15 @@ func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool) {
if nullable {
flag |= flagNullable
}
ta.add(typ, flag)
ta.add(typ, flag, size, scale)
}

func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) {
func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag, size, scale int32) {
if f&flagNullable != 0 {
ta.nullable = true
}
ta.size = max(ta.size, size)
ta.scale = max(ta.scale, scale)
switch tt {
case sqltypes.Float32, sqltypes.Float64:
ta.double++
Expand Down Expand Up @@ -190,6 +193,14 @@ func nextSignedTypeForUnsigned(t sqltypes.Type) sqltypes.Type {
}
}

func (ta *typeAggregation) Size() int32 {
return ta.size
}

func (ta *typeAggregation) Scale() int32 {
return ta.scale
}

func (ta *typeAggregation) result() sqltypes.Type {
/*
If all types are numeric, the aggregated type is also numeric:
Expand Down
17 changes: 17 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"time"

"github.com/olekukonko/tablewriter"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -168,6 +169,7 @@ func TestCompilerSingle(t *testing.T) {
values []sqltypes.Value
result string
collation collations.ID
typeWanted evalengine.Type
}{
{
expression: "1 + column0",
Expand Down Expand Up @@ -675,6 +677,15 @@ func TestCompilerSingle(t *testing.T) {
expression: `1 * unix_timestamp(time('1.0000'))`,
result: `DECIMAL(1698098401.0000)`,
},
{
expression: `(case
when 'PROMOTION' like 'PROMO%'
then 0.01
else 0
end) * 0.01`,
result: `DECIMAL(0.0001)`,
typeWanted: evalengine.NewTypeEx(sqltypes.Decimal, collations.CollationBinaryID, false, 4, 4, nil),
},
}

tz, _ := time.LoadLocation("Europe/Madrid")
Expand Down Expand Up @@ -715,6 +726,12 @@ func TestCompilerSingle(t *testing.T) {
t.Fatalf("bad collation evaluation from eval engine: got %d, want %d", expected.Collation(), tc.collation)
}

if tc.typeWanted.Type() != sqltypes.Unknown {
typ, err := env.TypeOf(converted)
require.NoError(t, err)
require.True(t, tc.typeWanted.Equal(&typ))
}

// re-run the same evaluation multiple times to ensure results are always consistent
for i := 0; i < 8; i++ {
res, err := env.Evaluate(converted)
Expand Down
2 changes: 2 additions & 0 deletions go/vt/vtgate/evalengine/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ func (f typeFlag) Nullable() bool {
type eval interface {
ToRawBytes() []byte
SQLType() sqltypes.Type
Size() int32
Scale() int32
}

type hashable interface {
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ func (e *evalBytes) SQLType() sqltypes.Type {
return sqltypes.Type(e.tt)
}

func (e *evalBytes) Size() int32 {
return 0
}

func (e *evalBytes) Scale() int32 {
return 0
}

func (e *evalBytes) ToRawBytes() []byte {
return e.bytes
}
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ func (e *evalEnum) SQLType() sqltypes.Type {
return sqltypes.Enum
}

func (e *evalEnum) Size() int32 {
return 0
}

func (e *evalEnum) Scale() int32 {
return 0
}

func valueIdx(values *EnumSetValues, value string) int {
if values == nil {
return -1
Expand Down
32 changes: 32 additions & 0 deletions go/vt/vtgate/evalengine/eval_numeric.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,14 @@ func (e *evalInt64) SQLType() sqltypes.Type {
return sqltypes.Int64
}

func (e *evalInt64) Size() int32 {
return 0
}

func (e *evalInt64) Scale() int32 {
return 0
}

func (e *evalInt64) ToRawBytes() []byte {
return strconv.AppendInt(nil, e.i, 10)
}
Expand Down Expand Up @@ -409,6 +417,14 @@ func (e *evalUint64) SQLType() sqltypes.Type {
return sqltypes.Uint64
}

func (e *evalUint64) Size() int32 {
return 0
}

func (e *evalUint64) Scale() int32 {
return 0
}

func (e *evalUint64) ToRawBytes() []byte {
return strconv.AppendUint(nil, e.u, 10)
}
Expand Down Expand Up @@ -452,6 +468,14 @@ func (e *evalFloat) SQLType() sqltypes.Type {
return sqltypes.Float64
}

func (e *evalFloat) Size() int32 {
return 0
}

func (e *evalFloat) Scale() int32 {
return 0
}

func (e *evalFloat) ToRawBytes() []byte {
return format.FormatFloat(e.f)
}
Expand Down Expand Up @@ -528,6 +552,14 @@ func (e *evalDecimal) SQLType() sqltypes.Type {
return sqltypes.Decimal
}

func (e *evalDecimal) Size() int32 {
return e.length
}

func (e *evalDecimal) Scale() int32 {
return -e.dec.Exponent()
}

func (e *evalDecimal) ToRawBytes() []byte {
return e.dec.FormatMySQL(e.length)
}
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ func (e *evalSet) SQLType() sqltypes.Type {
return sqltypes.Set
}

func (e *evalSet) Size() int32 {
return 0
}

func (e *evalSet) Scale() int32 {
return 0
}

func evalSetBits(values *EnumSetValues, value string) uint64 {
if values != nil && len(*values) > 64 {
// This never would happen as MySQL limits SET
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_temporal.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ func (e *evalTemporal) SQLType() sqltypes.Type {
return e.t
}

func (e *evalTemporal) Size() int32 {
return 0
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved
}

func (e *evalTemporal) Scale() int32 {
return 0
}

func (e *evalTemporal) toInt64() int64 {
switch e.SQLType() {
case sqltypes.Date:
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_tuple.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,11 @@ func (e *evalTuple) ToRawBytes() []byte {
func (e *evalTuple) SQLType() sqltypes.Type {
return sqltypes.Tuple
}

func (e *evalTuple) Size() int32 {
return 0
}

func (e *evalTuple) Scale() int32 {
return 0
}
6 changes: 3 additions & 3 deletions go/vt/vtgate/evalengine/expr_logical.go
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
return ctype{}, err
}

ta.add(then.Type, then.Flag)
ta.add(then.Type, then.Flag, then.Size, then.Scale)
if err := ca.add(then.Col, c.env.CollationEnv()); err != nil {
return ctype{}, err
}
Expand All @@ -702,7 +702,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
return ctype{}, err
}

ta.add(els.Type, els.Flag)
ta.add(els.Type, els.Flag, els.Size, els.Scale)
if err := ca.add(els.Col, c.env.CollationEnv()); err != nil {
return ctype{}, err
}
Expand All @@ -712,7 +712,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
if ta.nullable {
f |= flagNullable
}
ct := ctype{Type: ta.result(), Flag: f, Col: ca.result()}
ct := ctype{Type: ta.result(), Flag: f, Col: ca.result(), Scale: ta.Scale(), Size: ta.Size()}
c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col, c.sqlmode.AllowZeroDate())
return ct, nil
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/evalengine/fn_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (b *builtinCoalesce) compile(c *compiler) (ctype, error) {
if !tt.nullable() {
f = 0
}
ta.add(tt.Type, tt.Flag)
ta.add(tt.Type, tt.Flag, tt.Size, tt.Scale)
if err := ca.add(tt.Col, c.env.CollationEnv()); err != nil {
return ctype{}, err
}
Expand Down
Loading