diff --git a/go/mysql/datetime/helpers.go b/go/mysql/datetime/helpers.go index b91114fd791..68466e320e2 100644 --- a/go/mysql/datetime/helpers.go +++ b/go/mysql/datetime/helpers.go @@ -17,6 +17,7 @@ limitations under the License. package datetime import ( + "strings" "time" ) @@ -287,3 +288,12 @@ func parseNanoseconds[bytes []byte | string](value bytes, nbytes int) (ns int, l const ( durationPerDay = 24 * time.Hour ) + +// SizeAndScaleFromString +func SizeFromString(s string) int32 { + idx := strings.LastIndex(s, ".") + if idx == -1 { + return 0 + } + return int32(len(s[idx+1:])) +} diff --git a/go/mysql/datetime/helpers_test.go b/go/mysql/datetime/helpers_test.go new file mode 100644 index 00000000000..cb46500bf45 --- /dev/null +++ b/go/mysql/datetime/helpers_test.go @@ -0,0 +1,77 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package datetime + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSizeFromString(t *testing.T) { + testcases := []struct { + value string + sizeExpected int32 + }{ + { + value: "2020-01-01 00:00:00", + sizeExpected: 0, + }, + { + value: "2020-01-01 00:00:00.1", + sizeExpected: 1, + }, + { + value: "2020-01-01 00:00:00.12", + sizeExpected: 2, + }, + { + value: "2020-01-01 00:00:00.123", + sizeExpected: 3, + }, + { + value: "2020-01-01 00:00:00.123456", + sizeExpected: 6, + }, + { + value: "00:00:00", + sizeExpected: 0, + }, + { + value: "00:00:00.1", + sizeExpected: 1, + }, + { + value: "00:00:00.12", + sizeExpected: 2, + }, + { + value: "00:00:00.123", + sizeExpected: 3, + }, + { + value: "00:00:00.123456", + sizeExpected: 6, + }, + } + for _, testcase := range testcases { + t.Run(testcase.value, func(t *testing.T) { + siz := SizeFromString(testcase.value) + assert.EqualValues(t, testcase.sizeExpected, siz) + }) + } +} diff --git a/go/mysql/decimal/decimal_test.go b/go/mysql/decimal/decimal_test.go index 6a6cf001231..03619a8f272 100644 --- a/go/mysql/decimal/decimal_test.go +++ b/go/mysql/decimal/decimal_test.go @@ -957,7 +957,52 @@ func TestDecimal_Cmp1(t *testing.T) { a := New(123, 3) b := New(-1234, 2) assert.Equal(t, 1, a.Cmp(b)) +} +func TestSizeAndScaleFromString(t *testing.T) { + testcases := []struct { + value string + sizeExpected int32 + scaleExpected int32 + }{ + { + value: "0.00003", + sizeExpected: 6, + scaleExpected: 5, + }, + { + value: "-0.00003", + sizeExpected: 6, + scaleExpected: 5, + }, + { + value: "12.00003", + sizeExpected: 7, + scaleExpected: 5, + }, + { + value: "-12.00003", + sizeExpected: 7, + scaleExpected: 5, + }, + { + value: "1000003", + sizeExpected: 7, + scaleExpected: 0, + }, + { + value: "-1000003", + sizeExpected: 7, + scaleExpected: 0, + }, + } + for _, testcase := range testcases { + t.Run(testcase.value, func(t *testing.T) { + siz, scale := SizeAndScaleFromString(testcase.value) + assert.EqualValues(t, testcase.sizeExpected, siz) + assert.EqualValues(t, testcase.scaleExpected, scale) + }) + } } func TestDecimal_Cmp2(t *testing.T) { diff --git a/go/mysql/decimal/scan.go b/go/mysql/decimal/scan.go index 12fc73af4e2..c56fc185287 100644 --- a/go/mysql/decimal/scan.go +++ b/go/mysql/decimal/scan.go @@ -23,6 +23,7 @@ import ( "math" "math/big" "math/bits" + "strings" "vitess.io/vitess/go/mysql/fastparse" ) @@ -71,6 +72,20 @@ func parseDecimal64(s []byte) (Decimal, error) { }, nil } +// SizeAndScaleFromString gets the size and scale for the decimal value without needing to parse it. +func SizeAndScaleFromString(s string) (int32, int32) { + switch s[0] { + case '+', '-': + s = s[1:] + } + totalLen := len(s) + idx := strings.Index(s, ".") + if idx == -1 { + return int32(totalLen), 0 + } + return int32(totalLen - 1), int32(totalLen - 1 - idx) +} + func NewFromMySQL(s []byte) (Decimal, error) { var original = s var neg bool diff --git a/go/test/endtoend/vtgate/queries/normalize/normalize_test.go b/go/test/endtoend/vtgate/queries/normalize/normalize_test.go index 735a26fc00c..51d9f9f24bf 100644 --- a/go/test/endtoend/vtgate/queries/normalize/normalize_test.go +++ b/go/test/endtoend/vtgate/queries/normalize/normalize_test.go @@ -40,11 +40,11 @@ func TestNormalizeAllFields(t *testing.T) { defer conn.Close() insertQuery := `insert into t1 values (1, "chars", "variable chars", x'73757265', 0x676F, 0.33, 9.99, 1, "1976-06-08", "small", "b", "{\"key\":\"value\"}", point(1,5), b'011', 0b0101)` - normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)` + normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL(3,2) */, :vtg7 /* DECIMAL(3,2) */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)` vtgateVersion, err := cluster.GetMajorVersion("vtgate") require.NoError(t, err) - if vtgateVersion < 19 { - normalizedInsertQuery = `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* HEXNUM */, :vtg16 /* HEXNUM */)` + if vtgateVersion < 20 { + normalizedInsertQuery = `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)` } selectQuery := "select * from t1" utils.Exec(t, conn, insertQuery) diff --git a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go index 70e0c5e1edd..255b961e15b 100644 --- a/go/test/endtoend/vtgate/queries/tpch/tpch_test.go +++ b/go/test/endtoend/vtgate/queries/tpch/tpch_test.go @@ -170,6 +170,28 @@ order by end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue from lineitem, part +where l_partkey = p_partkey + and l_shipdate >= '1996-12-01' + and l_shipdate < date_add('1996-12-01', interval '1' month);`, + }, + { + name: "Q14 without case", + query: `select 100.00 * sum(l_extendedprice * (1 - l_discount)) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from lineitem, + part +where l_partkey = p_partkey + and l_shipdate >= '1996-12-01' + and l_shipdate < date_add('1996-12-01', interval '1' month);`, + }, + { + name: "Q14", + query: `select 100.00 * sum(case + when p_type like 'PROMO%' + then l_extendedprice * (1 - l_discount) + else 0 + end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue +from lineitem, + part where l_partkey = p_partkey and l_shipdate >= '1996-12-01' and l_shipdate < date_add('1996-12-01', interval '1' month);`, diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index cfef9923530..4016d596c62 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -2317,8 +2317,9 @@ type ( // Argument represents bindvariable expression Argument struct { - Name string - Type sqltypes.Type + Name string + Type sqltypes.Type + Size, Scale int32 } // NullVal represents a NULL value. diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index c4066218859..1b083f41170 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -1869,6 +1869,8 @@ func (cmp *Comparator) RefOfArgument(a, b *Argument) bool { return false } return a.Name == b.Name && + a.Size == b.Size && + a.Scale == b.Scale && a.Type == b.Type } diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index c823f0ae487..80608b83d75 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -1357,7 +1357,15 @@ func (node *Argument) Format(buf *TrackedBuffer) { // For bind variables that are statically typed, emit their type as an adjacent comment. // This comment will be ignored by older versions of Vitess (and by MySQL) but will provide // type safety when using the query as a cache key. - buf.astPrintf(node, " /* %s */", node.Type.String()) + buf.astPrintf(node, " /* %s", node.Type.String()) + if node.Size != 0 || node.Scale != 0 { + buf.astPrintf(node, "(%d", node.Size) + if node.Scale != 0 { + buf.astPrintf(node, ",%d", node.Scale) + } + buf.WriteString(")") + } + buf.WriteString(" */") } } diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index 1f5a5229a20..28fdd119f7b 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -1780,6 +1780,15 @@ func (node *Argument) FormatFast(buf *TrackedBuffer) { // type safety when using the query as a cache key. buf.WriteString(" /* ") buf.WriteString(node.Type.String()) + if node.Size != 0 || node.Scale != 0 { + buf.WriteByte('(') + buf.WriteString(fmt.Sprintf("%d", node.Size)) + if node.Scale != 0 { + buf.WriteByte(',') + buf.WriteString(fmt.Sprintf("%d", node.Scale)) + } + buf.WriteString(")") + } buf.WriteString(" */") } } diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 8019645c250..3943aa84b1f 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -24,6 +24,8 @@ import ( "strconv" "strings" + "vitess.io/vitess/go/mysql/datetime" + "vitess.io/vitess/go/mysql/decimal" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/log" querypb "vitess.io/vitess/go/vt/proto/query" @@ -562,6 +564,20 @@ func NewTypedArgument(in string, t sqltypes.Type) *Argument { return &Argument{Name: in, Type: t} } +func NewTypedArgumentFromLiteral(in string, lit *Literal) (*Argument, error) { + arg := &Argument{Name: in, Type: lit.SQLType()} + switch arg.Type { + case sqltypes.Decimal: + siz, scale := decimal.SizeAndScaleFromString(lit.Val) + arg.Scale = scale + arg.Size = siz + case sqltypes.Datetime, sqltypes.Time: + siz := datetime.SizeFromString(lit.Val) + arg.Size = siz + } + return arg, nil +} + // NewListArg builds a new ListArg. func NewListArg(in string) ListArg { return ListArg(in) diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index 361888727b2..6160eabacfa 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -351,7 +351,7 @@ func (cached *Argument) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(24) + size += int64(32) } // field Name string size += hack.RuntimeAllocSize(int64(len(cached.Name))) diff --git a/go/vt/sqlparser/normalizer.go b/go/vt/sqlparser/normalizer.go index 0254dccdfb2..1492ecfac90 100644 --- a/go/vt/sqlparser/normalizer.go +++ b/go/vt/sqlparser/normalizer.go @@ -207,7 +207,12 @@ func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) { } // Modify the AST node to a bindvar. - cursor.Replace(NewTypedArgument(bvname, node.SQLType())) + arg, err := NewTypedArgumentFromLiteral(bvname, node) + if err != nil { + nz.err = err + return + } + cursor.Replace(arg) } // convertLiteral converts an Literal without the dedup. @@ -224,7 +229,12 @@ func (nz *normalizer) convertLiteral(node *Literal, cursor *Cursor) { bvname := nz.reserved.nextUnusedVar() nz.bindVars[bvname] = bval - cursor.Replace(NewTypedArgument(bvname, node.SQLType())) + arg, err := NewTypedArgumentFromLiteral(bvname, node) + if err != nil { + nz.err = err + return + } + cursor.Replace(arg) } // convertComparison attempts to convert IN clauses to @@ -268,7 +278,12 @@ func (nz *normalizer) parameterize(left, right Expr) Expr { return nil } bvname := nz.decideBindVarName(lit, col, bval) - return NewTypedArgument(bvname, lit.SQLType()) + arg, err := NewTypedArgumentFromLiteral(bvname, lit) + if err != nil { + nz.err = err + return nil + } + return arg } func (nz *normalizer) decideBindVarName(lit *Literal, col *ColName, bval *querypb.BindVariable) string { diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 42e26caa39d..7ad6f06f52b 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -75,14 +75,28 @@ func TestNormalize(t *testing.T) { }, { // float val in: "select * from t where foobar = 1.2", - outstmt: "select * from t where foobar = :foobar /* DECIMAL */", + outstmt: "select * from t where foobar = :foobar /* DECIMAL(2,1) */", outbv: map[string]*querypb.BindVariable{ "foobar": sqltypes.DecimalBindVariable("1.2"), }, + }, { + // datetime val + in: "select * from t where foobar = timestamp'2012-02-29 12:34:56.123456'", + outstmt: "select * from t where foobar = :foobar /* DATETIME(6) */", + outbv: map[string]*querypb.BindVariable{ + "foobar": sqltypes.ValueBindVariable(sqltypes.NewDatetime("2012-02-29 12:34:56.123456")), + }, + }, { + // time val + in: "select * from t where foobar = time'12:34:56.123456'", + outstmt: "select * from t where foobar = :foobar /* TIME(6) */", + outbv: map[string]*querypb.BindVariable{ + "foobar": sqltypes.ValueBindVariable(sqltypes.NewTime("12:34:56.123456")), + }, }, { // multiple vals in: "select * from t where foo = 1.2 and bar = 2", - outstmt: "select * from t where foo = :foo /* DECIMAL */ and bar = :bar /* INT64 */", + outstmt: "select * from t where foo = :foo /* DECIMAL(2,1) */ and bar = :bar /* INT64 */", outbv: map[string]*querypb.BindVariable{ "foo": sqltypes.DecimalBindVariable("1.2"), "bar": sqltypes.Int64BindVariable(2), diff --git a/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt b/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt index aab1ab0234f..b63683ca274 100644 --- a/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt +++ b/go/vt/vtexplain/testdata/multi-output/unsharded-output.txt @@ -24,7 +24,7 @@ update t1 set intval = 10 update t1 set floatval = 9.99 1 ks_unsharded/-: begin -1 ks_unsharded/-: update t1 set floatval = 9.99 limit 10001 /* DECIMAL */ +1 ks_unsharded/-: update t1 set floatval = 9.99 limit 10001 /* DECIMAL(3,2) */ 1 ks_unsharded/-: commit ---------------------------------------------------------------------- @@ -37,7 +37,7 @@ delete from t1 where id = 100 ---------------------------------------------------------------------- insert into t1 (id,intval,floatval) values (1,2,3.14) on duplicate key update intval=3, floatval=3.14 -1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, 3.14) on duplicate key update intval = 3, floatval = 3.14 /* DECIMAL */ +1 ks_unsharded/-: insert into t1(id, intval, floatval) values (1, 2, 3.14) on duplicate key update intval = 3, floatval = 3.14 /* DECIMAL(3,2) */ ---------------------------------------------------------------------- select ID from t1 diff --git a/go/vt/vtgate/semantics/typer.go b/go/vt/vtgate/semantics/typer.go index 54261339114..b56c836a740 100644 --- a/go/vt/vtgate/semantics/typer.go +++ b/go/vt/vtgate/semantics/typer.go @@ -47,7 +47,7 @@ func (t *typer) up(cursor *sqlparser.Cursor) error { t.m[node] = evalengine.NewType(node.SQLType(), collations.CollationForType(node.SQLType(), t.collationEnv.DefaultConnectionCharset())) case *sqlparser.Argument: if node.Type >= 0 { - t.m[node] = evalengine.NewType(node.Type, collations.CollationForType(node.Type, t.collationEnv.DefaultConnectionCharset())) + t.m[node] = evalengine.NewTypeEx(node.Type, collations.CollationForType(node.Type, t.collationEnv.DefaultConnectionCharset()), true, node.Size, node.Scale, nil) } case sqlparser.AggrFunc: code, ok := opcode.SupportedAggregates[node.AggrName()]