Skip to content

Commit

Permalink
fix: derived table join column expression to be part of add join pred…
Browse files Browse the repository at this point in the history
…icate on rewrite

Signed-off-by: Harshit Gangal <harshit@planetscale.com>
  • Loading branch information
harshit-gangal committed May 16, 2024
1 parent 951f273 commit 1ab0fa9
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 0 deletions.
14 changes: 14 additions & 0 deletions go/test/endtoend/vtgate/queries/derived/derived_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,17 @@ func TestDerivedTablesWithLimit(t *testing.T) {
(SELECT id, user_id FROM music LIMIT 10) as m on u.id = m.user_id`,
`[[INT64(1) INT64(1)] [INT64(5) INT64(2)] [INT64(1) INT64(3)] [INT64(2) INT64(4)] [INT64(3) INT64(5)] [INT64(5) INT64(7)] [INT64(4) INT64(6)] [INT64(6) NULL]]`)
}

func TestDerivedTableColumnAliasWithJoin(t *testing.T) {
mcmp, closer := start(t)
defer closer()

mcmp.AssertMatches(`SELECT user.id FROM user join (SELECT id as uid FROM user) t on t.uid = user.id`,
`[[INT64(5)] [INT64(4)] [INT64(3)] [INT64(2)] [INT64(1)]]`)
mcmp.AssertMatches(`SELECT user.id FROM user left join (SELECT id as uid FROM user) t on t.uid = user.id`,
`[[INT64(5)] [INT64(4)] [INT64(3)] [INT64(2)] [INT64(1)]]`)
mcmp.AssertMatches(`SELECT user.id FROM user join (SELECT id FROM user) t(uid) on t.uid = user.id`,
`[[INT64(5)] [INT64(4)] [INT64(3)] [INT64(2)] [INT64(1)]]`)
mcmp.AssertMatches(`SELECT user.id FROM user left join (SELECT id FROM user) t(uid) on t.uid = user.id`,
`[[INT64(5)] [INT64(4)] [INT64(3)] [INT64(2)] [INT64(1)]]`)
}
2 changes: 2 additions & 0 deletions go/vt/vtgate/planbuilder/operators/apply_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ func (aj *ApplyJoin) AddJoinPredicate(ctx *plancontext.PlanningContext, expr sql
col := breakExpressionInLHSandRHSForApplyJoin(ctx, pred, TableID(aj.LHS))
aj.JoinPredicates.add(col)
ctx.AddJoinPredicates(pred, col.RHSExpr)
ctx.JoinPredInProgress = pred
rhs = rhs.AddPredicate(ctx, col.RHSExpr)
ctx.JoinPredInProgress = nil
}
aj.RHS = rhs
}
Expand Down
3 changes: 3 additions & 0 deletions go/vt/vtgate/planbuilder/operators/horizon.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ func (h *Horizon) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.
}

newExpr := semantics.RewriteDerivedTableExpression(expr, tableInfo)
if ctx.JoinPredInProgress != nil {
ctx.AddJoinPredicates(ctx.JoinPredInProgress, newExpr)
}
if ContainsAggr(ctx, newExpr) {
return newFilter(h, expr)
}
Expand Down
3 changes: 3 additions & 0 deletions go/vt/vtgate/planbuilder/operators/rewriters.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ func bottomUp(
childID = childID.Merge(resolveID(oldInputs[0]))
}
in, changed := bottomUp(operator, childID, resolveID, rewriter, shouldVisit, false)
if DebugOperatorTree && changed.Changed() {
fmt.Println(ToTree(in))
}
anythingChanged = anythingChanged.Merge(changed)
newInputs[i] = in
}
Expand Down
1 change: 1 addition & 0 deletions go/vt/vtgate/planbuilder/plancontext/planning_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type PlanningContext struct {
SemTable *semantics.SemTable
VSchema VSchema

JoinPredInProgress sqlparser.Expr
// joinPredicates maps each original join predicate (key) to a slice of
// variations of the RHS predicates (value). This map is used to handle
// different scenarios in join planning, where the RHS predicates are
Expand Down
22 changes: 22 additions & 0 deletions go/vt/vtgate/planbuilder/testdata/select_cases.json
Original file line number Diff line number Diff line change
Expand Up @@ -5221,5 +5221,27 @@
"main.unsharded_a"
]
}
},
{
"comment": "join with derived table with alias and join condition - merge into route",
"query": "select 1 from user join (select id as uid from user) as t where t.uid = user.id",
"plan": {
"QueryType": "SELECT",
"Original": "select 1 from user join (select id as uid from user) as t where t.uid = user.id",
"Instructions": {
"OperatorType": "Route",
"Variant": "Scatter",
"Keyspace": {
"Name": "user",
"Sharded": true
},
"FieldQuery": "select 1 from (select id as uid from `user` where 1 != 1) as t, `user` where 1 != 1",
"Query": "select 1 from (select id as uid from `user`) as t, `user` where t.uid = `user`.id",
"Table": "`user`"
},
"TablesUsed": [
"user.user"
]
}
}
]

0 comments on commit 1ab0fa9

Please sign in to comment.