Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Store Decimal precision and size while normalising #15785

Merged
merged 12 commits into from
Apr 25, 2024
10 changes: 10 additions & 0 deletions go/mysql/datetime/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package datetime

import (
"strings"
"time"
)

Expand Down Expand Up @@ -287,3 +288,12 @@ func parseNanoseconds[bytes []byte | string](value bytes, nbytes int) (ns int, l
const (
durationPerDay = 24 * time.Hour
)

// SizeAndScaleFromString
func SizeFromString(s string) int32 {
idx := strings.LastIndex(s, ".")
if idx == -1 {
return 0
}
return int32(len(s[idx+1:]))
}
77 changes: 77 additions & 0 deletions go/mysql/datetime/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
Copyright 2024 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package datetime

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestSizeFromString(t *testing.T) {
testcases := []struct {
value string
sizeExpected int32
}{
{
value: "2020-01-01 00:00:00",
sizeExpected: 0,
},
{
value: "2020-01-01 00:00:00.1",
sizeExpected: 1,
},
{
value: "2020-01-01 00:00:00.12",
sizeExpected: 2,
},
{
value: "2020-01-01 00:00:00.123",
sizeExpected: 3,
},
{
value: "2020-01-01 00:00:00.123456",
sizeExpected: 6,
},
{
value: "00:00:00",
sizeExpected: 0,
},
{
value: "00:00:00.1",
sizeExpected: 1,
},
{
value: "00:00:00.12",
sizeExpected: 2,
},
{
value: "00:00:00.123",
sizeExpected: 3,
},
{
value: "00:00:00.123456",
sizeExpected: 6,
},
}
for _, testcase := range testcases {
t.Run(testcase.value, func(t *testing.T) {
siz := SizeFromString(testcase.value)
assert.EqualValues(t, testcase.sizeExpected, siz)
})
}
}
45 changes: 45 additions & 0 deletions go/mysql/decimal/decimal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,52 @@ func TestDecimal_Cmp1(t *testing.T) {
a := New(123, 3)
b := New(-1234, 2)
assert.Equal(t, 1, a.Cmp(b))
}

func TestSizeAndScaleFromString(t *testing.T) {
testcases := []struct {
value string
sizeExpected int32
scaleExpected int32
}{
{
value: "0.00003",
sizeExpected: 6,
scaleExpected: 5,
},
{
value: "-0.00003",
sizeExpected: 6,
scaleExpected: 5,
},
{
value: "12.00003",
sizeExpected: 7,
scaleExpected: 5,
},
{
value: "-12.00003",
sizeExpected: 7,
scaleExpected: 5,
},
{
value: "1000003",
sizeExpected: 7,
scaleExpected: 0,
},
{
value: "-1000003",
sizeExpected: 7,
scaleExpected: 0,
},
}
for _, testcase := range testcases {
t.Run(testcase.value, func(t *testing.T) {
siz, scale := SizeAndScaleFromString(testcase.value)
assert.EqualValues(t, testcase.sizeExpected, siz)
assert.EqualValues(t, testcase.scaleExpected, scale)
})
}
}

func TestDecimal_Cmp2(t *testing.T) {
Expand Down
15 changes: 15 additions & 0 deletions go/mysql/decimal/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"math"
"math/big"
"math/bits"
"strings"

"vitess.io/vitess/go/mysql/fastparse"
)
Expand Down Expand Up @@ -71,6 +72,20 @@ func parseDecimal64(s []byte) (Decimal, error) {
}, nil
}

// SizeAndScaleFromString gets the size and scale for the decimal value without needing to parse it.
func SizeAndScaleFromString(s string) (int32, int32) {
switch s[0] {
case '+', '-':
s = s[1:]
}
totalLen := len(s)
idx := strings.Index(s, ".")
if idx == -1 {
return int32(totalLen), 0
}
return int32(totalLen - 1), int32(totalLen - 1 - idx)
}

