From f422372c3a0ddb60a4bc4dffddab6d542c522c69 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Wed, 24 Apr 2024 12:38:25 +0530 Subject: [PATCH 1/5] feat: add failing case test and fix it Signed-off-by: Manan Gupta --- .../endtoend/vtgate/queries/tpch/tpch_test.go | 17 +++++++++++++++-- go/vt/vtgate/engine/join.go | 10 +++++----- go/vt/vtgate/evalengine/compiler_test.go | 17 +++++++++++++++++ go/vt/vtgate/evalengine/expr_logical.go | 7 ++++++- 4 files changed, 43 insertions(+), 8 deletions(-) diff --git a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go index 513aea94a86..70e0c5e1edd 100644 --- a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go +++ b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go @@ -19,10 +19,10 @@ package union import ( "testing" + "github.com/stretchr/testify/require" + "vitess.io/vitess/go/test/endtoend/cluster" "vitess.io/vitess/go/test/endtoend/utils" - - "github.com/stretchr/testify/require" ) func start(t *testing.T) (utils.MySQLCompare, func()) { @@ -161,6 +161,19 @@ group by order by value desc;`, }, + { + name: "Q14 without decimal literal", + query: `select sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from lineitem, + part +where l_partkey = p_partkey + and l_shipdate >= '1996-12-01' + and l_shipdate < date_add('1996-12-01', interval '1' month);`, + }, } for _, testcase := range testcases { diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index 45b0d182dd7..5c5259dfa83 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -61,7 +61,7 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st result := &sqltypes.Result{} if len(lresult.Rows) == 0 && wantfields { for k, col := range jn.Vars { - joinVars[k] = bindvarForType(lresult.Fields[col].Type) + joinVars[k] = bindvarForType(lresult.Fields[col]) } rresult, err := jn.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars)) if err != nil { @@ -95,19 +95,19 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st return result, nil } -func bindvarForType(t querypb.Type) *querypb.BindVariable { +func bindvarForType(field *querypb.Field) *querypb.BindVariable { bv := &querypb.BindVariable{ - Type: t, + Type: field.Type, Value: nil, } - switch t { + switch field.Type { case querypb.Type_INT8, querypb.Type_UINT8, querypb.Type_INT16, querypb.Type_UINT16, querypb.Type_INT32, querypb.Type_UINT32, querypb.Type_INT64, querypb.Type_UINT64: bv.Value = []byte("0") case querypb.Type_FLOAT32, querypb.Type_FLOAT64: bv.Value = []byte("0e0") case querypb.Type_DECIMAL: - bv.Value = []byte("0.0") + bv.Value = []byte(fmt.Sprintf("%s.%s", strings.Repeat("0", max(1, int(field.ColumnLength-field.Decimals))), strings.Repeat("0", max(1, int(field.Decimals))))) default: return sqltypes.NullBindVariable } diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 3d5283db415..ebcfbf84a7c 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -25,6 +25,7 @@ import ( "time" "github.com/olekukonko/tablewriter" + "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" @@ -168,6 +169,7 @@ func TestCompilerSingle(t *testing.T) { values []sqltypes.Value result string collation collations.ID + typeWanted evalengine.Type }{ { expression: "1 + column0", @@ -675,6 +677,15 @@ func TestCompilerSingle(t *testing.T) { expression: `1 * unix_timestamp(time('1.0000'))`, result: `DECIMAL(1698098401.0000)`, }, + { + expression: `(case + when 'PROMOTION' like 'PROMO%' + then 0.01 + else 0 + end) * 0.01`, + result: `DECIMAL(0.0001)`, + typeWanted: evalengine.NewTypeEx(sqltypes.Decimal, collations.CollationBinaryID, false, 4, 4), + }, } tz, _ := time.LoadLocation("Europe/Madrid") @@ -715,6 +726,12 @@ func TestCompilerSingle(t *testing.T) { t.Fatalf("bad collation evaluation from eval engine: got %d, want %d", expected.Collation(), tc.collation) } + if tc.typeWanted.Type() != sqltypes.Unknown { + typ, err := env.TypeOf(converted) + require.NoError(t, err) + require.EqualValues(t, tc.typeWanted, typ) + } + // re-run the same evaluation multiple times to ensure results are always consistent for i := 0; i < 8; i++ { res, err := env.Evaluate(converted) diff --git a/go/vt/vtgate/evalengine/expr_logical.go b/go/vt/vtgate/evalengine/expr_logical.go index ef59616b97c..c4ddba4ce0c 100644 --- a/go/vt/vtgate/evalengine/expr_logical.go +++ b/go/vt/vtgate/evalengine/expr_logical.go @@ -674,6 +674,7 @@ func (c *CaseExpr) simplify(env *ExpressionEnv) error { func (cs *CaseExpr) compile(c *compiler) (ctype, error) { var ca collationAggregation var ta typeAggregation + var scale, size int32 for _, wt := range cs.cases { when, err := wt.when.compile(c) @@ -691,6 +692,8 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { } ta.add(then.Type, then.Flag) + scale = max(scale, then.Scale) + size = max(size, then.Size) if err := ca.add(then.Col, c.env.CollationEnv()); err != nil { return ctype{}, err } @@ -703,6 +706,8 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { } ta.add(els.Type, els.Flag) + scale = max(scale, els.Scale) + size = max(size, els.Size) if err := ca.add(els.Col, c.env.CollationEnv()); err != nil { return ctype{}, err } @@ -712,7 +717,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { if ta.nullable { f |= flagNullable } - ct := ctype{Type: ta.result(), Flag: f, Col: ca.result()} + ct := ctype{Type: ta.result(), Flag: f, Col: ca.result(), Scale: scale, Size: size} c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col, c.sqlmode.AllowZeroDate()) return ct, nil } From e65487d31edce2e39b083ad982ae31d7b42ee4df Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Wed, 24 Apr 2024 14:09:59 +0530 Subject: [PATCH 2/5] refactor: address review comments Signed-off-by: Manan Gupta --- go/vt/vtgate/engine/join.go | 3 ++- go/vt/vtgate/evalengine/compiler_test.go | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index 5c5259dfa83..e6dd00b687b 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -17,6 +17,7 @@ limitations under the License. package engine import ( + "bytes" "context" "fmt" "strings" @@ -107,7 +108,7 @@ func bindvarForType(field *querypb.Field) *querypb.BindVariable { case querypb.Type_FLOAT32, querypb.Type_FLOAT64: bv.Value = []byte("0e0") case querypb.Type_DECIMAL: - bv.Value = []byte(fmt.Sprintf("%s.%s", strings.Repeat("0", max(1, int(field.ColumnLength-field.Decimals))), strings.Repeat("0", max(1, int(field.Decimals))))) + bv.Value = append(append(bytes.Repeat([]byte{'0'}, max(1, int(field.ColumnLength-field.Decimals))), byte('.')), bytes.Repeat([]byte{'0'}, max(1, int(field.Decimals)))...) default: return sqltypes.NullBindVariable } diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index ebcfbf84a7c..00dc538bf73 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -684,7 +684,7 @@ func TestCompilerSingle(t *testing.T) { else 0 end) * 0.01`, result: `DECIMAL(0.0001)`, - typeWanted: evalengine.NewTypeEx(sqltypes.Decimal, collations.CollationBinaryID, false, 4, 4), + typeWanted: evalengine.NewTypeEx(sqltypes.Decimal, collations.CollationBinaryID, false, 4, 4, nil), }, } @@ -729,7 +729,7 @@ func TestCompilerSingle(t *testing.T) { if tc.typeWanted.Type() != sqltypes.Unknown { typ, err := env.TypeOf(converted) require.NoError(t, err) - require.EqualValues(t, tc.typeWanted, typ) + require.True(t, tc.typeWanted.Equal(&typ)) } // re-run the same evaluation multiple times to ensure results are always consistent From 4786466921d8faa3f079d6682070a7a0d02c56b8 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Wed, 24 Apr 2024 14:50:29 +0530 Subject: [PATCH 3/5] refactor: move scale and size calculation inside type aggregation Signed-off-by: Manan Gupta --- go/mysql/json/parser.go | 8 +++++ .../vtgate/evalengine/api_type_aggregation.go | 23 +++++++++---- go/vt/vtgate/evalengine/eval.go | 2 ++ go/vt/vtgate/evalengine/eval_bytes.go | 8 +++++ go/vt/vtgate/evalengine/eval_enum.go | 8 +++++ go/vt/vtgate/evalengine/eval_numeric.go | 32 +++++++++++++++++++ go/vt/vtgate/evalengine/eval_set.go | 8 +++++ go/vt/vtgate/evalengine/eval_temporal.go | 8 +++++ go/vt/vtgate/evalengine/eval_tuple.go | 8 +++++ go/vt/vtgate/evalengine/expr_logical.go | 11 ++----- go/vt/vtgate/evalengine/fn_compare.go | 2 +- 11 files changed, 103 insertions(+), 15 deletions(-) diff --git a/go/mysql/json/parser.go b/go/mysql/json/parser.go index 707d890df93..b7a87c25756 100644 --- a/go/mysql/json/parser.go +++ b/go/mysql/json/parser.go @@ -669,6 +669,14 @@ type Value struct { n NumberType } +func (v *Value) Size() int32 { + return 0 +} + +func (v *Value) Scale() int32 { + return 0 +} + func (v *Value) MarshalDate() string { if d, ok := v.Date(); ok { return d.ToStdTime(time.Local).Format("2006-01-02") diff --git a/go/vt/vtgate/evalengine/api_type_aggregation.go b/go/vt/vtgate/evalengine/api_type_aggregation.go index 04622e5a212..b618f391953 100644 --- a/go/vt/vtgate/evalengine/api_type_aggregation.go +++ b/go/vt/vtgate/evalengine/api_type_aggregation.go @@ -47,7 +47,8 @@ type typeAggregation struct { blob uint16 total uint16 - nullable bool + nullable bool + scale, size int32 } type TypeAggregator struct { @@ -63,7 +64,7 @@ func (ta *TypeAggregator) Add(typ Type, env *collations.Environment) error { return nil } - ta.types.addNullable(typ.typ, typ.nullable) + ta.types.addNullable(typ.typ, typ.nullable, typ.size, typ.scale) if err := ta.collations.add(typedCoercionCollation(typ.typ, typ.collation), env); err != nil { return err } @@ -105,10 +106,10 @@ func (ta *typeAggregation) addEval(e eval) { default: t = e.SQLType() } - ta.add(t, f) + ta.add(t, f, e.Size(), e.Scale()) } -func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool) { +func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool, size, scale int32) { var flag typeFlag if typ == sqltypes.HexVal || typ == sqltypes.HexNum { typ = sqltypes.Binary @@ -117,13 +118,15 @@ func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool) { if nullable { flag |= flagNullable } - ta.add(typ, flag) + ta.add(typ, flag, size, scale) } -func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag) { +func (ta *typeAggregation) add(tt sqltypes.Type, f typeFlag, size, scale int32) { if f&flagNullable != 0 { ta.nullable = true } + ta.size = max(ta.size, size) + ta.scale = max(ta.scale, scale) switch tt { case sqltypes.Float32, sqltypes.Float64: ta.double++ @@ -190,6 +193,14 @@ func nextSignedTypeForUnsigned(t sqltypes.Type) sqltypes.Type { } } +func (ta *typeAggregation) Size() int32 { + return ta.size +} + +func (ta *typeAggregation) Scale() int32 { + return ta.scale +} + func (ta *typeAggregation) result() sqltypes.Type { /* If all types are numeric, the aggregated type is also numeric: diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index 90b1add541a..34ad4f4f008 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -72,6 +72,8 @@ func (f typeFlag) Nullable() bool { type eval interface { ToRawBytes() []byte SQLType() sqltypes.Type + Size() int32 + Scale() int32 } type hashable interface { diff --git a/go/vt/vtgate/evalengine/eval_bytes.go b/go/vt/vtgate/evalengine/eval_bytes.go index caa516acbe4..027c4bb652d 100644 --- a/go/vt/vtgate/evalengine/eval_bytes.go +++ b/go/vt/vtgate/evalengine/eval_bytes.go @@ -138,6 +138,14 @@ func (e *evalBytes) SQLType() sqltypes.Type { return sqltypes.Type(e.tt) } +func (e *evalBytes) Size() int32 { + return 0 +} + +func (e *evalBytes) Scale() int32 { + return 0 +} + func (e *evalBytes) ToRawBytes() []byte { return e.bytes } diff --git a/go/vt/vtgate/evalengine/eval_enum.go b/go/vt/vtgate/evalengine/eval_enum.go index a0d349314da..fa9675d7c0e 100644 --- a/go/vt/vtgate/evalengine/eval_enum.go +++ b/go/vt/vtgate/evalengine/eval_enum.go @@ -26,6 +26,14 @@ func (e *evalEnum) SQLType() sqltypes.Type { return sqltypes.Enum } +func (e *evalEnum) Size() int32 { + return 0 +} + +func (e *evalEnum) Scale() int32 { + return 0 +} + func valueIdx(values *EnumSetValues, value string) int { if values == nil { return -1 diff --git a/go/vt/vtgate/evalengine/eval_numeric.go b/go/vt/vtgate/evalengine/eval_numeric.go index 64f5477a3fc..04f844566b1 100644 --- a/go/vt/vtgate/evalengine/eval_numeric.go +++ b/go/vt/vtgate/evalengine/eval_numeric.go @@ -366,6 +366,14 @@ func (e *evalInt64) SQLType() sqltypes.Type { return sqltypes.Int64 } +func (e *evalInt64) Size() int32 { + return 0 +} + +func (e *evalInt64) Scale() int32 { + return 0 +} + func (e *evalInt64) ToRawBytes() []byte { return strconv.AppendInt(nil, e.i, 10) } @@ -409,6 +417,14 @@ func (e *evalUint64) SQLType() sqltypes.Type { return sqltypes.Uint64 } +func (e *evalUint64) Size() int32 { + return 0 +} + +func (e *evalUint64) Scale() int32 { + return 0 +} + func (e *evalUint64) ToRawBytes() []byte { return strconv.AppendUint(nil, e.u, 10) } @@ -452,6 +468,14 @@ func (e *evalFloat) SQLType() sqltypes.Type { return sqltypes.Float64 } +func (e *evalFloat) Size() int32 { + return 0 +} + +func (e *evalFloat) Scale() int32 { + return 0 +} + func (e *evalFloat) ToRawBytes() []byte { return format.FormatFloat(e.f) } @@ -528,6 +552,14 @@ func (e *evalDecimal) SQLType() sqltypes.Type { return sqltypes.Decimal } +func (e *evalDecimal) Size() int32 { + return e.length +} + +func (e *evalDecimal) Scale() int32 { + return -e.dec.Exponent() +} + func (e *evalDecimal) ToRawBytes() []byte { return e.dec.FormatMySQL(e.length) } diff --git a/go/vt/vtgate/evalengine/eval_set.go b/go/vt/vtgate/evalengine/eval_set.go index 6a9de2eff14..bc75a527edc 100644 --- a/go/vt/vtgate/evalengine/eval_set.go +++ b/go/vt/vtgate/evalengine/eval_set.go @@ -29,6 +29,14 @@ func (e *evalSet) SQLType() sqltypes.Type { return sqltypes.Set } +func (e *evalSet) Size() int32 { + return 0 +} + +func (e *evalSet) Scale() int32 { + return 0 +} + func evalSetBits(values *EnumSetValues, value string) uint64 { if values != nil && len(*values) > 64 { // This never would happen as MySQL limits SET diff --git a/go/vt/vtgate/evalengine/eval_temporal.go b/go/vt/vtgate/evalengine/eval_temporal.go index 7706ec36e64..2766c1dfb56 100644 --- a/go/vt/vtgate/evalengine/eval_temporal.go +++ b/go/vt/vtgate/evalengine/eval_temporal.go @@ -42,6 +42,14 @@ func (e *evalTemporal) SQLType() sqltypes.Type { return e.t } +func (e *evalTemporal) Size() int32 { + return 0 +} + +func (e *evalTemporal) Scale() int32 { + return 0 +} + func (e *evalTemporal) toInt64() int64 { switch e.SQLType() { case sqltypes.Date: diff --git a/go/vt/vtgate/evalengine/eval_tuple.go b/go/vt/vtgate/evalengine/eval_tuple.go index 73e7fcc2051..81fa3317977 100644 --- a/go/vt/vtgate/evalengine/eval_tuple.go +++ b/go/vt/vtgate/evalengine/eval_tuple.go @@ -33,3 +33,11 @@ func (e *evalTuple) ToRawBytes() []byte { func (e *evalTuple) SQLType() sqltypes.Type { return sqltypes.Tuple } + +func (e *evalTuple) Size() int32 { + return 0 +} + +func (e *evalTuple) Scale() int32 { + return 0 +} diff --git a/go/vt/vtgate/evalengine/expr_logical.go b/go/vt/vtgate/evalengine/expr_logical.go index c4ddba4ce0c..9a38405f05c 100644 --- a/go/vt/vtgate/evalengine/expr_logical.go +++ b/go/vt/vtgate/evalengine/expr_logical.go @@ -674,7 +674,6 @@ func (c *CaseExpr) simplify(env *ExpressionEnv) error { func (cs *CaseExpr) compile(c *compiler) (ctype, error) { var ca collationAggregation var ta typeAggregation - var scale, size int32 for _, wt := range cs.cases { when, err := wt.when.compile(c) @@ -691,9 +690,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { return ctype{}, err } - ta.add(then.Type, then.Flag) - scale = max(scale, then.Scale) - size = max(size, then.Size) + ta.add(then.Type, then.Flag, then.Size, then.Scale) if err := ca.add(then.Col, c.env.CollationEnv()); err != nil { return ctype{}, err } @@ -705,9 +702,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { return ctype{}, err } - ta.add(els.Type, els.Flag) - scale = max(scale, els.Scale) - size = max(size, els.Size) + ta.add(els.Type, els.Flag, els.Size, els.Scale) if err := ca.add(els.Col, c.env.CollationEnv()); err != nil { return ctype{}, err } @@ -717,7 +712,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { if ta.nullable { f |= flagNullable } - ct := ctype{Type: ta.result(), Flag: f, Col: ca.result(), Scale: scale, Size: size} + ct := ctype{Type: ta.result(), Flag: f, Col: ca.result(), Scale: ta.Scale(), Size: ta.Size()} c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col, c.sqlmode.AllowZeroDate()) return ct, nil } diff --git a/go/vt/vtgate/evalengine/fn_compare.go b/go/vt/vtgate/evalengine/fn_compare.go index c102f5e5ef5..1deec6752ef 100644 --- a/go/vt/vtgate/evalengine/fn_compare.go +++ b/go/vt/vtgate/evalengine/fn_compare.go @@ -71,7 +71,7 @@ func (b *builtinCoalesce) compile(c *compiler) (ctype, error) { if !tt.nullable() { f = 0 } - ta.add(tt.Type, tt.Flag) + ta.add(tt.Type, tt.Flag, tt.Size, tt.Scale) if err := ca.add(tt.Col, c.env.CollationEnv()); err != nil { return ctype{}, err } From 0ec0c78ca7961216595c19b5256dc38a8ed4f3dd Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Wed, 24 Apr 2024 21:34:31 +0200 Subject: [PATCH 4/5] Fix crash and handle datetime size correctly Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/api_type_aggregation.go | 15 ++++++--------- go/vt/vtgate/evalengine/compiler_asm.go | 6 +++--- go/vt/vtgate/evalengine/compiler_test.go | 13 +++++++++++++ go/vt/vtgate/evalengine/eval.go | 8 ++++---- go/vt/vtgate/evalengine/eval_temporal.go | 2 +- go/vt/vtgate/evalengine/expr_logical.go | 6 +++--- 6 files changed, 30 insertions(+), 20 deletions(-) diff --git a/go/vt/vtgate/evalengine/api_type_aggregation.go b/go/vt/vtgate/evalengine/api_type_aggregation.go index b618f391953..45c0377bca4 100644 --- a/go/vt/vtgate/evalengine/api_type_aggregation.go +++ b/go/vt/vtgate/evalengine/api_type_aggregation.go @@ -96,6 +96,7 @@ func (ta *typeAggregation) empty() bool { func (ta *typeAggregation) addEval(e eval) { var t sqltypes.Type var f typeFlag + var size, scale int32 switch e := e.(type) { case nil: t = sqltypes.Null @@ -103,10 +104,14 @@ func (ta *typeAggregation) addEval(e eval) { case *evalBytes: t = sqltypes.Type(e.tt) f = e.flag + size = e.Size() + scale = e.Scale() default: t = e.SQLType() + size = e.Size() + scale = e.Scale() } - ta.add(t, f, e.Size(), e.Scale()) + ta.add(t, f, size, scale) } func (ta *typeAggregation) addNullable(typ sqltypes.Type, nullable bool, size, scale int32) { @@ -193,14 +198,6 @@ func nextSignedTypeForUnsigned(t sqltypes.Type) sqltypes.Type { } } -func (ta *typeAggregation) Size() int32 { - return ta.size -} - -func (ta *typeAggregation) Scale() int32 { - return ta.scale -} - func (ta *typeAggregation) result() sqltypes.Type { /* If all types are numeric, the aggregated type is also numeric: diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 07c302ac6ec..2cda3ecb348 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -516,7 +516,7 @@ func (asm *assembler) Cmp_ne_n() { }, "CMPFLAG NE [NULL]") } -func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc collations.TypedCollation, allowZeroDate bool) { +func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, size, scale int32, cc collations.TypedCollation, allowZeroDate bool) { elseOffset := 0 if hasElse { elseOffset = 1 @@ -529,12 +529,12 @@ func (asm *assembler) CmpCase(cases int, hasElse bool, tt sqltypes.Type, cc coll end := env.vm.sp - elseOffset for sp := env.vm.sp - stackDepth; sp < end; sp += 2 { if env.vm.stack[sp] != nil && env.vm.stack[sp].(*evalInt64).i != 0 { - env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, cc.Collation, env.now, allowZeroDate) + env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[sp+1], tt, size, scale, cc.Collation, env.now, allowZeroDate) goto done } } if elseOffset != 0 { - env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, cc.Collation, env.now, allowZeroDate) + env.vm.stack[env.vm.sp-stackDepth], env.vm.err = evalCoerce(env.vm.stack[env.vm.sp-1], tt, size, scale, cc.Collation, env.now, allowZeroDate) } else { env.vm.stack[env.vm.sp-stackDepth] = nil } diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 00dc538bf73..b2d4ff0c2f0 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -686,6 +686,19 @@ func TestCompilerSingle(t *testing.T) { result: `DECIMAL(0.0001)`, typeWanted: evalengine.NewTypeEx(sqltypes.Decimal, collations.CollationBinaryID, false, 4, 4, nil), }, + { + expression: `case when true then 0.02 else 1.000 end`, + result: `DECIMAL(0.02)`, + }, + { + expression: `case + when false + then timestamp'2023-10-24 12:00:00.123456' + else timestamp'2023-10-24 12:00:00' + end`, + result: `DATETIME("2023-10-24 12:00:00.000000")`, + typeWanted: evalengine.NewTypeEx(sqltypes.Datetime, collations.CollationBinaryID, false, 6, 0, nil), + }, } tz, _ := time.LoadLocation("Europe/Madrid") diff --git a/go/vt/vtgate/evalengine/eval.go b/go/vt/vtgate/evalengine/eval.go index 34ad4f4f008..49423979379 100644 --- a/go/vt/vtgate/evalengine/eval.go +++ b/go/vt/vtgate/evalengine/eval.go @@ -172,7 +172,7 @@ func evalIsTruthy(e eval) boolean { } } -func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, allowZero bool) (eval, error) { +func evalCoerce(e eval, typ sqltypes.Type, size, scale int32, col collations.ID, now time.Time, allowZero bool) (eval, error) { if e == nil { return nil, nil } @@ -183,7 +183,7 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, all // if we have an explicit VARCHAR coercion, always force it so the collation is replaced in the target return evalToVarchar(e, col, false) } - if e.SQLType() == typ { + if e.SQLType() == typ && e.Size() == size && e.Scale() == scale { // nothing to be done here return e, nil } @@ -206,9 +206,9 @@ func evalCoerce(e eval, typ sqltypes.Type, col collations.ID, now time.Time, all case sqltypes.Date: return evalToDate(e, now, allowZero), nil case sqltypes.Datetime, sqltypes.Timestamp: - return evalToDateTime(e, -1, now, allowZero), nil + return evalToDateTime(e, int(size), now, allowZero), nil case sqltypes.Time: - return evalToTime(e, -1), nil + return evalToTime(e, int(size)), nil default: return nil, vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Unsupported type conversion: %s", typ.String()) } diff --git a/go/vt/vtgate/evalengine/eval_temporal.go b/go/vt/vtgate/evalengine/eval_temporal.go index 2766c1dfb56..d73485441c3 100644 --- a/go/vt/vtgate/evalengine/eval_temporal.go +++ b/go/vt/vtgate/evalengine/eval_temporal.go @@ -43,7 +43,7 @@ func (e *evalTemporal) SQLType() sqltypes.Type { } func (e *evalTemporal) Size() int32 { - return 0 + return int32(e.prec) } func (e *evalTemporal) Scale() int32 { diff --git a/go/vt/vtgate/evalengine/expr_logical.go b/go/vt/vtgate/evalengine/expr_logical.go index 9a38405f05c..561915f600c 100644 --- a/go/vt/vtgate/evalengine/expr_logical.go +++ b/go/vt/vtgate/evalengine/expr_logical.go @@ -631,7 +631,7 @@ func (c *CaseExpr) eval(env *ExpressionEnv) (eval, error) { if !matched { return nil, nil } - return evalCoerce(result, ta.result(), ca.result().Collation, env.now, env.sqlmode.AllowZeroDate()) + return evalCoerce(result, ta.result(), ta.size, ta.scale, ca.result().Collation, env.now, env.sqlmode.AllowZeroDate()) } func (c *CaseExpr) constant() bool { @@ -712,8 +712,8 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) { if ta.nullable { f |= flagNullable } - ct := ctype{Type: ta.result(), Flag: f, Col: ca.result(), Scale: ta.Scale(), Size: ta.Size()} - c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col, c.sqlmode.AllowZeroDate()) + ct := ctype{Type: ta.result(), Flag: f, Col: ca.result(), Scale: ta.scale, Size: ta.size} + c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Size, ct.Scale, ct.Col, c.sqlmode.AllowZeroDate()) return ct, nil } From 94fe5b1278ef51d4c15ee79658c42551f13e4300 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Thu, 25 Apr 2024 12:11:17 +0530 Subject: [PATCH 5/5] refactor: split the line into multiple parts for readability Signed-off-by: Manan Gupta --- go/vt/vtgate/engine/join.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index e6dd00b687b..dc952673cfe 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -108,7 +108,9 @@ func bindvarForType(field *querypb.Field) *querypb.BindVariable { case querypb.Type_FLOAT32, querypb.Type_FLOAT64: bv.Value = []byte("0e0") case querypb.Type_DECIMAL: - bv.Value = append(append(bytes.Repeat([]byte{'0'}, max(1, int(field.ColumnLength-field.Decimals))), byte('.')), bytes.Repeat([]byte{'0'}, max(1, int(field.Decimals)))...) + size := max(1, int(field.ColumnLength-field.Decimals)) + scale := max(1, int(field.Decimals)) + bv.Value = append(append(bytes.Repeat([]byte{'0'}, size), byte('.')), bytes.Repeat([]byte{'0'}, scale)...) default: return sqltypes.NullBindVariable }