diff --git a/go/test/endtoend/vtgate/queries/derived/derived_test.go b/go/test/endtoend/vtgate/queries/derived/derived_test.go index 80ae36633e1..c41161d9bcf 100644 --- a/go/test/endtoend/vtgate/queries/derived/derived_test.go +++ b/go/test/endtoend/vtgate/queries/derived/derived_test.go @@ -113,3 +113,15 @@ func TestDerivedTablesWithLimit(t *testing.T) { (SELECT id, user_id FROM music LIMIT 10) as m on u.id = m.user_id`, `[[INT64(1) INT64(1)] [INT64(5) INT64(2)] [INT64(1) INT64(3)] [INT64(2) INT64(4)] [INT64(3) INT64(5)] [INT64(5) INT64(7)] [INT64(4) INT64(6)] [INT64(6) NULL]]`) } + +// TestDerivedTableColumnAliasWithJoin tests the derived table having alias column and using it in the join condition +func TestDerivedTableColumnAliasWithJoin(t *testing.T) { + utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate") + mcmp, closer := start(t) + defer closer() + + mcmp.Exec(`SELECT user.id FROM user join (SELECT id as uid FROM user) t on t.uid = user.id`) + mcmp.Exec(`SELECT user.id FROM user left join (SELECT id as uid FROM user) t on t.uid = user.id`) + mcmp.Exec(`SELECT user.id FROM user join (SELECT id FROM user) t(uid) on t.uid = user.id`) + mcmp.Exec(`SELECT user.id FROM user left join (SELECT id FROM user) t(uid) on t.uid = user.id`) +} diff --git a/go/vt/vtgate/planbuilder/operators/horizon.go b/go/vt/vtgate/planbuilder/operators/horizon.go index 34f6dc79217..e83255dc599 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon.go +++ b/go/vt/vtgate/planbuilder/operators/horizon.go @@ -99,8 +99,13 @@ func (h *Horizon) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser. panic(err) } +<<<<<<< HEAD newExpr := semantics.RewriteDerivedTableExpression(expr, tableInfo) if sqlparser.ContainsAggregation(newExpr) { +======= + newExpr := ctx.RewriteDerivedTableExpression(expr, tableInfo) + if ContainsAggr(ctx, newExpr) { +>>>>>>> 7997d1eb5f (fix: derived table join column expression to be part of add join predicate on rewrite (#15956)) return newFilter(h, expr) } h.Source = h.Source.AddPredicate(ctx, newExpr) diff --git a/go/vt/vtgate/planbuilder/operators/rewriters.go b/go/vt/vtgate/planbuilder/operators/rewriters.go index 6a329860b4b..7ec8379dfab 100644 --- a/go/vt/vtgate/planbuilder/operators/rewriters.go +++ b/go/vt/vtgate/planbuilder/operators/rewriters.go @@ -218,6 +218,9 @@ func bottomUp( childID = childID.Merge(resolveID(oldInputs[0])) } in, changed := bottomUp(operator, childID, resolveID, rewriter, shouldVisit, false) + if DebugOperatorTree && changed.Changed() { + fmt.Println(ToTree(in)) + } anythingChanged = anythingChanged.Merge(changed) newInputs[i] = in } diff --git a/go/vt/vtgate/planbuilder/plancontext/planning_context.go b/go/vt/vtgate/planbuilder/plancontext/planning_context.go index 49039ddd347..3c2a1c97434 100644 --- a/go/vt/vtgate/planbuilder/plancontext/planning_context.go +++ b/go/vt/vtgate/planbuilder/plancontext/planning_context.go @@ -188,3 +188,16 @@ func (ctx *PlanningContext) execOnJoinPredicateEqual(joinPred sqlparser.Expr, fn } return false } + +func (ctx *PlanningContext) RewriteDerivedTableExpression(expr sqlparser.Expr, tableInfo semantics.TableInfo) sqlparser.Expr { + modifiedExpr := semantics.RewriteDerivedTableExpression(expr, tableInfo) + for key, exprs := range ctx.joinPredicates { + for _, rhsExpr := range exprs { + if ctx.SemTable.EqualsExpr(expr, rhsExpr) { + ctx.joinPredicates[key] = append(ctx.joinPredicates[key], modifiedExpr) + return modifiedExpr + } + } + } + return modifiedExpr +} diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index cf13e92d4dc..3e029027da5 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -4943,5 +4943,110 @@ "user.user" ] } +<<<<<<< HEAD +======= + }, + { + "comment": "column name aliases in outer join queries", + "query": "select name as t0, name as t1 from user left outer join user_extra on user.cola = user_extra.cola", + "plan": { + "QueryType": "SELECT", + "Original": "select name as t0, name as t1 from user left outer join user_extra on user.cola = user_extra.cola", + "Instructions": { + "OperatorType": "SimpleProjection", + "ColumnNames": [ + "t0", + "t1" + ], + "Columns": [ + 0, + 0 + ], + "Inputs": [ + { + "OperatorType": "Join", + "Variant": "LeftJoin", + "JoinColumnIndexes": "L:0", + "JoinVars": { + "user_cola": 1 + }, + "TableName": "`user`_user_extra", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `name`, `user`.cola from `user` where 1 != 1", + "Query": "select `name`, `user`.cola from `user`", + "Table": "`user`" + }, + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from user_extra where 1 != 1", + "Query": "select 1 from user_extra where user_extra.cola = :user_cola", + "Table": "user_extra" + } + ] + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] + } + }, + { + "comment": "Over clause works for unsharded tables", + "query": "SELECT val, CUME_DIST() OVER w, ROW_NUMBER() OVER w, DENSE_RANK() OVER w, PERCENT_RANK() OVER w, RANK() OVER w AS 'cd' FROM unsharded_a", + "plan": { + "QueryType": "SELECT", + "Original": "SELECT val, CUME_DIST() OVER w, ROW_NUMBER() OVER w, DENSE_RANK() OVER w, PERCENT_RANK() OVER w, RANK() OVER w AS 'cd' FROM unsharded_a", + "Instructions": { + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select val, cume_dist() over w, row_number() over w, dense_rank() over w, percent_rank() over w, rank() over w as cd from unsharded_a where 1 != 1", + "Query": "select val, cume_dist() over w, row_number() over w, dense_rank() over w, percent_rank() over w, rank() over w as cd from unsharded_a", + "Table": "unsharded_a" + }, + "TablesUsed": [ + "main.unsharded_a" + ] + } + }, + { + "comment": "join with derived table with alias and join condition - merge into route", + "query": "select 1 from user join (select id as uid from user) as t where t.uid = user.id", + "plan": { + "QueryType": "SELECT", + "Original": "select 1 from user join (select id as uid from user) as t where t.uid = user.id", + "Instructions": { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select 1 from (select id as uid from `user` where 1 != 1) as t, `user` where 1 != 1", + "Query": "select 1 from (select id as uid from `user`) as t, `user` where t.uid = `user`.id", + "Table": "`user`" + }, + "TablesUsed": [ + "user.user" + ] + } +>>>>>>> 7997d1eb5f (fix: derived table join column expression to be part of add join predicate on rewrite (#15956)) } ]