diff --git a/go/test/endtoend/vtgate/vitess_tester/cte/queries.test b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test index de38a21cd78..173dcaeb6ba 100644 --- a/go/test/endtoend/vtgate/vitess_tester/cte/queries.test +++ b/go/test/endtoend/vtgate/vitess_tester/cte/queries.test @@ -6,6 +6,11 @@ CREATE TABLE employees manager_id INT ); +# Simple recursive CTE using a real table. Select everything from empty table +with recursive cte as (select * from employees union all select * from cte) +select * +from cte; + # Insert data into the tables INSERT INTO employees (id, name, manager_id) VALUES (1, 'CEO', NULL), @@ -107,4 +112,115 @@ GROUP BY manager_id; --error infinite recursion with recursive cte as (select 1 as n union all select n+1 from cte) select * -from cte; \ No newline at end of file +from cte; + +# Define recursive CTE and then use it on the RHS of UNION +WITH RECURSIVE foo AS (SELECT id + FROM employees + WHERE id = 1 + UNION ALL + SELECT id + 1 + FROM foo + WHERE id < 5) +SELECT id +FROM foo; + +# Recursive CTE with UNION DISTINCT +WITH RECURSIVE hierarchy AS (SELECT id, name, manager_id + FROM employees + UNION ALL + SELECT id, name, manager_id + FROM employees + UNION + DISTINCT + SELECT id * 2, name, manager_id + from hierarchy + WHERE id < 10) +SELECT * +FROM hierarchy; + +# Select with false condition +with recursive cte as (select * from employees where false union all select * from cte) +select * +from cte; + +# Select with no matching rows +with recursive cte as (select * from employees where id > 100 union all select * from cte) +select * +from cte; + +# Recursive CTE joined with a normal table. Predicate on the outside should not be pushed in +WITH RECURSIVE emp_cte AS (SELECT id, name, manager_id + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id + FROM employees e + INNER JOIN emp_cte cte ON e.manager_id = cte.id) +SELECT * +FROM emp_cte +where name = 'Engineer1'; + +# Query with a recursive CTE in a subquery +SELECT * +FROM (SELECT 1 UNION ALL SELECT 2) AS dt(a) +WHERE EXISTS(WITH RECURSIVE qn AS (SELECT a * 0 AS b UNION ALL SELECT b + 1 FROM qn WHERE b = 0) + SELECT * + FROM qn + WHERE b = a); + +# Join with recursive CTE inside a derived table using data from DUAL +SELECT e.id, e.name, e.manager_id, d.id AS cte_id +FROM employees e + JOIN (WITH RECURSIVE foo AS (SELECT 1 AS id + UNION ALL + SELECT id + 1 + FROM foo + WHERE id < 5) + SELECT id + FROM foo) d ON e.id = d.id; + +# Join with recursive CTE inside a derived table using data from employees table +SELECT e.id, e.name, e.manager_id, d.id AS cte_id +FROM employees e + JOIN (WITH RECURSIVE foo AS (SELECT id + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id + FROM employees e + JOIN foo f ON e.manager_id = f.id) + SELECT id + FROM foo) d ON e.id = d.id; + +# Recursive CTE within an uncorrelated subquery as a select expression +SELECT e.id, + e.name, + e.manager_id, + (SELECT MAX(cte_id) + FROM (WITH RECURSIVE foo AS (SELECT 1 AS cte_id + UNION ALL + SELECT cte_id + 1 + FROM foo + WHERE cte_id < e.id) + SELECT cte_id + FROM foo) AS recursive_result) AS max_cte_id +FROM employees e; + +# Recursive CTE used twice in the same query +WITH RECURSIVE employee_hierarchy AS (SELECT id, name, manager_id, 1 AS level + FROM employees + WHERE manager_id IS NULL + UNION ALL + SELECT e.id, e.name, e.manager_id, h.level + 1 + FROM employees e + JOIN employee_hierarchy h ON e.manager_id = h.id) +SELECT h1.id AS employee_id, + h1.name AS employee_name, + h1.level AS employee_level, + h2.name AS manager_name, + h2.level AS manager_level +FROM employee_hierarchy h1 + LEFT JOIN + employee_hierarchy h2 ON h1.manager_id = h2.id +ORDER BY h1.level, h1.id; \ No newline at end of file diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index 8cc23c57ae7..3972ac8290a 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -134,6 +134,12 @@ func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) { addPred = stmt.AddWhere case *sqlparser.Delete: addPred = stmt.AddWhere + case nil: + // this would happen if we are adding a predicate on a dual query. + // we use this when building recursive CTE queries + sel := &sqlparser.Select{} + addPred = sel.AddWhere + qb.stmt = sel default: panic(fmt.Sprintf("cant add WHERE to %T", qb.stmt)) } @@ -236,10 +242,11 @@ func (qb *queryBuilder) unionWith(other *queryBuilder, distinct bool) { } } -func (qb *queryBuilder) recursiveCteWith(other *queryBuilder, name, alias string) { +func (qb *queryBuilder) recursiveCteWith(other *queryBuilder, name, alias string, distinct bool) { cteUnion := &sqlparser.Union{ - Left: qb.stmt.(sqlparser.SelectStatement), - Right: other.stmt.(sqlparser.SelectStatement), + Left: qb.stmt.(sqlparser.SelectStatement), + Right: other.stmt.(sqlparser.SelectStatement), + Distinct: distinct, } qb.stmt = &sqlparser.Select{ @@ -719,7 +726,7 @@ func buildRecursiveCTE(op *RecurseCTE, qb *queryBuilder) { panic(err) } - qb.recursiveCteWith(qbR, op.Def.Name, infoFor.GetAliasedTableExpr().As.String()) + qb.recursiveCteWith(qbR, op.Def.Name, infoFor.GetAliasedTableExpr().As.String(), op.Distinct) } func mergeHaving(h1, h2 *sqlparser.Where) *sqlparser.Where { diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 4f0ab742935..4c075f480d3 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -337,7 +337,7 @@ func createRecursiveCTE(ctx *plancontext.PlanningContext, def *semantics.CTE, ou panic(err) } - return newRecurse(ctx, def, seed, term, activeCTE.Predicates, horizon, idForRecursiveTable(ctx, def), outerID) + return newRecurse(ctx, def, seed, term, activeCTE.Predicates, horizon, idForRecursiveTable(ctx, def), outerID, union.Distinct) } func idForRecursiveTable(ctx *plancontext.PlanningContext, def *semantics.CTE) semantics.TableSet { diff --git a/go/vt/vtgate/planbuilder/operators/cte_merging.go b/go/vt/vtgate/planbuilder/operators/cte_merging.go index 9ca453f39c6..a6830cbe12b 100644 --- a/go/vt/vtgate/planbuilder/operators/cte_merging.go +++ b/go/vt/vtgate/planbuilder/operators/cte_merging.go @@ -31,14 +31,22 @@ func tryMergeRecurse(ctx *plancontext.PlanningContext, in *RecurseCTE) (Operator } func tryMergeCTE(ctx *plancontext.PlanningContext, seed, term Operator, in *RecurseCTE) *Route { - seedRoute, termRoute, _, routingB, a, b, sameKeyspace := prepareInputRoutes(seed, term) - if seedRoute == nil || !sameKeyspace { + seedRoute, termRoute, routingA, routingB, a, b, sameKeyspace := prepareInputRoutes(seed, term) + if seedRoute == nil { return nil } switch { case a == dual: return mergeCTE(ctx, seedRoute, termRoute, routingB, in) + case b == dual: + return mergeCTE(ctx, seedRoute, termRoute, routingA, in) + case !sameKeyspace: + return nil + case a == anyShard: + return mergeCTE(ctx, seedRoute, termRoute, routingB, in) + case b == anyShard: + return mergeCTE(ctx, seedRoute, termRoute, routingA, in) case a == sharded && b == sharded: return tryMergeCTESharded(ctx, seedRoute, termRoute, in) default: @@ -80,6 +88,7 @@ func mergeCTE(ctx *plancontext.PlanningContext, seed, term *Route, r Routing, in Term: newTerm, LeftID: in.LeftID, OuterID: in.OuterID, + Distinct: in.Distinct, }, MergedWith: []*Route{term}, } diff --git a/go/vt/vtgate/planbuilder/operators/join_merging.go b/go/vt/vtgate/planbuilder/operators/join_merging.go index c124cefd73c..672da551fa6 100644 --- a/go/vt/vtgate/planbuilder/operators/join_merging.go +++ b/go/vt/vtgate/planbuilder/operators/join_merging.go @@ -111,7 +111,6 @@ func prepareInputRoutes(lhs Operator, rhs Operator) (*Route, *Route, Routing, Ro lhsRoute, rhsRoute, routingA, routingB, sameKeyspace := getRoutesOrAlternates(lhsRoute, rhsRoute) a, b := getRoutingType(routingA), getRoutingType(routingB) - return lhsRoute, rhsRoute, routingA, routingB, a, b, sameKeyspace } diff --git a/go/vt/vtgate/planbuilder/operators/recurse_cte.go b/go/vt/vtgate/planbuilder/operators/recurse_cte.go index 7a8c9dcd355..ebb7dc54765 100644 --- a/go/vt/vtgate/planbuilder/operators/recurse_cte.go +++ b/go/vt/vtgate/planbuilder/operators/recurse_cte.go @@ -18,6 +18,7 @@ package operators import ( "fmt" + "slices" "strings" "golang.org/x/exp/maps" @@ -56,6 +57,9 @@ type RecurseCTE struct { // The OuterID is the id for this use of the CTE OuterID semantics.TableSet + + // Distinct is used to determine if the result set should be distinct + Distinct bool } var _ Operator = (*RecurseCTE)(nil) @@ -67,6 +71,7 @@ func newRecurse( predicates []*plancontext.RecurseExpression, horizon *Horizon, leftID, outerID semantics.TableSet, + distinct bool, ) *RecurseCTE { for _, pred := range predicates { ctx.AddJoinPredicates(pred.Original, pred.RightExpr) @@ -79,21 +84,18 @@ func newRecurse( Horizon: horizon, LeftID: leftID, OuterID: outerID, + Distinct: distinct, } } func (r *RecurseCTE) Clone(inputs []Operator) Operator { - return &RecurseCTE{ - Seed: inputs[0], - Term: inputs[1], - Def: r.Def, - Predicates: r.Predicates, - Projections: r.Projections, - Vars: maps.Clone(r.Vars), - Horizon: r.Horizon, - LeftID: r.LeftID, - OuterID: r.OuterID, - } + klone := *r + klone.Seed = inputs[0] + klone.Term = inputs[1] + klone.Vars = maps.Clone(r.Vars) + klone.Predicates = slices.Clone(r.Predicates) + klone.Projections = slices.Clone(r.Projections) + return &klone } func (r *RecurseCTE) Inputs() []Operator { @@ -106,8 +108,7 @@ func (r *RecurseCTE) SetInputs(operators []Operator) { } func (r *RecurseCTE) AddPredicate(_ *plancontext.PlanningContext, e sqlparser.Expr) Operator { - r.Term = newFilter(r, e) - return r + return newFilter(r, e) } func (r *RecurseCTE) AddColumn(ctx *plancontext.PlanningContext, _, _ bool, expr *sqlparser.AliasedExpr) int { @@ -162,13 +163,17 @@ func (r *RecurseCTE) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser. } func (r *RecurseCTE) ShortDescription() string { + distinct := "" + if r.Distinct { + distinct = "distinct " + } if len(r.Vars) > 0 { - return fmt.Sprintf("%v", r.Vars) + return fmt.Sprintf("%s%v", distinct, r.Vars) } expressions := slice.Map(r.expressions(), func(expr *plancontext.RecurseExpression) string { return sqlparser.String(expr.Original) }) - return fmt.Sprintf("%v %v", r.Def.Name, strings.Join(expressions, ", ")) + return fmt.Sprintf("%s%v %v", distinct, r.Def.Name, strings.Join(expressions, ", ")) } func (r *RecurseCTE) GetOrdering(*plancontext.PlanningContext) []OrderBy { diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index 35470ce77d0..b00dc0a060f 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -2438,5 +2438,128 @@ "user.user" ] } + }, + { + "comment": "Query that can be merged, dual on the RHS of the UNION", + "query": "with recursive cte as (select id from user where id = 72 union all select id+1 from cte where id < 100) select * from cte", + "plan": { + "QueryType": "SELECT", + "Original": "with recursive cte as (select id from user where id = 72 union all select id+1 from cte where id < 100) select * from cte", + "Instructions": { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "with recursive cte as (select id from `user` where 1 != 1 union all select id + 1 from cte where 1 != 1) select id from cte where 1 != 1", + "Query": "with recursive cte as (select id from `user` where id = 72 union all select id + 1 from cte where id < 100) select id from cte", + "Table": "`user`, dual", + "Values": [ + "72" + ], + "Vindex": "user_index" + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } + }, + { + "comment": "Merge CTE with reference tables", + "query": "with recursive cte as (select ue.id, ue.foo from user u join user_extra ue on u.id = ue.user_id union all select sr.id, sr.foo from cte join main.source_of_ref sr on sr.foo = cte.foo join main.rerouted_ref rr on rr.bar = sr.bar) select * from cte", + "plan": { + "QueryType": "SELECT", + "Original": "with recursive cte as (select ue.id, ue.foo from user u join user_extra ue on u.id = ue.user_id union all select sr.id, sr.foo from cte join main.source_of_ref sr on sr.foo = cte.foo join main.rerouted_ref rr on rr.bar = sr.bar) select * from cte", + "Instructions": { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "with recursive cte as (select ue.id, ue.foo from `user` as u, user_extra as ue where 1 != 1 union all select sr.id, sr.foo from ref_with_source as sr, ref as rr where 1 != 1) select id, foo from cte where 1 != 1", + "Query": "with recursive cte as (select ue.id, ue.foo from `user` as u, user_extra as ue where u.id = ue.user_id union all select sr.id, sr.foo from ref_with_source as sr, ref as rr where rr.bar = sr.bar and sr.foo = cte.foo) select id, foo from cte", + "Table": "`user`, ref, ref_with_source, user_extra" + }, + "TablesUsed": [ + "user.ref", + "user.ref_with_source", + "user.user", + "user.user_extra" + ] + } + }, + { + "comment": "Merge CTE with reference tables 2", + "query": "with recursive cte as (select sr.id, sr.foo from main.source_of_ref sr join main.rerouted_ref rr on rr.bar = sr.bar union all select ue.id, ue.foo from cte join user_extra ue on cte.foo = ue.foo join user u on ue.user_id = u.id) select * from cte", + "plan": { + "QueryType": "SELECT", + "Original": "with recursive cte as (select sr.id, sr.foo from main.source_of_ref sr join main.rerouted_ref rr on rr.bar = sr.bar union all select ue.id, ue.foo from cte join user_extra ue on cte.foo = ue.foo join user u on ue.user_id = u.id) select * from cte", + "Instructions": { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "with recursive cte as (select 1 from ref_with_source as sr, ref as rr where 1 != 1 union all select ue.id, ue.foo from cte, user_extra as ue, `user` as u where 1 != 1) select id, foo from cte where 1 != 1", + "Query": "with recursive cte as (select 1 from ref_with_source as sr, ref as rr where rr.bar = sr.bar union all select ue.id, ue.foo from cte, user_extra as ue, `user` as u where ue.user_id = u.id and cte.foo = ue.foo) select id, foo from cte", + "Table": "`user`, dual, ref, ref_with_source, user_extra" + }, + "TablesUsed": [ + "main.dual", + "user.ref", + "user.ref_with_source", + "user.user", + "user.user_extra" + ] + } + }, + { + "comment": "Merged recursive CTE with DISTINCT", + "query": "WITH RECURSIVE hierarchy AS (SELECT id, name, manager_id FROM user UNION ALL SELECT id, name, manager_id FROM user UNION DISTINCT SELECT id*2, name, manager_id from hierarchy WHERE id < 10 ) SELECT * FROM hierarchy", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE hierarchy AS (SELECT id, name, manager_id FROM user UNION ALL SELECT id, name, manager_id FROM user UNION DISTINCT SELECT id*2, name, manager_id from hierarchy WHERE id < 10 ) SELECT * FROM hierarchy", + "Instructions": { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "with recursive hierarchy as (select id, `name`, manager_id from `user` where 1 != 1 union all select id, `name`, manager_id from `user` where 1 != 1 union select id * 2, `name`, manager_id from hierarchy where 1 != 1) select id, `name`, manager_id from hierarchy where 1 != 1", + "Query": "with recursive hierarchy as (select id, `name`, manager_id from `user` union all select id, `name`, manager_id from `user` union select id * 2, `name`, manager_id from hierarchy where id < 10) select id, `name`, manager_id from hierarchy", + "Table": "`user`, dual" + }, + "TablesUsed": [ + "main.dual", + "user.user" + ] + } + }, + { + "comment": "Query that caused planner to stack overflow", + "query": "SELECT * FROM (SELECT 1 UNION ALL SELECT 2) AS dt(a) WHERE EXISTS(WITH RECURSIVE qn AS (SELECT a * 0 AS b UNION ALL SELECT b + 1 FROM qn WHERE b = 0) SELECT * FROM qn WHERE b = a)", + "plan": { + "QueryType": "SELECT", + "Original": "SELECT * FROM (SELECT 1 UNION ALL SELECT 2) AS dt(a) WHERE EXISTS(WITH RECURSIVE qn AS (SELECT a * 0 AS b UNION ALL SELECT b + 1 FROM qn WHERE b = 0) SELECT * FROM qn WHERE b = a)", + "Instructions": { + "OperatorType": "Route", + "Variant": "Reference", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "select a from (select 1 from dual where 1 != 1 union all select 2 from dual where 1 != 1) as dt(a) where 1 != 1", + "Query": "select a from (select 1 from dual union all select 2 from dual) as dt(a) where exists (with recursive qn as (select a * 0 as b from dual union all select b + 1 from qn where b = 0) select 1 from qn where b = a)", + "Table": "dual" + }, + "TablesUsed": [ + "main.dual" + ] + } } ]