Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Scale and length handling in CASE and JOIN bind variables #15787

Merged
merged 6 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions go/test/endtoend/vtgate/queries/tpch/tpch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ package union
import (
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/test/endtoend/cluster"
"vitess.io/vitess/go/test/endtoend/utils"

"github.com/stretchr/testify/require"
)

func start(t *testing.T) (utils.MySQLCompare, func()) {
Expand Down Expand Up @@ -161,6 +161,19 @@ group by
order by
value desc;`,
},
{
name: "Q14 without decimal literal",
query: `select sum(case
when p_type like 'PROMO%'
then l_extendedprice * (1 - l_discount)
else 0
end) / sum(l_extendedprice * (1 - l_discount)) as promo_revenue
from lineitem,
part
where l_partkey = p_partkey
and l_shipdate >= '1996-12-01'
and l_shipdate < date_add('1996-12-01', interval '1' month);`,
},
}

for _, testcase := range testcases {
Expand Down
10 changes: 5 additions & 5 deletions go/vt/vtgate/engine/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st
result := &sqltypes.Result{}
if len(lresult.Rows) == 0 && wantfields {
for k, col := range jn.Vars {
joinVars[k] = bindvarForType(lresult.Fields[col].Type)
joinVars[k] = bindvarForType(lresult.Fields[col])
}
rresult, err := jn.Right.GetFields(ctx, vcursor, combineVars(bindVars, joinVars))
if err != nil {
Expand Down Expand Up @@ -95,19 +95,19 @@ func (jn *Join) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[st
return result, nil
}

func bindvarForType(t querypb.Type) *querypb.BindVariable {
func bindvarForType(field *querypb.Field) *querypb.BindVariable {
bv := &querypb.BindVariable{
Type: t,
Type: field.Type,
Value: nil,
}
switch t {
switch field.Type {
case querypb.Type_INT8, querypb.Type_UINT8, querypb.Type_INT16, querypb.Type_UINT16,
querypb.Type_INT32, querypb.Type_UINT32, querypb.Type_INT64, querypb.Type_UINT64:
bv.Value = []byte("0")
case querypb.Type_FLOAT32, querypb.Type_FLOAT64:
bv.Value = []byte("0e0")
case querypb.Type_DECIMAL:
bv.Value = []byte("0.0")
bv.Value = []byte(fmt.Sprintf("%s.%s", strings.Repeat("0", max(1, int(field.ColumnLength-field.Decimals))), strings.Repeat("0", max(1, int(field.Decimals)))))
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved
default:
return sqltypes.NullBindVariable
}
Expand Down
17 changes: 17 additions & 0 deletions go/vt/vtgate/evalengine/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"time"

"github.com/olekukonko/tablewriter"
"github.com/stretchr/testify/require"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
Expand Down Expand Up @@ -168,6 +169,7 @@ func TestCompilerSingle(t *testing.T) {
values []sqltypes.Value
result string
collation collations.ID
typeWanted evalengine.Type
}{
{
expression: "1 + column0",
Expand Down Expand Up @@ -675,6 +677,15 @@ func TestCompilerSingle(t *testing.T) {
expression: `1 * unix_timestamp(time('1.0000'))`,
result: `DECIMAL(1698098401.0000)`,
},
{
expression: `(case
when 'PROMOTION' like 'PROMO%'
then 0.01
else 0
end) * 0.01`,
result: `DECIMAL(0.0001)`,
typeWanted: evalengine.NewTypeEx(sqltypes.Decimal, collations.CollationBinaryID, false, 4, 4),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make sure to merge / rebase since the signature of this function has changed.

},
}

tz, _ := time.LoadLocation("Europe/Madrid")
Expand Down Expand Up @@ -715,6 +726,12 @@ func TestCompilerSingle(t *testing.T) {
t.Fatalf("bad collation evaluation from eval engine: got %d, want %d", expected.Collation(), tc.collation)
}

if tc.typeWanted.Type() != sqltypes.Unknown {
typ, err := env.TypeOf(converted)
require.NoError(t, err)
require.EqualValues(t, tc.typeWanted, typ)
GuptaManan100 marked this conversation as resolved.
Show resolved Hide resolved
}

// re-run the same evaluation multiple times to ensure results are always consistent
for i := 0; i < 8; i++ {
res, err := env.Evaluate(converted)
Expand Down
7 changes: 6 additions & 1 deletion go/vt/vtgate/evalengine/expr_logical.go
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,7 @@ func (c *CaseExpr) simplify(env *ExpressionEnv) error {
func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
var ca collationAggregation
var ta typeAggregation
var scale, size int32

for _, wt := range cs.cases {
when, err := wt.when.compile(c)
Expand All @@ -691,6 +692,8 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
}

ta.add(then.Type, then.Flag)
scale = max(scale, then.Scale)
size = max(size, then.Size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this logic apply to other usages of the typeAggregation as well? And should this really be part of typeAggregation therefore instead? Seems like that might be more appropriate to refactor it into that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made the required changes. I wasn't sure what the length and scale for other data types looks like, so I just passed 0 from everywhere.

if err := ca.add(then.Col, c.env.CollationEnv()); err != nil {
return ctype{}, err
}
Expand All @@ -703,6 +706,8 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
}

ta.add(els.Type, els.Flag)
scale = max(scale, els.Scale)
size = max(size, els.Size)
if err := ca.add(els.Col, c.env.CollationEnv()); err != nil {
return ctype{}, err
}
Expand All @@ -712,7 +717,7 @@ func (cs *CaseExpr) compile(c *compiler) (ctype, error) {
if ta.nullable {
f |= flagNullable
}
ct := ctype{Type: ta.result(), Flag: f, Col: ca.result()}
ct := ctype{Type: ta.result(), Flag: f, Col: ca.result(), Scale: scale, Size: size}
c.asm.CmpCase(len(cs.cases), cs.Else != nil, ct.Type, ct.Col, c.sqlmode.AllowZeroDate())
return ct, nil
}
Expand Down
Loading