From 6ba94f281f5663edb87ed4303e443087098fb318 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Tue, 23 Apr 2024 16:58:36 +0530 Subject: [PATCH 01/10] feat: add failing tpch query and fix Signed-off-by: Manan Gupta --- go/mysql/decimal/decimal.go | 4 ++++ .../endtoend/vtgate/queries/tpch/tpch_test.go | 13 +++++++++++++ go/vt/sqlparser/ast.go | 5 +++-- go/vt/sqlparser/ast_equals.go | 2 ++ go/vt/sqlparser/ast_format.go | 10 +++++++++- go/vt/sqlparser/ast_format_fast.go | 9 +++++++++ go/vt/sqlparser/ast_funcs.go | 15 +++++++++++++++ go/vt/sqlparser/cached_size.go | 2 +- go/vt/sqlparser/normalizer.go | 18 +++++++++++++++--- go/vt/vtgate/semantics/typer.go | 2 +- 10 files changed, 72 insertions(+), 8 deletions(-) diff --git a/go/mysql/decimal/decimal.go b/go/mysql/decimal/decimal.go index a2b505a1232..d320d940641 100644 --- a/go/mysql/decimal/decimal.go +++ b/go/mysql/decimal/decimal.go @@ -579,6 +579,10 @@ func (d Decimal) Exponent() int32 { return d.exp } +func (d Decimal) Size() int32 { + return int32(len(d.value.String())) +} + func (d Decimal) Int64() (int64, bool) { scaledD := d.rescale(0) return scaledD.value.Int64(), scaledD.value.IsInt64() diff --git a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go index 513aea94a86..d981bcaf743 100644 --- a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go +++ b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go @@ -161,6 +161,19 @@ group by order by value desc;`, }, + { + name: "Q14", + query: `select 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from lineitem, + part +where l_partkey = p_partkey + and l_shipdate >= '1996-12-01' + and l_shipdate < date_add('1996-12-01', interval '1' month);`, + }, } for _, testcase := range testcases { diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index cfef9923530..4016d596c62 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -2317,8 +2317,9 @@ type ( // Argument represents bindvariable expression Argument struct { - Name string - Type sqltypes.Type + Name string + Type sqltypes.Type + Size, Scale int32 } // NullVal represents a NULL value. diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index c4066218859..1b083f41170 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -1869,6 +1869,8 @@ func (cmp *Comparator) RefOfArgument(a, b *Argument) bool { return false } return a.Name == b.Name && + a.Size == b.Size && + a.Scale == b.Scale && a.Type == b.Type } diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index c823f0ae487..ea9b45e45c6 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -1357,7 +1357,15 @@ func (node *Argument) Format(buf *TrackedBuffer) { // For bind variables that are statically typed, emit their type as an adjacent comment. // This comment will be ignored by older versions of Vitess (and by MySQL) but will provide // type safety when using the query as a cache key. - buf.astPrintf(node, " /* %s */", node.Type.String()) + buf.astPrintf(node, " /* %s", node.Type.String()) + if node.Size != 0 || node.Scale != 0 { + buf.astPrintf(node, "(%d", node.Size) + if node.Scale != 0 { + buf.astPrintf(node, ", %d", node.Scale) + } + buf.WriteString(")") + } + buf.WriteString(" */") } } diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index 1f5a5229a20..61f9e67c4ca 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -1780,6 +1780,15 @@ func (node *Argument) FormatFast(buf *TrackedBuffer) { // type safety when using the query as a cache key. buf.WriteString(" /* ") buf.WriteString(node.Type.String()) + if node.Size != 0 || node.Scale != 0 { + buf.WriteByte('(') + buf.WriteString(fmt.Sprintf("%d", node.Size)) + if node.Scale != 0 { + buf.WriteString(", ") + buf.WriteString(fmt.Sprintf("%d", node.Scale)) + } + buf.WriteString(")") + } buf.WriteString(" */") } } diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 8019645c250..9c1b4d20b95 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -24,6 +24,7 @@ import ( "strconv" "strings" + "vitess.io/vitess/go/mysql/decimal" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" @@ -562,6 +563,20 @@ func NewTypedArgument(in string, t sqltypes.Type) *Argument { return &Argument{Name: in, Type: t} } +func NewTypedArgumentFromLiteral(in string, lit *Literal) (*Argument, error) { + arg := &Argument{Name: in, Type: lit.SQLType()} + switch arg.Type { + case sqltypes.Decimal: + dec, err := decimal.NewFromMySQL(lit.Bytes()) + if err != nil { + return nil, err + } + arg.Scale = -dec.Exponent() + arg.Size = dec.Size() + } + return arg, nil +} + // NewListArg builds a new ListArg. func NewListArg(in string) ListArg { return ListArg(in) diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index 361888727b2..6160eabacfa 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -351,7 +351,7 @@ func (cached *Argument) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(24) + size += int64(32) } // field Name string size += hack.RuntimeAllocSize(int64(len(cached.Name))) diff --git a/go/vt/sqlparser/normalizer.go b/go/vt/sqlparser/normalizer.go index 0254dccdfb2..e826e497afa 100644 --- a/go/vt/sqlparser/normalizer.go +++ b/go/vt/sqlparser/normalizer.go @@ -207,7 +207,11 @@ func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) { } // Modify the AST node to a bindvar. - cursor.Replace(NewTypedArgument(bvname, node.SQLType())) + arg, err := NewTypedArgumentFromLiteral(bvname, node) + if err != nil { + nz.err = err + } + cursor.Replace(arg) } // convertLiteral converts an Literal without the dedup. @@ -224,7 +228,11 @@ func (nz *normalizer) convertLiteral(node *Literal, cursor *Cursor) { bvname := nz.reserved.nextUnusedVar() nz.bindVars[bvname] = bval - cursor.Replace(NewTypedArgument(bvname, node.SQLType())) + arg, err := NewTypedArgumentFromLiteral(bvname, node) + if err != nil { + nz.err = err + } + cursor.Replace(arg) } // convertComparison attempts to convert IN clauses to @@ -268,7 +276,11 @@ func (nz *normalizer) parameterize(left, right Expr) Expr { return nil } bvname := nz.decideBindVarName(lit, col, bval) - return NewTypedArgument(bvname, lit.SQLType()) + arg, err := NewTypedArgumentFromLiteral(bvname, lit) + if err != nil { + nz.err = err + } + return arg } func (nz *normalizer) decideBindVarName(lit *Literal, col *ColName, bval *querypb.BindVariable) string { diff --git a/go/vt/vtgate/semantics/typer.go b/go/vt/vtgate/semantics/typer.go index 54261339114..5b98b8eae23 100644 --- a/go/vt/vtgate/semantics/typer.go +++ b/go/vt/vtgate/semantics/typer.go @@ -47,7 +47,7 @@ func (t *typer) up(cursor *sqlparser.Cursor) error { t.m[node] = evalengine.NewType(node.SQLType(), collations.CollationForType(node.SQLType(), t.collationEnv.DefaultConnectionCharset())) case *sqlparser.Argument: if node.Type >= 0 { - t.m[node] = evalengine.NewType(node.Type, collations.CollationForType(node.Type, t.collationEnv.DefaultConnectionCharset())) + t.m[node] = evalengine.NewTypeEx(node.Type, collations.CollationForType(node.Type, t.collationEnv.DefaultConnectionCharset()), true, node.Size, node.Scale) } case sqlparser.AggrFunc: code, ok := opcode.SupportedAggregates[node.AggrName()] From 2ae9b2e77037eef8541084a72d58712d803ef580 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Tue, 23 Apr 2024 17:05:47 +0530 Subject: [PATCH 02/10] test: remove case from the query since that's a separate issue Signed-off-by: Manan Gupta --- go/test/endtoend/vtgate/queries/tpch/tpch_test.go | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go index d981bcaf743..de26908efa8 100644 --- a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go +++ b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go @@ -162,12 +162,8 @@ order by value desc;`, }, { - name: "Q14", - query: `select 100.00 * sum(case - when p_type like 'PROMO%' - then l_extendedprice * (1 - l_discount) - else 0 - end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue + name: "Q14 without case", + query: `select 100.00 * sum(l_extendedprice * (1 - l_discount)) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue from lineitem, part where l_partkey = p_partkey From 5310d1b1b8501f6297d0c4858c3ab2a4defb606d Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Tue, 23 Apr 2024 17:29:50 +0530 Subject: [PATCH 03/10] feat: fix size implementation Signed-off-by: Manan Gupta --- go/mysql/decimal/decimal.go | 6 ++++- go/mysql/decimal/decimal_test.go | 43 +++++++++++++++++++++++++++--- go/vt/sqlparser/normalizer_test.go | 4 +-- 3 files changed, 47 insertions(+), 6 deletions(-) diff --git a/go/mysql/decimal/decimal.go b/go/mysql/decimal/decimal.go index d320d940641..ddb2ffe3b5a 100644 --- a/go/mysql/decimal/decimal.go +++ b/go/mysql/decimal/decimal.go @@ -580,7 +580,11 @@ func (d Decimal) Exponent() int32 { } func (d Decimal) Size() int32 { - return int32(len(d.value.String())) + digitsCount := int32(len(d.value.String())) + if d.value.Sign() == -1 { + digitsCount-- + } + return max(digitsCount, -d.exp) } func (d Decimal) Int64() (int64, bool) { diff --git a/go/mysql/decimal/decimal_test.go b/go/mysql/decimal/decimal_test.go index 6a6cf001231..8c8ea765df9 100644 --- a/go/mysql/decimal/decimal_test.go +++ b/go/mysql/decimal/decimal_test.go @@ -28,6 +28,7 @@ import ( "testing/quick" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type testEnt struct { @@ -954,9 +955,45 @@ func TestDecimal_ScalesNotEqual(t *testing.T) { } func TestDecimal_Cmp1(t *testing.T) { - a := New(123, 3) - b := New(-1234, 2) - assert.Equal(t, 1, a.Cmp(b)) + testcases := []struct { + value string + sizeExpected int32 + }{ + { + value: "0.00003", + sizeExpected: 5, + }, + { + value: "-0.00003", + sizeExpected: 5, + }, + { + value: "12.00003", + sizeExpected: 7, + }, + { + value: "-12.00003", + sizeExpected: 7, + }, + { + value: "1000003", + sizeExpected: 7, + }, + { + value: "-1000003", + sizeExpected: 7, + }, + } + for _, testcase := range testcases { + t.Run(testcase.value, func(t *testing.T) { + val, err := NewFromString(testcase.value) + require.NoError(t, err) + assert.EqualValues(t, testcase.sizeExpected, val.Size()) + }) + } +} + +func TestDecimal_Size(t *testing.T) { } diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 42e26caa39d..1de111ddb13 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -75,14 +75,14 @@ func TestNormalize(t *testing.T) { }, { // float val in: "select * from t where foobar = 1.2", - outstmt: "select * from t where foobar = :foobar /* DECIMAL */", + outstmt: "select * from t where foobar = :foobar /* DECIMAL(2, 1) */", outbv: map[string]*querypb.BindVariable{ "foobar": sqltypes.DecimalBindVariable("1.2"), }, }, { // multiple vals in: "select * from t where foo = 1.2 and bar = 2", - outstmt: "select * from t where foo = :foo /* DECIMAL */ and bar = :bar /* INT64 */", + outstmt: "select * from t where foo = :foo /* DECIMAL(2, 1) */ and bar = :bar /* INT64 */", outbv: map[string]*querypb.BindVariable{ "foo": sqltypes.DecimalBindVariable("1.2"), "bar": sqltypes.Int64BindVariable(2), From 08ff1cb714c61e126664bef26d5a11ea79029aa0 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Wed, 24 Apr 2024 13:19:57 +0530 Subject: [PATCH 04/10] refactor: address review comments Signed-off-by: Manan Gupta --- go/vt/sqlparser/ast_format.go | 2 +- go/vt/sqlparser/ast_format_fast.go | 2 +- go/vt/sqlparser/normalizer.go | 3 +++ go/vt/sqlparser/normalizer_test.go | 4 ++-- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index ea9b45e45c6..80608b83d75 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -1361,7 +1361,7 @@ func (node *Argument) Format(buf *TrackedBuffer) { if node.Size != 0 || node.Scale != 0 { buf.astPrintf(node, "(%d", node.Size) if node.Scale != 0 { - buf.astPrintf(node, ", %d", node.Scale) + buf.astPrintf(node, ",%d", node.Scale) } buf.WriteString(")") } diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index 61f9e67c4ca..28fdd119f7b 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -1784,7 +1784,7 @@ func (node *Argument) FormatFast(buf *TrackedBuffer) { buf.WriteByte('(') buf.WriteString(fmt.Sprintf("%d", node.Size)) if node.Scale != 0 { - buf.WriteString(", ") + buf.WriteByte(',') buf.WriteString(fmt.Sprintf("%d", node.Scale)) } buf.WriteString(")") diff --git a/go/vt/sqlparser/normalizer.go b/go/vt/sqlparser/normalizer.go index e826e497afa..1492ecfac90 100644 --- a/go/vt/sqlparser/normalizer.go +++ b/go/vt/sqlparser/normalizer.go @@ -210,6 +210,7 @@ func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) { arg, err := NewTypedArgumentFromLiteral(bvname, node) if err != nil { nz.err = err + return } cursor.Replace(arg) } @@ -231,6 +232,7 @@ func (nz *normalizer) convertLiteral(node *Literal, cursor *Cursor) { arg, err := NewTypedArgumentFromLiteral(bvname, node) if err != nil { nz.err = err + return } cursor.Replace(arg) } @@ -279,6 +281,7 @@ func (nz *normalizer) parameterize(left, right Expr) Expr { arg, err := NewTypedArgumentFromLiteral(bvname, lit) if err != nil { nz.err = err + return nil } return arg } diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 1de111ddb13..b9b33347477 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -75,14 +75,14 @@ func TestNormalize(t *testing.T) { }, { // float val in: "select * from t where foobar = 1.2", - outstmt: "select * from t where foobar = :foobar /* DECIMAL(2, 1) */", + outstmt: "select * from t where foobar = :foobar /* DECIMAL(2,1) */", outbv: map[string]*querypb.BindVariable{ "foobar": sqltypes.DecimalBindVariable("1.2"), }, }, { // multiple vals in: "select * from t where foo = 1.2 and bar = 2", - outstmt: "select * from t where foo = :foo /* DECIMAL(2, 1) */ and bar = :bar /* INT64 */", + outstmt: "select * from t where foo = :foo /* DECIMAL(2,1) */ and bar = :bar /* INT64 */", outbv: map[string]*querypb.BindVariable{ "foo": sqltypes.DecimalBindVariable("1.2"), "bar": sqltypes.Int64BindVariable(2), From f9ec1b735aa8e1c8d7b2cbb41d57fa8641213850 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Wed, 24 Apr 2024 13:22:31 +0530 Subject: [PATCH 05/10] feat: fix build issues Signed-off-by: Manan Gupta --- go/vt/vtgate/semantics/typer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/vt/vtgate/semantics/typer.go b/go/vt/vtgate/semantics/typer.go index 5b98b8eae23..b56c836a740 100644 --- a/go/vt/vtgate/semantics/typer.go +++ b/go/vt/vtgate/semantics/typer.go @@ -47,7 +47,7 @@ func (t *typer) up(cursor *sqlparser.Cursor) error { t.m[node] = evalengine.NewType(node.SQLType(), collations.CollationForType(node.SQLType(), t.collationEnv.DefaultConnectionCharset())) case *sqlparser.Argument: if node.Type >= 0 { - t.m[node] = evalengine.NewTypeEx(node.Type, collations.CollationForType(node.Type, t.collationEnv.DefaultConnectionCharset()), true, node.Size, node.Scale) + t.m[node] = evalengine.NewTypeEx(node.Type, collations.CollationForType(node.Type, t.collationEnv.DefaultConnectionCharset()), true, node.Size, node.Scale, nil) } case sqlparser.AggrFunc: code, ok := opcode.SupportedAggregates[node.AggrName()] From db20c3ba6490ca6d026a738c18eeaf48629efb6e Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Wed, 24 Apr 2024 14:00:27 +0530 Subject: [PATCH 06/10] refactor: improve performance Signed-off-by: Manan Gupta --- go/mysql/decimal/decimal.go | 8 ----- go/mysql/decimal/decimal_test.go | 52 ++++++++++++++++++-------------- go/mysql/decimal/scan.go | 16 ++++++++++ go/vt/sqlparser/ast_funcs.go | 9 ++---- 4 files changed, 49 insertions(+), 36 deletions(-) diff --git a/go/mysql/decimal/decimal.go b/go/mysql/decimal/decimal.go index ddb2ffe3b5a..a2b505a1232 100644 --- a/go/mysql/decimal/decimal.go +++ b/go/mysql/decimal/decimal.go @@ -579,14 +579,6 @@ func (d Decimal) Exponent() int32 { return d.exp } -func (d Decimal) Size() int32 { - digitsCount := int32(len(d.value.String())) - if d.value.Sign() == -1 { - digitsCount-- - } - return max(digitsCount, -d.exp) -} - func (d Decimal) Int64() (int64, bool) { scaledD := d.rescale(0) return scaledD.value.Int64(), scaledD.value.IsInt64() diff --git a/go/mysql/decimal/decimal_test.go b/go/mysql/decimal/decimal_test.go index 8c8ea765df9..03619a8f272 100644 --- a/go/mysql/decimal/decimal_test.go +++ b/go/mysql/decimal/decimal_test.go @@ -28,7 +28,6 @@ import ( "testing/quick" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) type testEnt struct { @@ -955,48 +954,57 @@ func TestDecimal_ScalesNotEqual(t *testing.T) { } func TestDecimal_Cmp1(t *testing.T) { + a := New(123, 3) + b := New(-1234, 2) + assert.Equal(t, 1, a.Cmp(b)) +} + +func TestSizeAndScaleFromString(t *testing.T) { testcases := []struct { - value string - sizeExpected int32 + value string + sizeExpected int32 + scaleExpected int32 }{ { - value: "0.00003", - sizeExpected: 5, + value: "0.00003", + sizeExpected: 6, + scaleExpected: 5, }, { - value: "-0.00003", - sizeExpected: 5, + value: "-0.00003", + sizeExpected: 6, + scaleExpected: 5, }, { - value: "12.00003", - sizeExpected: 7, + value: "12.00003", + sizeExpected: 7, + scaleExpected: 5, }, { - value: "-12.00003", - sizeExpected: 7, + value: "-12.00003", + sizeExpected: 7, + scaleExpected: 5, }, { - value: "1000003", - sizeExpected: 7, + value: "1000003", + sizeExpected: 7, + scaleExpected: 0, }, { - value: "-1000003", - sizeExpected: 7, + value: "-1000003", + sizeExpected: 7, + scaleExpected: 0, }, } for _, testcase := range testcases { t.Run(testcase.value, func(t *testing.T) { - val, err := NewFromString(testcase.value) - require.NoError(t, err) - assert.EqualValues(t, testcase.sizeExpected, val.Size()) + siz, scale := SizeAndScaleFromString(testcase.value) + assert.EqualValues(t, testcase.sizeExpected, siz) + assert.EqualValues(t, testcase.scaleExpected, scale) }) } } -func TestDecimal_Size(t *testing.T) { - -} - func TestDecimal_Cmp2(t *testing.T) { a := New(123, 3) b := New(1234, 2) diff --git a/go/mysql/decimal/scan.go b/go/mysql/decimal/scan.go index 12fc73af4e2..863c0cacdc9 100644 --- a/go/mysql/decimal/scan.go +++ b/go/mysql/decimal/scan.go @@ -23,6 +23,7 @@ import ( "math" "math/big" "math/bits" + "strings" "vitess.io/vitess/go/mysql/fastparse" ) @@ -71,6 +72,21 @@ func parseDecimal64(s []byte) (Decimal, error) { }, nil } +// SizeAndScaleFromString +func SizeAndScaleFromString(s string) (int32, int32) { + sign := 0 + switch s[0] { + case '+', '-': + sign = 1 + } + lenWithoutSign := len(s) - sign + idx := strings.Index(s, ".") + if idx == -1 { + return int32(lenWithoutSign), 0 + } + return int32(lenWithoutSign - 1), int32(len(s) - 1 - idx) +} + func NewFromMySQL(s []byte) (Decimal, error) { var original = s var neg bool diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 9c1b4d20b95..1930b00e663 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -567,12 +567,9 @@ func NewTypedArgumentFromLiteral(in string, lit *Literal) (*Argument, error) { arg := &Argument{Name: in, Type: lit.SQLType()} switch arg.Type { case sqltypes.Decimal: - dec, err := decimal.NewFromMySQL(lit.Bytes()) - if err != nil { - return nil, err - } - arg.Scale = -dec.Exponent() - arg.Size = dec.Size() + siz, scale := decimal.SizeAndScaleFromString(lit.Val) + arg.Scale = scale + arg.Size = siz } return arg, nil } From a021f07f6627fa5d64f047a5490fa6999a0cba1f Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Wed, 24 Apr 2024 22:07:13 +0200 Subject: [PATCH 07/10] Add size handling for datetime / time as well Signed-off-by: Dirkjan Bussink --- go/mysql/datetime/helpers.go | 10 ++++ go/mysql/datetime/helpers_test.go | 77 ++++++++++++++++++++++++++++++ go/vt/sqlparser/ast_funcs.go | 4 ++ go/vt/sqlparser/normalizer_test.go | 14 ++++++ 4 files changed, 105 insertions(+) create mode 100644 go/mysql/datetime/helpers_test.go diff --git a/go/mysql/datetime/helpers.go b/go/mysql/datetime/helpers.go index b91114fd791..68466e320e2 100644 --- a/go/mysql/datetime/helpers.go +++ b/go/mysql/datetime/helpers.go @@ -17,6 +17,7 @@ limitations under the License. package datetime import ( + "strings" "time" ) @@ -287,3 +288,12 @@ func parseNanoseconds[bytes []byte | string](value bytes, nbytes int) (ns int, l const ( durationPerDay = 24 * time.Hour ) + +// SizeAndScaleFromString +func SizeFromString(s string) int32 { + idx := strings.LastIndex(s, ".") + if idx == -1 { + return 0 + } + return int32(len(s[idx+1:])) +} diff --git a/go/mysql/datetime/helpers_test.go b/go/mysql/datetime/helpers_test.go new file mode 100644 index 00000000000..cb46500bf45 --- /dev/null +++ b/go/mysql/datetime/helpers_test.go @@ -0,0 +1,77 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package datetime + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSizeFromString(t *testing.T) { + testcases := []struct { + value string + sizeExpected int32 + }{ + { + value: "2020-01-01 00:00:00", + sizeExpected: 0, + }, + { + value: "2020-01-01 00:00:00.1", + sizeExpected: 1, + }, + { + value: "2020-01-01 00:00:00.12", + sizeExpected: 2, + }, + { + value: "2020-01-01 00:00:00.123", + sizeExpected: 3, + }, + { + value: "2020-01-01 00:00:00.123456", + sizeExpected: 6, + }, + { + value: "00:00:00", + sizeExpected: 0, + }, + { + value: "00:00:00.1", + sizeExpected: 1, + }, + { + value: "00:00:00.12", + sizeExpected: 2, + }, + { + value: "00:00:00.123", + sizeExpected: 3, + }, + { + value: "00:00:00.123456", + sizeExpected: 6, + }, + } + for _, testcase := range testcases { + t.Run(testcase.value, func(t *testing.T) { + siz := SizeFromString(testcase.value) + assert.EqualValues(t, testcase.sizeExpected, siz) + }) + } +} diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 1930b00e663..3943aa84b1f 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -24,6 +24,7 @@ import ( "strconv" "strings" + "vitess.io/vitess/go/mysql/datetime" "vitess.io/vitess/go/mysql/decimal" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/log" @@ -570,6 +571,9 @@ func NewTypedArgumentFromLiteral(in string, lit *Literal) (*Argument, error) { siz, scale := decimal.SizeAndScaleFromString(lit.Val) arg.Scale = scale arg.Size = siz + case sqltypes.Datetime, sqltypes.Time: + siz := datetime.SizeFromString(lit.Val) + arg.Size = siz } return arg, nil } diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index b9b33347477..7ad6f06f52b 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -79,6 +79,20 @@ func TestNormalize(t *testing.T) { outbv: map[string]*querypb.BindVariable{ "foobar": sqltypes.DecimalBindVariable("1.2"), }, + }, { + // datetime val + in: "select * from t where foobar = timestamp'2012-02-29 12:34:56.123456'", + outstmt: "select * from t where foobar = :foobar /* DATETIME(6) */", + outbv: map[string]*querypb.BindVariable{ + "foobar": sqltypes.ValueBindVariable(sqltypes.NewDatetime("2012-02-29 12:34:56.123456")), + }, + }, { + // time val + in: "select * from t where foobar = time'12:34:56.123456'", + outstmt: "select * from t where foobar = :foobar /* TIME(6) */", + outbv: map[string]*querypb.BindVariable{ + "foobar": sqltypes.ValueBindVariable(sqltypes.NewTime("12:34:56.123456")), + }, }, { // multiple vals in: "select * from t where foo = 1.2 and bar = 2", From 019684bcaad3900a11419b2f7128615b7cbded72 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Thu, 25 Apr 2024 12:15:10 +0530 Subject: [PATCH 08/10] refactor: simplify code Signed-off-by: Manan Gupta --- go/mysql/decimal/scan.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/go/mysql/decimal/scan.go b/go/mysql/decimal/scan.go index 863c0cacdc9..c56fc185287 100644 --- a/go/mysql/decimal/scan.go +++ b/go/mysql/decimal/scan.go @@ -72,19 +72,18 @@ func parseDecimal64(s []byte) (Decimal, error) { }, nil } -// SizeAndScaleFromString +// SizeAndScaleFromString gets the size and scale for the decimal value without needing to parse it. func SizeAndScaleFromString(s string) (int32, int32) { - sign := 0 switch s[0] { case '+', '-': - sign = 1 + s = s[1:] } - lenWithoutSign := len(s) - sign + totalLen := len(s) idx := strings.Index(s, ".") if idx == -1 { - return int32(lenWithoutSign), 0 + return int32(totalLen), 0 } - return int32(lenWithoutSign - 1), int32(len(s) - 1 - idx) + return int32(totalLen - 1), int32(totalLen - 1 - idx) } func NewFromMySQL(s []byte) (Decimal, error) { From 3b39dc13465230bd8f7cc38a047c0809a2062fda Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Thu, 25 Apr 2024 13:23:10 +0530 Subject: [PATCH 09/10] test: fix test expectations Signed-off-by: Manan Gupta --- go/test/endtoend/vtgate/queries/normalize/normalize_test.go | 6 +++--- go/vt/vtexplain/testdata/multi-output/unsharded-output.txt | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/go/test/endtoend/vtgate/queries/normalize/normalize_test.go b/go/test/endtoend/vtgate/queries/normalize/normalize_test.go index 735a26fc00c..51d9f9f24bf 100644 --- a/go/test/endtoend/vtgate/queries/normalize/normalize_test.go +++ b/go/test/endtoend/vtgate/queries/normalize/normalize_test.go @@ -40,11 +40,11 @@ func TestNormalizeAllFields(t *testing.T) { defer conn.Close() insertQuery := `insert into t1 values (1, "chars", "variable chars", x'73757265', 0x676F, 0.33, 9.99, 1, "1976-06-08", "small", "b", "{\"key\":\"value\"}", point(1,5), b'011', 0b0101)` - normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)` + normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL(3,2) */, :vtg7 /* DECIMAL(3,2) */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)` vtgateVersion, err := cluster.GetMajorVersion("vtgate") require.NoError(t, err) - if vtgateVersion < 19 { - normalizedInsertQuery = `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* HEXNUM */, :vtg16 /* HEXNUM */)` + if vtgateVersion < 20 { + normalizedInsertQuery = `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)` } selectQuery := "select * from t1" utils.Exec(t, conn, insertQuery) diff --git a/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt b/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt index aab1ab0234f..b63683ca274 100644 --- a/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt +++ b/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt @@ -24,7 +24,7 @@ update t1 set intval = 10 update t1 set floatval = 9.99 1 ks_unsharded/-: begin -1 ks_unsharded/-: update t1 set floatval = 9.99 limit 10001 /* DECIMAL */ +1 ks_unsharded/-: update t1 set floatval = 9.99 limit 10001 /* DECIMAL(3,2) */ 1 ks_unsharded/-: commit ---------------------------------------------------------------------- @@ -37,7 +37,7 @@ delete from t1 where id = 100 ---------------------------------------------------------------------- insert into t1 (id,intval,floatval) values (1,2,3.14) on duplicate key update intval=3, floatval=3.14 -1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, 3.14) on duplicate key update intval = 3, floatval = 3.14 /* DECIMAL */ +1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, 3.14) on duplicate key update intval = 3, floatval = 3.14 /* DECIMAL(3,2) */ ---------------------------------------------------------------------- select ID from t1 From 0588f639d5b5e496725784dccb7648d8b4c02de7 Mon Sep 17 00:00:00 2001 From: Manan Gupta Date: Thu, 25 Apr 2024 13:29:15 +0530 Subject: [PATCH 10/10] test: add complete Q14 test Signed-off-by: Manan Gupta --- go/test/endtoend/vtgate/queries/tpch/tpch_test.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go index 96ac5bf0923..255b961e15b 100644 --- a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go +++ b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go @@ -179,6 +179,19 @@ where l_partkey = p_partkey query: `select 100.00 * sum(l_extendedprice * (1 - l_discount)) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue from lineitem, part +where l_partkey = p_partkey + and l_shipdate >= '1996-12-01' + and l_shipdate < date_add('1996-12-01', interval '1' month);`, + }, + { + name: "Q14", + query: `select 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from lineitem, + part where l_partkey = p_partkey and l_shipdate >= '1996-12-01' and l_shipdate < date_add('1996-12-01', interval '1' month);`,