Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Scale and length handling in CASE and JOIN bind variables #15787

Merged
merged 6 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions go/mysql/json/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,14 @@ type Value struct {
n NumberType
}

func (v *Value) Size() int32 {
return 0
}

func (v *Value) Scale() int32 {
return 0
}

func (v *Value) MarshalDate() string {
if d, ok := v.Date(); ok {
return d.ToStdTime(time.Local).Format("2006-01-02")
Expand Down
17 changes: 15 additions & 2 deletions go/test/endtoend/vtgate/queries/tpch/tpch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package union
import (
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/test/endtoend/utils"

"github.com/stretchr/testify/require"
)

func start(t *testing.T) (utils.MySQLCompare, func()) {
Expand Down Expand Up @@ -161,6 +161,19 @@ group by
order by
value desc;`,
},
{
name: "Q14 without decimal literal",
query: `select sum(case
when p_type like 'PROMO%'
then l_extendedprice * (1 - l_discount)
else 0
end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue
from lineitem,
part
where l_partkey = p_partkey
and l_shipdate >= '1996-12-01'
and l_shipdate < date_add('1996-12-01', interval '1' month);`,
},
}

for _, testcase := range testcases {
Expand Down
13 changes: 8 additions & 5 deletions go/vt/vtgate/engine/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package engine

import (
"bytes"
"context"
"fmt"
"strings"
Expand Down Expand Up @@ -61,7 +62,7 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st
result := &sqltypes.Result{}
if len(lresult.Rows) == 0 && wantfields {
for k, col := range jn.Vars {
joinVars[k] = bindvarForType(lresult.Fields[col].Type)
joinVars[k] = bindvarForType(lresult.Fields[col])
}
rresult, err := jn.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars))
if err != nil {
Expand Down Expand Up @@ -95,19 +96,21 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st
return result, nil
}

func bindvarForType(t querypb.Type) *querypb.BindVariable {
func bindvarForType(field *querypb.Field) *querypb.BindVariable {
bv := &querypb.BindVariable{
Type: t,
Type: field.Type,
Value: nil,
}
switch t {
switch field.Type {
case querypb.Type_INT8, querypb.Type_UINT8, querypb.Type_INT16, querypb.Type_UINT16,
querypb.Type_INT32, querypb.Type_UINT32, querypb.Type_INT64, querypb.Type_UINT64:
bv.Value = []byte("0")
case querypb.Type_FLOAT32, querypb.Type_FLOAT64:
bv.Value = []byte("0e0")
case querypb.Type_DECIMAL:
bv.Value = []byte("0.0")
size := max(1, int(field.ColumnLength-field.Decimals))
scale := max(1, int(field.Decimals))
bv.Value = append(append(bytes.Repeat([]byte{'0'}, size), byte('.')), bytes.Repeat([]byte{'0'}, scale)...)
default:
return sqltypes.NullBindVariable
}
Expand Down
20 changes: 14 additions & 6 deletions go/vt/vtgate/evalengine/api_type_aggregation.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ type typeAggregation struct {
blob uint16
total uint16

nullable bool
nullable bool
scale, size int32
}

type TypeAggregator struct {
Expand All @@ -63,7 +64,7 @@ func (ta *TypeAggregator) Add(typ Type, env *collations.Environment) error {
return nil
}

ta.types.addNullable(typ.typ, typ.nullable)
ta.types.addNullable(typ.typ, typ.nullable, typ.size, typ.scale)
if err := ta.collations.add(typedCoercionCollation(typ.typ, typ.collation), env); err != nil {
return err
}
Expand Down Expand Up @@ -95,20 +96,25 @@ func (ta *typeAggregation) empty() bool {
func (ta *typeAggregation) addEval(e eval) {
var t sqltypes.Type
var f typeFlag
var size, scale int32
switch e := e.(type) {
case nil:
t = sqltypes.Null
ta.nullable = true
case *evalBytes:
t = sqltypes.Type(e.tt)
f = e.flag
size = e.Size()
scale = e.Scale()
default:
t = e.SQLType()
size = e.Size()
scale = e.Scale()
}
ta.add(t, f)
ta.add(t, f, size, scale)
}

func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool) {
func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool, size, scale int32) {
var flag typeFlag
if typ == sqltypes.HexVal || typ == sqltypes.HexNum {
typ = sqltypes.Binary
Expand All @@ -117,13 +123,15 @@ func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool) {
if nullable {
flag |= flagNullable
}
ta.add(typ, flag)
ta.add(typ, flag, size, scale)
}

func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) {
func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag, size, scale int32) {
if f&flagNullable != 0 {
ta.nullable = true
}
ta.size = max(ta.size, size)
ta.scale = max(ta.scale, scale)
switch tt {
case sqltypes.Float32, sqltypes.Float64:
ta.double++
Expand Down
6 changes: 3 additions & 3 deletions go/vt/vtgate/evalengine/compiler_asm.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ func (asm *assembler) Cmp_ne_n() {
}, "CMPFLAG NE [NULL]")
}

func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc collations.TypedCollation, allowZeroDate bool) {
func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, size, scale int32, cc collations.TypedCollation, allowZeroDate bool) {
elseOffset := 0
if hasElse {
elseOffset = 1
Expand All @@ -529,12 +529,12 @@ func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc coll
end := env.vm.sp - elseOffset
for sp := env.vm.sp - stackDepth; sp < end; sp += 2 {
if env.vm.stack[sp] != nil && env.vm.stack[sp].(*evalInt64).i != 0 {
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation, env.now, allowZeroDate)
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, size, scale, cc.Collation, env.now, allowZeroDate)
goto done
}
}
if elseOffset != 0 {
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation, env.now, allowZeroDate)
env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, size, scale, cc.Collation, env.now, allowZeroDate)
} else {
env.vm.stack[env.vm.sp-stackDepth] = nil
}
Expand Down
30 changes: 30 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"time"

"github.com/olekukonko/tablewriter"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -168,6 +169,7 @@ func TestCompilerSingle(t *testing.T) {
values []sqltypes.Value
result string
collation collations.ID
typeWanted evalengine.Type
}{
{
expression: "1 + column0",
Expand Down Expand Up @@ -675,6 +677,28 @@ func TestCompilerSingle(t *testing.T) {
expression: `1 * unix_timestamp(time('1.0000'))`,
result: `DECIMAL(1698098401.0000)`,
},
{
expression: `(case
when 'PROMOTION' like 'PROMO%'
then 0.01
else 0
end) * 0.01`,
result: `DECIMAL(0.0001)`,
typeWanted: evalengine.NewTypeEx(sqltypes.Decimal, collations.CollationBinaryID, false, 4, 4, nil),
},
{
expression: `case when true then 0.02 else 1.000 end`,
result: `DECIMAL(0.02)`,
},
{
expression: `case
when false
then timestamp'2023-10-24 12:00:00.123456'
else timestamp'2023-10-24 12:00:00'
end`,
result: `DATETIME("2023-10-24 12:00:00.000000")`,
typeWanted: evalengine.NewTypeEx(sqltypes.Datetime, collations.CollationBinaryID, false, 6, 0, nil),
},
}

tz, _ := time.LoadLocation("Europe/Madrid")
Expand Down Expand Up @@ -715,6 +739,12 @@ func TestCompilerSingle(t *testing.T) {
t.Fatalf("bad collation evaluation from eval engine: got %d, want %d", expected.Collation(), tc.collation)
}

if tc.typeWanted.Type() != sqltypes.Unknown {
typ, err := env.TypeOf(converted)
require.NoError(t, err)
require.True(t, tc.typeWanted.Equal(&typ))
}

// re-run the same evaluation multiple times to ensure results are always consistent
for i := 0; i < 8; i++ {
res, err := env.Evaluate(converted)
Expand Down
10 changes: 6 additions & 4 deletions go/vt/vtgate/evalengine/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ func (f typeFlag) Nullable() bool {
type eval interface {
ToRawBytes() []byte
SQLType() sqltypes.Type
Size() int32
Scale() int32
}

type hashable interface {
Expand Down Expand Up @@ -170,7 +172,7 @@ func evalIsTruthy(e eval) boolean {
}
}

func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, allowZero bool) (eval, error) {
func evalCoerce(e eval, typ sqltypes.Type, size, scale int32, col collations.ID, now time.Time, allowZero bool) (eval, error) {
if e == nil {
return nil, nil
}
Expand All @@ -181,7 +183,7 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, all
// if we have an explicit VARCHAR coercion, always force it so the collation is replaced in the target
return evalToVarchar(e, col, false)
}
if e.SQLType() == typ {
if e.SQLType() == typ && e.Size() == size && e.Scale() == scale {
// nothing to be done here
return e, nil
}
Expand All @@ -204,9 +206,9 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, all
case sqltypes.Date:
return evalToDate(e, now, allowZero), nil
case sqltypes.Datetime, sqltypes.Timestamp:
return evalToDateTime(e, -1, now, allowZero), nil
return evalToDateTime(e, int(size), now, allowZero), nil
case sqltypes.Time:
return evalToTime(e, -1), nil
return evalToTime(e, int(size)), nil
default:
return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Unsupported type conversion: %s", typ.String())
}
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_bytes.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ func (e *evalBytes) SQLType() sqltypes.Type {
return sqltypes.Type(e.tt)
}

func (e *evalBytes) Size() int32 {
return 0
}

func (e *evalBytes) Scale() int32 {
return 0
}

func (e *evalBytes) ToRawBytes() []byte {
return e.bytes
}
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ func (e *evalEnum) SQLType() sqltypes.Type {
return sqltypes.Enum
}

func (e *evalEnum) Size() int32 {
return 0
}

func (e *evalEnum) Scale() int32 {
return 0
}

func valueIdx(values *EnumSetValues, value string) int {
if values == nil {
return -1
Expand Down
32 changes: 32 additions & 0 deletions go/vt/vtgate/evalengine/eval_numeric.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,14 @@ func (e *evalInt64) SQLType() sqltypes.Type {
return sqltypes.Int64
}

func (e *evalInt64) Size() int32 {
return 0
}

func (e *evalInt64) Scale() int32 {
return 0
}

func (e *evalInt64) ToRawBytes() []byte {
return strconv.AppendInt(nil, e.i, 10)
}
Expand Down Expand Up @@ -409,6 +417,14 @@ func (e *evalUint64) SQLType() sqltypes.Type {
return sqltypes.Uint64
}

func (e *evalUint64) Size() int32 {
return 0
}

func (e *evalUint64) Scale() int32 {
return 0
}

func (e *evalUint64) ToRawBytes() []byte {
return strconv.AppendUint(nil, e.u, 10)
}
Expand Down Expand Up @@ -452,6 +468,14 @@ func (e *evalFloat) SQLType() sqltypes.Type {
return sqltypes.Float64
}

func (e *evalFloat) Size() int32 {
return 0
}

func (e *evalFloat) Scale() int32 {
return 0
}

func (e *evalFloat) ToRawBytes() []byte {
return format.FormatFloat(e.f)
}
Expand Down Expand Up @@ -528,6 +552,14 @@ func (e *evalDecimal) SQLType() sqltypes.Type {
return sqltypes.Decimal
}

func (e *evalDecimal) Size() int32 {
return e.length
}

func (e *evalDecimal) Scale() int32 {
return -e.dec.Exponent()
}

func (e *evalDecimal) ToRawBytes() []byte {
return e.dec.FormatMySQL(e.length)
}
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ func (e *evalSet) SQLType() sqltypes.Type {
return sqltypes.Set
}

func (e *evalSet) Size() int32 {
return 0
}

func (e *evalSet) Scale() int32 {
return 0
}

func evalSetBits(values *EnumSetValues, value string) uint64 {
if values != nil && len(*values) > 64 {
// This never would happen as MySQL limits SET
Expand Down
8 changes: 8 additions & 0 deletions go/vt/vtgate/evalengine/eval_temporal.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ func (e *evalTemporal) SQLType() sqltypes.Type {
return e.t
}

func (e *evalTemporal) Size() int32 {
return int32(e.prec)
}

func (e *evalTemporal) Scale() int32 {
return 0
}

func (e *evalTemporal) toInt64() int64 {
switch e.SQLType() {
case sqltypes.Date:
Expand Down
Loading
Loading