Skip to content

Commit

Permalink
Fix derived table bug (#15831)
Browse files Browse the repository at this point in the history
  • Loading branch information
GuptaManan100 authored May 3, 2024
1 parent ef9d7db commit 4b66d39
Show file tree
Hide file tree
Showing 10 changed files with 172 additions and 66 deletions.
31 changes: 31 additions & 0 deletions go/test/endtoend/vtgate/queries/tpch/tpch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,37 @@ where l_partkey = p_partkey
and l_shipdate >= '1996-12-01'
and l_shipdate < date_add('1996-12-01', interval '1' month);`,
},
{
name: "Q8",
query: `select o_year, sum(case when nation = 'BRAZIL' then volume else 0 end) / sum(volume) as mkt_share
from (select extract(year from o_orderdate) as o_year, l_extendedprice * (1 - l_discount) as volume, n2.n_name as nation
from part,
supplier,
lineitem,
orders,
customer,
nation n1,
nation n2,
region
where p_partkey = l_partkey
and s_suppkey = l_suppkey
and l_orderkey = o_orderkey
and o_custkey = c_custkey
and c_nationkey = n1.n_nationkey
and n1.n_regionkey = r_regionkey
and r_name = 'AMERICA'
and s_nationkey = n2.n_nationkey
and o_orderdate between date '1995-01-01' and date ('1996-12-31') and p_type = 'ECONOMY ANODIZED STEEL' ) as all_nations
group by o_year
order by o_year`,
},
{
name: "simple derived table",
query: `select *
from (select l.l_extendedprice * o.o_totalprice
from lineitem l
join orders o) as dt`,
},
}

for _, testcase := range testcases {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2883,7 +2883,7 @@ func TestCrossShardSubquery(t *testing.T) {
result, err := executorExec(ctx, executor, session, "select id1 from (select u1.id id1, u2.id from user u1 join user u2 on u2.id = u1.col where u1.id = 1) as t", nil)
require.NoError(t, err)
wantQueries := []*querypb.BoundQuery{{
Sql: "select id1, t.`u1.col` from (select u1.id as id1, u1.col as `u1.col` from `user` as u1 where u1.id = 1) as t",
Sql: "select t.id1, t.`u1.col` from (select u1.id as id1, u1.col as `u1.col` from `user` as u1 where u1.id = 1) as t",
BindVariables: map[string]*querypb.BindVariable{},
}}
utils.MustMatch(t, wantQueries, sbc1.Queries)
Expand Down Expand Up @@ -2959,7 +2959,7 @@ func TestCrossShardSubqueryStream(t *testing.T) {
result, err := executorStream(ctx, executor, "select id1 from (select u1.id id1, u2.id from user u1 join user u2 on u2.id = u1.col where u1.id = 1) as t")
require.NoError(t, err)
wantQueries := []*querypb.BoundQuery{{
Sql: "select id1, t.`u1.col` from (select u1.id as id1, u1.col as `u1.col` from `user` as u1 where u1.id = 1) as t",
Sql: "select t.id1, t.`u1.col` from (select u1.id as id1, u1.col as `u1.col` from `user` as u1 where u1.id = 1) as t",
BindVariables: map[string]*querypb.BindVariable{},
}}
utils.MustMatch(t, wantQueries, sbc1.Queries)
Expand Down
28 changes: 23 additions & 5 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ type (
// so they can be used for the result of this expression that is using data from both sides.
// All fields will be used for these
applyJoinColumn struct {
Original sqlparser.Expr // this is the original expression being passed through
LHSExprs []BindVarExpr
RHSExpr sqlparser.Expr
GroupBy bool // if this is true, we need to push this down to our inputs with addToGroupBy set to true
Original sqlparser.Expr // this is the original expression being passed through
LHSExprs []BindVarExpr // These are the expressions we are pushing to the left hand side which we'll receive as bind variables
RHSExpr sqlparser.Expr // This the expression that we'll evaluate on the right hand side. This is nil, if the right hand side has nothing.
DTColName *sqlparser.ColName // This is the output column name that the parent of JOIN will be seeing. If this is unset, then the colname is the String(Original). We set this when we push Projections with derived tables underneath a Join.
GroupBy bool // if this is true, we need to push this down to our inputs with addToGroupBy set to true
}

// BindVarExpr is an expression needed from one side of a join/subquery, and the argument name for it.
Expand Down Expand Up @@ -211,7 +212,8 @@ func (aj *ApplyJoin) getJoinColumnFor(ctx *plancontext.PlanningContext, orig *sq

func applyJoinCompare(ctx *plancontext.PlanningContext, expr sqlparser.Expr) func(e applyJoinColumn) bool {
return func(e applyJoinColumn) bool {
return ctx.SemTable.EqualsExprWithDeps(e.Original, expr)
// e.DTColName is how the outside world will be using this expression. So we should check for an equality with that too.
return ctx.SemTable.EqualsExprWithDeps(e.Original, expr) || ctx.SemTable.EqualsExprWithDeps(e.DTColName, expr)
}
}

Expand Down Expand Up @@ -301,6 +303,22 @@ func (aj *ApplyJoin) planOffsets(ctx *plancontext.PlanningContext) Operator {
}

func (aj *ApplyJoin) planOffsetFor(ctx *plancontext.PlanningContext, col applyJoinColumn) {
if col.DTColName != nil {
// If DTColName is set, then we already pushed the parts of the expression down while planning.
// We need to use this name and ask the correct side of the join for it. Nothing else is required.
if col.IsPureLeft() {
offset := aj.LHS.AddColumn(ctx, true, col.GroupBy, aeWrap(col.DTColName))
aj.addOffset(ToLeftOffset(offset))
} else {
for _, lhsExpr := range col.LHSExprs {
offset := aj.LHS.AddColumn(ctx, true, col.GroupBy, aeWrap(lhsExpr.Expr))
aj.Vars[lhsExpr.Name] = offset
}
offset := aj.RHS.AddColumn(ctx, true, col.GroupBy, aeWrap(col.DTColName))
aj.addOffset(ToRightOffset(offset))
}
return
}
for _, lhsExpr := range col.LHSExprs {
offset := aj.LHS.AddColumn(ctx, true, col.GroupBy, aeWrap(lhsExpr.Expr))
if col.RHSExpr == nil {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/phases.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,10 @@ func createDMLWithInput(ctx *plancontext.PlanningContext, op, src Operator, in *
dm.cols = make([][]*sqlparser.ColName, 1)
for _, col := range in.Target.VTable.PrimaryKey {
colName := sqlparser.NewColNameWithQualifier(col.String(), in.Target.Name)
ctx.SemTable.Recursive[colName] = in.Target.ID
proj.AddColumn(ctx, true, false, aeWrap(colName))
dm.cols[0] = append(dm.cols[0], colName)
leftComp = append(leftComp, colName)
ctx.SemTable.Recursive[colName] = in.Target.ID
}

dm.Source = proj
Expand Down
19 changes: 14 additions & 5 deletions go/vt/vtgate/planbuilder/operators/projection_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,18 @@ func pushProjectionInApplyJoin(
rhs.explicitColumnAliases = true
}

// We store the original join columns to reuse them.
originalJoinColumns := src.JoinColumns
src.JoinColumns = &applyJoinColumns{}
for idx, pe := range ap {
// First we check if we have already done the work to find how to push this expression.
// If we find it then we can directly use it. This is not just a performance improvement, but
// is also required for pushing a projection that is just an alias.
foundIdx := slices.IndexFunc(originalJoinColumns.columns, applyJoinCompare(ctx, pe.ColExpr))
if foundIdx != -1 {
src.JoinColumns.add(originalJoinColumns.columns[foundIdx])
continue
}
var alias string
if p.DT != nil && len(p.DT.Columns) > 0 {
if len(p.DT.Columns) <= idx {
Expand Down Expand Up @@ -295,19 +305,18 @@ func splitUnexploredExpression(
original := sqlparser.CloneRefOfAliasedExpr(pe.Original)
expr := pe.ColExpr

var colName *sqlparser.ColName
if dt != nil {
if !pe.isSameInAndOut(ctx) {
panic(vterrors.VT13001("derived table columns must be the same in and out"))
}
colName := pe.Original.ColumnName()
newExpr := sqlparser.NewColNameWithQualifier(colName, sqlparser.NewTableName(dt.Alias))
ctx.SemTable.CopySemanticInfo(expr, newExpr)
original.Expr = newExpr
expr = newExpr
colName = sqlparser.NewColNameWithQualifier(pe.Original.ColumnName(), sqlparser.NewTableName(dt.Alias))
ctx.SemTable.CopySemanticInfo(expr, colName)
}

// Get a applyJoinColumn for the current expression.
col := join.getJoinColumnFor(ctx, original, expr, false)
col.DTColName = colName

return pushDownSplitJoinCol(col, lhs, pe, alias, rhs)
}
Expand Down
32 changes: 16 additions & 16 deletions go/vt/vtgate/planbuilder/testdata/cte_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,7 @@
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:0,L:1",
"JoinColumnIndexes": "L:1,L:0",
"TableName": "`user`_user_extra",
"Inputs": [
{
Expand All @@ -1063,8 +1063,8 @@
"Name": "user",
"Sharded": true
},
"FieldQuery": "select t.col1, t.id from (select `user`.id, `user`.col1 from `user` where 1 != 1) as t where 1 != 1",
"Query": "select t.col1, t.id from (select `user`.id, `user`.col1 from `user`) as t",
"FieldQuery": "select t.id, t.col1 from (select `user`.id, `user`.col1 from `user` where 1 != 1) as t where 1 != 1",
"Query": "select t.id, t.col1 from (select `user`.id, `user`.col1 from `user`) as t",
"Table": "`user`"
},
{
Expand Down Expand Up @@ -1111,7 +1111,7 @@
"Variant": "Join",
"JoinColumnIndexes": "L:0",
"JoinVars": {
"user_col": 1
"user_col": 2
},
"TableName": "`user`_user_extra",
"Inputs": [
Expand All @@ -1122,8 +1122,8 @@
"Name": "user",
"Sharded": true
},
"FieldQuery": "select t.id, t.`user.col` from (select `user`.id, `user`.col1, `user`.col as `user.col` from `user` where 1 != 1) as t where 1 != 1",
"Query": "select t.id, t.`user.col` from (select `user`.id, `user`.col1, `user`.col as `user.col` from `user`) as t",
"FieldQuery": "select t.id, t.col1, t.`user.col` from (select `user`.id, `user`.col1, `user`.col as `user.col` from `user` where 1 != 1) as t where 1 != 1",
"Query": "select t.id, t.col1, t.`user.col` from (select `user`.id, `user`.col1, `user`.col as `user.col` from `user`) as t",
"Table": "`user`"
},
{
Expand Down Expand Up @@ -1171,7 +1171,7 @@
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:0",
"JoinColumnIndexes": "L:1",
"TableName": "`user`_user_extra",
"Inputs": [
{
Expand All @@ -1181,8 +1181,8 @@
"Name": "user",
"Sharded": true
},
"FieldQuery": "select t.col1 from (select `user`.id, `user`.col1 from `user` where 1 != 1) as t where 1 != 1",
"Query": "select t.col1 from (select `user`.id, `user`.col1 from `user`) as t",
"FieldQuery": "select t.id, t.col1 from (select `user`.id, `user`.col1 from `user` where 1 != 1) as t where 1 != 1",
"Query": "select t.id, t.col1 from (select `user`.id, `user`.col1 from `user`) as t",
"Table": "`user`"
},
{
Expand Down Expand Up @@ -1236,7 +1236,7 @@
{
"OperatorType": "Join",
"Variant": "Join",
"JoinColumnIndexes": "L:0",
"JoinColumnIndexes": "L:1",
"TableName": "`user`_user_extra",
"Inputs": [
{
Expand All @@ -1246,8 +1246,8 @@
"Name": "user",
"Sharded": true
},
"FieldQuery": "select t.col1 from (select `user`.id, `user`.col1 from `user` where 1 != 1) as t where 1 != 1",
"Query": "select t.col1 from (select `user`.id, `user`.col1 from `user` where `user`.id = :ua_id) as t",
"FieldQuery": "select t.id, t.col1 from (select `user`.id, `user`.col1 from `user` where 1 != 1) as t where 1 != 1",
"Query": "select t.id, t.col1 from (select `user`.id, `user`.col1 from `user` where `user`.id = :ua_id) as t",
"Table": "`user`",
"Values": [
":ua_id"
Expand Down Expand Up @@ -1295,8 +1295,8 @@
"Name": "user",
"Sharded": true
},
"FieldQuery": "select id, t.id from (select `user`.id from `user` where 1 != 1) as t where 1 != 1",
"Query": "select id, t.id from (select `user`.id from `user`) as t",
"FieldQuery": "select t.id from (select `user`.id from `user` where 1 != 1) as t where 1 != 1",
"Query": "select t.id from (select `user`.id from `user`) as t",
"Table": "`user`"
},
{
Expand Down Expand Up @@ -1388,8 +1388,8 @@
"Name": "user",
"Sharded": true
},
"FieldQuery": "select id from (select `user`.id, `user`.col from `user` where 1 != 1) as t where 1 != 1",
"Query": "select id from (select `user`.id, `user`.col from `user` where `user`.id = 5) as t",
"FieldQuery": "select t.id, t.col from (select `user`.id, `user`.col from `user` where 1 != 1) as t where 1 != 1",
"Query": "select t.id, t.col from (select `user`.id, `user`.col from `user` where `user`.id = 5) as t",
"Table": "`user`",
"Values": [
"5"
Expand Down
Loading

0 comments on commit 4b66d39

Please sign in to comment.