func NewFromMySQL(s []byte) (Decimal, error) {
var original = s
var neg bool
Expand Down
6 changes: 3 additions & 3 deletions go/test/endtoend/vtgate/queries/normalize/normalize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ func TestNormalizeAllFields(t *testing.T) {
defer conn.Close()

insertQuery := `insert into t1 values (1, "chars", "variable chars", x'73757265', 0x676F, 0.33, 9.99, 1, "1976-06-08", "small", "b", "{\"key\":\"value\"}", point(1,5), b'011', 0b0101)`
normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)`
normalizedInsertQuery := `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL(3,2) */, :vtg7 /* DECIMAL(3,2) */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)`
vtgateVersion, err := cluster.GetMajorVersion("vtgate")
require.NoError(t, err)
if vtgateVersion < 19 {
normalizedInsertQuery = `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* HEXNUM */, :vtg16 /* HEXNUM */)`
if vtgateVersion < 20 {
normalizedInsertQuery = `insert into t1 values (:vtg1 /* INT64 */, :vtg2 /* VARCHAR */, :vtg3 /* VARCHAR */, :vtg4 /* HEXVAL */, :vtg5 /* HEXNUM */, :vtg6 /* DECIMAL */, :vtg7 /* DECIMAL */, :vtg8 /* INT64 */, :vtg9 /* VARCHAR */, :vtg10 /* VARCHAR */, :vtg11 /* VARCHAR */, :vtg12 /* VARCHAR */, point(:vtg13 /* INT64 */, :vtg14 /* INT64 */), :vtg15 /* BITNUM */, :vtg16 /* BITNUM */)`
}
selectQuery := "select * from t1"
utils.Exec(t, conn, insertQuery)
Expand Down
22 changes: 22 additions & 0 deletions go/test/endtoend/vtgate/queries/tpch/tpch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,28 @@ order by
end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue
from lineitem,
part
where l_partkey = p_partkey
and l_shipdate >= '1996-12-01'
and l_shipdate < date_add('1996-12-01', interval '1' month);`,
},
{
name: "Q14 without case",
query: `select 100.00 * sum(l_extendedprice * (1 - l_discount)) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue
from lineitem,
part
where l_partkey = p_partkey
and l_shipdate >= '1996-12-01'
and l_shipdate < date_add('1996-12-01', interval '1' month);`,
},
{
name: "Q14",
query: `select 100.00 * sum(case
when p_type like 'PROMO%'
then l_extendedprice * (1 - l_discount)
else 0
end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue
from lineitem,
part
where l_partkey = p_partkey
and l_shipdate >= '1996-12-01'
and l_shipdate < date_add('1996-12-01', interval '1' month);`,
Expand Down
5 changes: 3 additions & 2 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2317,8 +2317,9 @@ type (

// Argument represents bindvariable expression
Argument struct {
Name string
Type sqltypes.Type
Name string
Type sqltypes.Type
Size, Scale int32
}

// NullVal represents a NULL value.
Expand Down
2 changes: 2 additions & 0 deletions go/vt/sqlparser/ast_equals.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion go/vt/sqlparser/ast_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -1357,7 +1357,15 @@ func (node *Argument) Format(buf *TrackedBuffer) {
// For bind variables that are statically typed, emit their type as an adjacent comment.
// This comment will be ignored by older versions of Vitess (and by MySQL) but will provide
// type safety when using the query as a cache key.
buf.astPrintf(node, " /* %s */", node.Type.String())
buf.astPrintf(node, " /* %s", node.Type.String())
if node.Size != 0 || node.Scale != 0 {
buf.astPrintf(node, "(%d", node.Size)
if node.Scale != 0 {
buf.astPrintf(node, ",%d", node.Scale)
}
buf.WriteString(")")
}
buf.WriteString(" */")
}
}

Expand Down
9 changes: 9 additions & 0 deletions go/vt/sqlparser/ast_format_fast.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"strconv"
"strings"

"vitess.io/vitess/go/mysql/datetime"
"vitess.io/vitess/go/mysql/decimal"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/log"
querypb "vitess.io/vitess/go/vt/proto/query"
Expand Down Expand Up @@ -562,6 +564,20 @@ func NewTypedArgument(in string, t sqltypes.Type) *Argument {
return &Argument{Name: in, Type: t}
}

func NewTypedArgumentFromLiteral(in string, lit *Literal) (*Argument, error) {
arg := &Argument{Name: in, Type: lit.SQLType()}
switch arg.Type {
case sqltypes.Decimal:
siz, scale := decimal.SizeAndScaleFromString(lit.Val)
arg.Scale = scale
arg.Size = siz
case sqltypes.Datetime, sqltypes.Time:
siz := datetime.SizeFromString(lit.Val)
arg.Size = siz
}
return arg, nil
}

// NewListArg builds a new ListArg.
func NewListArg(in string) ListArg {
return ListArg(in)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/sqlparser/cached_size.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 18 additions & 3 deletions go/vt/sqlparser/normalizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,12 @@ func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) {
}

// Modify the AST node to a bindvar.
cursor.Replace(NewTypedArgument(bvname, node.SQLType()))
arg, err := NewTypedArgumentFromLiteral(bvname, node)
if err != nil {
nz.err = err
return
}
cursor.Replace(arg)
}

// convertLiteral converts an Literal without the dedup.
Expand All @@ -224,7 +229,12 @@ func (nz *normalizer) convertLiteral(node *Literal, cursor *Cursor) {

bvname := nz.reserved.nextUnusedVar()
nz.bindVars[bvname] = bval
cursor.Replace(NewTypedArgument(bvname, node.SQLType()))
arg, err := NewTypedArgumentFromLiteral(bvname, node)
if err != nil {
nz.err = err
return
}
cursor.Replace(arg)
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved
}

// convertComparison attempts to convert IN clauses to
Expand Down Expand Up @@ -268,7 +278,12 @@ func (nz *normalizer) parameterize(left, right Expr) Expr {
return nil
}
bvname := nz.decideBindVarName(lit, col, bval)
return NewTypedArgument(bvname, lit.SQLType())
arg, err := NewTypedArgumentFromLiteral(bvname, lit)
if err != nil {
nz.err = err
return nil
}
return arg
}

func (nz *normalizer) decideBindVarName(lit *Literal, col *ColName, bval *querypb.BindVariable) string {
Expand Down
18 changes: 16 additions & 2 deletions go/vt/sqlparser/normalizer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,28 @@ func TestNormalize(t *testing.T) {
}, {
// float val
in: "select * from t where foobar = 1.2",
outstmt: "select * from t where foobar = :foobar /* DECIMAL */",
outstmt: "select * from t where foobar = :foobar /* DECIMAL(2,1) */",
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved
outbv: map[string]*querypb.BindVariable{
"foobar": sqltypes.DecimalBindVariable("1.2"),
},
}, {
// datetime val
in: "select * from t where foobar = timestamp'2012-02-29 12:34:56.123456'",
outstmt: "select * from t where foobar = :foobar /* DATETIME(6) */",
outbv: map[string]*querypb.BindVariable{
"foobar": sqltypes.ValueBindVariable(sqltypes.NewDatetime("2012-02-29 12:34:56.123456")),
},
}, {
// time val
in: "select * from t where foobar = time'12:34:56.123456'",
outstmt: "select * from t where foobar = :foobar /* TIME(6) */",
outbv: map[string]*querypb.BindVariable{
"foobar": sqltypes.ValueBindVariable(sqltypes.NewTime("12:34:56.123456")),
},
}, {
// multiple vals
in: "select * from t where foo = 1.2 and bar = 2",
outstmt: "select * from t where foo = :foo /* DECIMAL */ and bar = :bar /* INT64 */",
outstmt: "select * from t where foo = :foo /* DECIMAL(2,1) */ and bar = :bar /* INT64 */",
outbv: map[string]*querypb.BindVariable{
"foo": sqltypes.DecimalBindVariable("1.2"),
"bar": sqltypes.Int64BindVariable(2),
Expand Down
Loading
Loading