diff --git a/go/vt/sqlparser/impossible_query.go b/go/vt/sqlparser/impossible_query.go index 512931f1db7..a6bf1ea8736 100644 --- a/go/vt/sqlparser/impossible_query.go +++ b/go/vt/sqlparser/impossible_query.go @@ -27,6 +27,9 @@ package sqlparser func FormatImpossibleQuery(buf *TrackedBuffer, node SQLNode) { switch node := node.(type) { case *Select: + if node.With != nil { + node.With.Format(buf) + } buf.Myprintf("select %v from ", node.SelectExprs) var prefix string for _, n := range node.From { diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index 6c2171cf689..256372c172f 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -80,10 +80,7 @@ func (a *Aggregator) SetInputs(operators []Operator) { } func (a *Aggregator) AddPredicate(_ *plancontext.PlanningContext, expr sqlparser.Expr) Operator { - return &Filter{ - Source: a, - Predicates: []sqlparser.Expr{expr}, - } + return newFilter(a, expr) } func (a *Aggregator) addColumnWithoutPushing(_ *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) int { diff --git a/go/vt/vtgate/planbuilder/operators/apply_join.go b/go/vt/vtgate/planbuilder/operators/apply_join.go index 9294311c00f..c182bb2fb83 100644 --- a/go/vt/vtgate/planbuilder/operators/apply_join.go +++ b/go/vt/vtgate/planbuilder/operators/apply_join.go @@ -109,7 +109,7 @@ func (aj *ApplyJoin) Clone(inputs []Operator) Operator { } func (aj *ApplyJoin) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { - return AddPredicate(ctx, aj, expr, false, newFilter) + return AddPredicate(ctx, aj, expr, false, newFilterSinglePredicate) } // Inputs implements the Operator interface diff --git a/go/vt/vtgate/planbuilder/operators/filter.go b/go/vt/vtgate/planbuilder/operators/filter.go index 17e5ac64b71..c2432a40da9 100644 --- a/go/vt/vtgate/planbuilder/operators/filter.go +++ b/go/vt/vtgate/planbuilder/operators/filter.go @@ -39,9 +39,13 @@ type Filter struct { Truncate int } -func newFilter(op Operator, expr sqlparser.Expr) Operator { +func newFilterSinglePredicate(op Operator, expr sqlparser.Expr) Operator { + return newFilter(op, expr) +} + +func newFilter(op Operator, expr ...sqlparser.Expr) Operator { return &Filter{ - Source: op, Predicates: []sqlparser.Expr{expr}, + Source: op, Predicates: expr, } } diff --git a/go/vt/vtgate/planbuilder/operators/hash_join.go b/go/vt/vtgate/planbuilder/operators/hash_join.go index e5a87cf4ef8..0ad46bcbc82 100644 --- a/go/vt/vtgate/planbuilder/operators/hash_join.go +++ b/go/vt/vtgate/planbuilder/operators/hash_join.go @@ -106,7 +106,7 @@ func (hj *HashJoin) SetInputs(operators []Operator) { } func (hj *HashJoin) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { - return AddPredicate(ctx, hj, expr, false, newFilter) + return AddPredicate(ctx, hj, expr, false, newFilterSinglePredicate) } func (hj *HashJoin) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, addToGroupBy bool, expr *sqlparser.AliasedExpr) int { diff --git a/go/vt/vtgate/planbuilder/operators/horizon.go b/go/vt/vtgate/planbuilder/operators/horizon.go index f05abb0311b..34f6dc79217 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon.go +++ b/go/vt/vtgate/planbuilder/operators/horizon.go @@ -94,17 +94,14 @@ func (h *Horizon) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser. tableInfo, err := ctx.SemTable.TableInfoForExpr(expr) if err != nil { if errors.Is(err, semantics.ErrNotSingleTable) { - return &Filter{ - Source: h, - Predicates: []sqlparser.Expr{expr}, - } + return newFilter(h, expr) } panic(err) } newExpr := semantics.RewriteDerivedTableExpression(expr, tableInfo) if sqlparser.ContainsAggregation(newExpr) { - return &Filter{Source: h, Predicates: []sqlparser.Expr{expr}} + return newFilter(h, expr) } h.Source = h.Source.AddPredicate(ctx, newExpr) return h diff --git a/go/vt/vtgate/planbuilder/operators/join.go b/go/vt/vtgate/planbuilder/operators/join.go index 0796d237b88..787d7fedfcc 100644 --- a/go/vt/vtgate/planbuilder/operators/join.go +++ b/go/vt/vtgate/planbuilder/operators/join.go @@ -128,7 +128,7 @@ func createInnerJoin(ctx *plancontext.PlanningContext, tableExpr *sqlparser.Join } func (j *Join) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Operator { - return AddPredicate(ctx, j, expr, false, newFilter) + return AddPredicate(ctx, j, expr, false, newFilterSinglePredicate) } var _ JoinOp = (*Join)(nil) diff --git a/go/vt/vtgate/planbuilder/operators/route_planning.go b/go/vt/vtgate/planbuilder/operators/route_planning.go index ddb2f0d1210..f7276ea48c7 100644 --- a/go/vt/vtgate/planbuilder/operators/route_planning.go +++ b/go/vt/vtgate/planbuilder/operators/route_planning.go @@ -533,7 +533,7 @@ func pushJoinPredicates(ctx *plancontext.PlanningContext, exprs []sqlparser.Expr } for _, expr := range exprs { - AddPredicate(ctx, op, expr, true, newFilter) + AddPredicate(ctx, op, expr, true, newFilterSinglePredicate) } return op diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index 24417cfab21..537737363c8 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -269,10 +269,7 @@ func (sq *SubQuery) settleFilter(ctx *plancontext.PlanningContext, outer Operato predicates = append(predicates, rhsPred) sq.SubqueryValueName = sq.ArgName } - return &Filter{ - Source: outer, - Predicates: predicates, - } + return newFilter(outer, predicates...) } func dontEnterSubqueries(node, _ sqlparser.SQLNode) bool { diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index 1727f7bedcb..960cde99acc 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -458,7 +458,7 @@ func tryMergeSubqueriesRecursively( finalResult = finalResult.Merge(res) } - op.Source = &Filter{Source: outer.Source, Predicates: []sqlparser.Expr{subQuery.Original}} + op.Source = newFilter(outer.Source, subQuery.Original) return op, finalResult.Merge(Rewrote("merge outer of two subqueries")) } @@ -477,7 +477,7 @@ func tryMergeSubqueryWithOuter(ctx *plancontext.PlanningContext, subQuery *SubQu return outer, NoRewrite } if !subQuery.IsProjection { - op.Source = &Filter{Source: outer.Source, Predicates: []sqlparser.Expr{subQuery.Original}} + op.Source = newFilter(outer.Source, subQuery.Original) } ctx.MergedSubqueries = append(ctx.MergedSubqueries, subQuery.originalSubquery) return op, Rewrote("merged subquery with outer") @@ -582,10 +582,7 @@ func (s *subqueryRouteMerger) merge(ctx *plancontext.PlanningContext, inner, out if isSharded { src = s.outer.Source if !s.subq.IsProjection { - src = &Filter{ - Source: s.outer.Source, - Predicates: []sqlparser.Expr{s.original}, - } + src = newFilter(s.outer.Source, s.original) } } else { src = s.rewriteASTExpression(ctx, inner) @@ -655,10 +652,7 @@ func (s *subqueryRouteMerger) rewriteASTExpression(ctx *plancontext.PlanningCont cursor.Replace(subq) } }, ctx.SemTable.CopySemanticInfo).(sqlparser.Expr) - src = &Filter{ - Source: s.outer.Source, - Predicates: []sqlparser.Expr{sQuery}, - } + src = newFilter(s.outer.Source, sQuery) } return src } diff --git a/go/vt/vtgate/planbuilder/operators/union.go b/go/vt/vtgate/planbuilder/operators/union.go index 6ce5fe9a9f8..1d739c9f01c 100644 --- a/go/vt/vtgate/planbuilder/operators/union.go +++ b/go/vt/vtgate/planbuilder/operators/union.go @@ -105,10 +105,7 @@ func (u *Union) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.Ex needsFilter, exprPerSource := u.predicatePerSource(expr, offsets) if needsFilter { - return &Filter{ - Source: u, - Predicates: []sqlparser.Expr{expr}, - } + return newFilter(u, expr) } for i, src := range u.Sources { diff --git a/go/vt/vtgate/planbuilder/testdata/cte_cases.json b/go/vt/vtgate/planbuilder/testdata/cte_cases.json index c51a6f9144d..e43b6320340 100644 --- a/go/vt/vtgate/planbuilder/testdata/cte_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/cte_cases.json @@ -1369,8 +1369,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select u.* from (select * from unsharded where 1 != 1) as u where 1 != 1", - "Query": "select u.* from (select * from unsharded) as u", + "FieldQuery": "with u as (select * from unsharded where 1 != 1) select u.* from u where 1 != 1", + "Query": "with u as (select * from unsharded) select u.* from u", "Table": "unsharded" }, "TablesUsed": [ @@ -1709,8 +1709,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select col from (select col from unsharded join unsharded_b where 1 != 1) as u join unsharded_a as ua where 1 != 1", - "Query": "select col from (select col from unsharded join unsharded_b) as u join unsharded_a as ua limit 1", + "FieldQuery": "with u as (select col from unsharded join unsharded_b where 1 != 1) select col from u join unsharded_a as ua where 1 != 1", + "Query": "with u as (select col from unsharded join unsharded_b) select col from u join unsharded_a as ua limit 1", "Table": "unsharded, unsharded_a, unsharded_b" }, "TablesUsed": [ @@ -1840,5 +1840,28 @@ "user.user" ] } + }, + { + "comment": "recursive WITH against an unsharded database", + "query": "WITH RECURSIVE cte (n) AS ( SELECT 1 UNION ALL SELECT n + 1 FROM cte WHERE n < 5 ) SELECT cte.n FROM unsharded join cte on unsharded.id = cte.n ", + "plan": { + "QueryType": "SELECT", + "Original": "WITH RECURSIVE cte (n) AS ( SELECT 1 UNION ALL SELECT n + 1 FROM cte WHERE n < 5 ) SELECT cte.n FROM unsharded join cte on unsharded.id = cte.n ", + "Instructions": { + "OperatorType": "Route", + "Variant": "Unsharded", + "Keyspace": { + "Name": "main", + "Sharded": false + }, + "FieldQuery": "with recursive cte(n) as (select 1 from dual where 1 != 1 union all select n + 1 from cte where 1 != 1) select cte.n from unsharded join cte on unsharded.id = cte.n where 1 != 1", + "Query": "with recursive cte(n) as (select 1 from dual union all select n + 1 from cte where n < 5) select cte.n from unsharded join cte on unsharded.id = cte.n", + "Table": "dual, unsharded" + }, + "TablesUsed": [ + "main.dual", + "main.unsharded" + ] + } } ] diff --git a/go/vt/vtgate/planbuilder/testdata/ddl_cases.json b/go/vt/vtgate/planbuilder/testdata/ddl_cases.json index e31cc3e29e1..41aded18c5d 100644 --- a/go/vt/vtgate/planbuilder/testdata/ddl_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/ddl_cases.json @@ -260,7 +260,7 @@ "Name": "main", "Sharded": false }, - "Query": "create view view_a as select col1, col2 from (select col1, col2 from unsharded where id = 1 union select col1, col2 from unsharded where id = 3) as a" + "Query": "create view view_a as select * from (select col1, col2 from unsharded where id = 1 union select col1, col2 from unsharded where id = 3) as a" }, "TablesUsed": [ "main.view_a" diff --git a/go/vt/vtgate/planbuilder/testdata/filter_cases.json b/go/vt/vtgate/planbuilder/testdata/filter_cases.json index b4bca3f8830..4353f31fd48 100644 --- a/go/vt/vtgate/planbuilder/testdata/filter_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/filter_cases.json @@ -4086,7 +4086,7 @@ "Sharded": false }, "FieldQuery": "select col + 2 as a from unsharded where 1 != 1", - "Query": "select col + 2 as a from unsharded having col + 2 = 42", + "Query": "select col + 2 as a from unsharded having a = 42", "Table": "unsharded" }, "TablesUsed": [ diff --git a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json index b304f9323e4..1c1bcc9144a 100644 --- a/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/foreignkey_cases.json @@ -816,7 +816,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl2 left join u_tbl1 on u_tbl1.col1 = cast(u_tbl2.col1 + 'bar' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl2 left join u_tbl1 on u_tbl1.col1 = cast(u_tbl2.col1 + 'bar' as CHAR) where cast(u_tbl2.col1 + 'bar' as CHAR) is not null and not (u_tbl2.col2) <=> (cast(u_tbl2.col1 + 'bar' as CHAR)) and u_tbl2.id = 1 and u_tbl1.col1 is null limit 1 for share", + "Query": "select 1 from u_tbl2 left join u_tbl1 on u_tbl1.col1 = cast(u_tbl2.col1 + 'bar' as CHAR) where u_tbl1.col1 is null and cast(u_tbl2.col1 + 'bar' as CHAR) is not null and not (u_tbl2.col2) <=> (cast(u_tbl2.col1 + 'bar' as CHAR)) and u_tbl2.id = 1 limit 1 for share", "Table": "u_tbl1, u_tbl2" }, { @@ -1523,7 +1523,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals and u_tbl9.col9 is null limit 1 for share nowait", + "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and cast('foo' as CHAR) is not null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait", "Table": "u_tbl8, u_tbl9" }, { @@ -1599,7 +1599,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where not (u_tbl4.col4) <=> (cast('foo' as CHAR)) and (u_tbl4.col4) in ::fkc_vals and u_tbl3.col3 is null limit 1 for share", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where u_tbl3.col3 is null and cast('foo' as CHAR) is not null and not (u_tbl4.col4) <=> (cast('foo' as CHAR)) and (u_tbl4.col4) in ::fkc_vals limit 1 for share", "Table": "u_tbl3, u_tbl4" }, { @@ -1611,7 +1611,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4, u_tbl9 where 1 != 1", - "Query": "select 1 from u_tbl4, u_tbl9 where (u_tbl4.col4) in ::fkc_vals and (u_tbl9.col9) not in ((cast('foo' as CHAR))) and u_tbl4.col4 = u_tbl9.col9 limit 1 for share", + "Query": "select 1 from u_tbl4, u_tbl9 where u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and (cast('foo' as CHAR) is null or (u_tbl9.col9) not in ((cast('foo' as CHAR)))) limit 1 for share", "Table": "u_tbl4, u_tbl9" }, { @@ -1688,7 +1688,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast(:v1 as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast(:v1 as CHAR) where not (u_tbl4.col4) <=> (cast(:v1 as CHAR)) and (u_tbl4.col4) in ::fkc_vals and cast(:v1 as CHAR) is not null and u_tbl3.col3 is null limit 1 for share", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast(:v1 as CHAR) where u_tbl3.col3 is null and cast(:v1 as CHAR) is not null and not (u_tbl4.col4) <=> (cast(:v1 as CHAR)) and (u_tbl4.col4) in ::fkc_vals limit 1 for share", "Table": "u_tbl3, u_tbl4" }, { @@ -1700,7 +1700,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4, u_tbl9 where 1 != 1", - "Query": "select 1 from u_tbl4, u_tbl9 where (u_tbl4.col4) in ::fkc_vals and (cast(:v1 as CHAR) is null or (u_tbl9.col9) not in ((cast(:v1 as CHAR)))) and u_tbl4.col4 = u_tbl9.col9 limit 1 for share", + "Query": "select 1 from u_tbl4, u_tbl9 where u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and (cast(:v1 as CHAR) is null or (u_tbl9.col9) not in ((cast(:v1 as CHAR)))) limit 1 for share", "Table": "u_tbl4, u_tbl9" }, { @@ -2362,7 +2362,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast(:fkc_upd as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast(:fkc_upd as CHAR) where not (u_tbl4.col4) <=> (cast(:fkc_upd as CHAR)) and (u_tbl4.col4) in ::fkc_vals and cast(:fkc_upd as CHAR) is not null and u_tbl3.col3 is null limit 1 for share", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast(:fkc_upd as CHAR) where u_tbl3.col3 is null and cast(:fkc_upd as CHAR) is not null and not (u_tbl4.col4) <=> (cast(:fkc_upd as CHAR)) and (u_tbl4.col4) in ::fkc_vals limit 1 for share", "Table": "u_tbl3, u_tbl4" }, { @@ -2374,7 +2374,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4, u_tbl9 where 1 != 1", - "Query": "select 1 from u_tbl4, u_tbl9 where (u_tbl4.col4) in ::fkc_vals and (cast(:fkc_upd as CHAR) is null or (u_tbl9.col9) not in ((cast(:fkc_upd as CHAR)))) and u_tbl4.col4 = u_tbl9.col9 limit 1 for share", + "Query": "select 1 from u_tbl4, u_tbl9 where u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and (cast(:fkc_upd as CHAR) is null or (u_tbl9.col9) not in ((cast(:fkc_upd as CHAR)))) limit 1 for share", "Table": "u_tbl4, u_tbl9" }, { @@ -2537,7 +2537,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_multicol_tbl2 left join u_multicol_tbl1 on u_multicol_tbl1.cola = 2 and u_multicol_tbl1.colb = u_multicol_tbl2.colc - 2 where 1 != 1", - "Query": "select 1 from u_multicol_tbl2 left join u_multicol_tbl1 on u_multicol_tbl1.cola = 2 and u_multicol_tbl1.colb = u_multicol_tbl2.colc - 2 where u_multicol_tbl2.colc - 2 is not null and not (u_multicol_tbl2.cola, u_multicol_tbl2.colb) <=> (2, u_multicol_tbl2.colc - 2) and u_multicol_tbl2.id = 7 and u_multicol_tbl1.cola is null and u_multicol_tbl1.colb is null limit 1 for share", + "Query": "select 1 from u_multicol_tbl2 left join u_multicol_tbl1 on u_multicol_tbl1.cola = 2 and u_multicol_tbl1.colb = u_multicol_tbl2.colc - 2 where u_multicol_tbl1.cola is null and 2 is not null and u_multicol_tbl1.colb is null and u_multicol_tbl2.colc - 2 is not null and not (u_multicol_tbl2.cola, u_multicol_tbl2.colb) <=> (2, u_multicol_tbl2.colc - 2) and u_multicol_tbl2.id = 7 limit 1 for share", "Table": "u_multicol_tbl1, u_multicol_tbl2" }, { @@ -3139,7 +3139,7 @@ "Sharded": false }, "FieldQuery": "select u_tbl6.col6 from u_tbl6 as u, u_tbl5 as m where 1 != 1", - "Query": "select u_tbl6.col6 from u_tbl6 as u, u_tbl5 as m where u.col2 = 4 and m.col3 = 6 and u.col = m.col for update", + "Query": "select u_tbl6.col6 from u_tbl6 as u, u_tbl5 as m where u.col = m.col and u.col2 = 4 and m.col3 = 6 for update", "Table": "u_tbl5, u_tbl6" }, { @@ -3197,7 +3197,7 @@ "Sharded": false }, "FieldQuery": "select u_tbl10.col from u_tbl10, u_tbl11 where 1 != 1", - "Query": "select u_tbl10.col from u_tbl10, u_tbl11 where u_tbl10.id = 5 and u_tbl10.id = u_tbl11.id for update", + "Query": "select u_tbl10.col from u_tbl10, u_tbl11 where u_tbl10.id = u_tbl11.id and u_tbl10.id = 5 for update", "Table": "u_tbl10, u_tbl11" }, { @@ -3479,7 +3479,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl1 on u_tbl1.col14 = cast(:__sq1 as SIGNED) where 1 != 1", - "Query": "select /*+ SET_VAR(foreign_key_checks=OFF) */ 1 from u_tbl4 left join u_tbl1 on u_tbl1.col14 = cast(:__sq1 as SIGNED) where not (u_tbl4.col41) <=> (cast(:__sq1 as SIGNED)) and u_tbl4.col4 = 3 and cast(:__sq1 as SIGNED) is not null and u_tbl1.col14 is null limit 1 lock in share mode", + "Query": "select /*+ SET_VAR(foreign_key_checks=OFF) */ 1 from u_tbl4 left join u_tbl1 on u_tbl1.col14 = cast(:__sq1 as SIGNED) where u_tbl1.col14 is null and cast(:__sq1 as SIGNED) is not null and not (u_tbl4.col41) <=> (cast(:__sq1 as SIGNED)) and u_tbl4.col4 = 3 limit 1 lock in share mode", "Table": "u_tbl1, u_tbl4" }, { diff --git a/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json b/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json index 413f3246e38..5404be4e1dc 100644 --- a/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/foreignkey_checks_on_cases.json @@ -816,7 +816,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl2 left join u_tbl1 on u_tbl1.col1 = cast(u_tbl2.col1 + 'bar' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl2 left join u_tbl1 on u_tbl1.col1 = cast(u_tbl2.col1 + 'bar' as CHAR) where cast(u_tbl2.col1 + 'bar' as CHAR) is not null and not (u_tbl2.col2) <=> (cast(u_tbl2.col1 + 'bar' as CHAR)) and u_tbl2.id = 1 and u_tbl1.col1 is null limit 1 for share", + "Query": "select 1 from u_tbl2 left join u_tbl1 on u_tbl1.col1 = cast(u_tbl2.col1 + 'bar' as CHAR) where u_tbl1.col1 is null and cast(u_tbl2.col1 + 'bar' as CHAR) is not null and not (u_tbl2.col2) <=> (cast(u_tbl2.col1 + 'bar' as CHAR)) and u_tbl2.id = 1 limit 1 for share", "Table": "u_tbl1, u_tbl2" }, { @@ -1523,7 +1523,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals and u_tbl9.col9 is null limit 1 for share nowait", + "Query": "select 1 from u_tbl8 left join u_tbl9 on u_tbl9.col9 = cast('foo' as CHAR) where u_tbl9.col9 is null and cast('foo' as CHAR) is not null and not (u_tbl8.col8) <=> (cast('foo' as CHAR)) and (u_tbl8.col8) in ::fkc_vals limit 1 for share nowait", "Table": "u_tbl8, u_tbl9" }, { @@ -1599,7 +1599,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where not (u_tbl4.col4) <=> (cast('foo' as CHAR)) and (u_tbl4.col4) in ::fkc_vals and u_tbl3.col3 is null limit 1 for share", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast('foo' as CHAR) where u_tbl3.col3 is null and cast('foo' as CHAR) is not null and not (u_tbl4.col4) <=> (cast('foo' as CHAR)) and (u_tbl4.col4) in ::fkc_vals limit 1 for share", "Table": "u_tbl3, u_tbl4" }, { @@ -1611,7 +1611,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4, u_tbl9 where 1 != 1", - "Query": "select 1 from u_tbl4, u_tbl9 where (u_tbl4.col4) in ::fkc_vals and (u_tbl9.col9) not in ((cast('foo' as CHAR))) and u_tbl4.col4 = u_tbl9.col9 limit 1 for share", + "Query": "select 1 from u_tbl4, u_tbl9 where u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and (cast('foo' as CHAR) is null or (u_tbl9.col9) not in ((cast('foo' as CHAR)))) limit 1 for share", "Table": "u_tbl4, u_tbl9" }, { @@ -1688,7 +1688,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast(:v1 as CHAR) where 1 != 1", - "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast(:v1 as CHAR) where not (u_tbl4.col4) <=> (cast(:v1 as CHAR)) and (u_tbl4.col4) in ::fkc_vals and cast(:v1 as CHAR) is not null and u_tbl3.col3 is null limit 1 for share", + "Query": "select 1 from u_tbl4 left join u_tbl3 on u_tbl3.col3 = cast(:v1 as CHAR) where u_tbl3.col3 is null and cast(:v1 as CHAR) is not null and not (u_tbl4.col4) <=> (cast(:v1 as CHAR)) and (u_tbl4.col4) in ::fkc_vals limit 1 for share", "Table": "u_tbl3, u_tbl4" }, { @@ -1700,7 +1700,7 @@ "Sharded": false }, "FieldQuery": "select 1 from u_tbl4, u_tbl9 where 1 != 1", - "Query": "select 1 from u_tbl4, u_tbl9 where (u_tbl4.col4) in ::fkc_vals and (cast(:v1 as CHAR) is null or (u_tbl9.col9) not in ((cast(:v1 as CHAR)))) and u_tbl4.col4 = u_tbl9.col9 limit 1 for share", + "Query": "select 1 from u_tbl4, u_tbl9 where u_tbl4.col4 = u_tbl9.col9 and (u_tbl4.col4) in ::fkc_vals and (cast(:v1 as CHAR) is null or (u_tbl9.col9) not in ((cast(:v1 as CHAR)))) limit 1 for share", "Table": "u_tbl4, u_tbl9" }, { diff --git a/go/vt/vtgate/planbuilder/testdata/from_cases.json b/go/vt/vtgate/planbuilder/testdata/from_cases.json index 6ec0d1ab135..044036a4590 100644 --- a/go/vt/vtgate/planbuilder/testdata/from_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/from_cases.json @@ -711,8 +711,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select m1.col from unsharded as m1 join unsharded as m2 where 1 != 1", - "Query": "select m1.col from unsharded as m1 join unsharded as m2", + "FieldQuery": "select m1.col from unsharded as m1 straight_join unsharded as m2 where 1 != 1", + "Query": "select m1.col from unsharded as m1 straight_join unsharded as m2", "Table": "unsharded" }, "TablesUsed": [ @@ -3989,8 +3989,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select A.col1 as col1, A.col2 as col2, B.col2 as col2 from unsharded_authoritative as A left join unsharded_authoritative as B on A.col1 = B.col1 where 1 != 1", - "Query": "select A.col1 as col1, A.col2 as col2, B.col2 as col2 from unsharded_authoritative as A left join unsharded_authoritative as B on A.col1 = B.col1", + "FieldQuery": "select * from unsharded_authoritative as A left join unsharded_authoritative as B using (col1) where 1 != 1", + "Query": "select * from unsharded_authoritative as A left join unsharded_authoritative as B using (col1)", "Table": "unsharded_authoritative" }, "TablesUsed": [ diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index 1466135dd6c..9e75a2a2f32 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -1506,8 +1506,8 @@ "Name": "main", "Sharded": false }, - "FieldQuery": "select col1, col2 from (select col1, col2 from unsharded where 1 != 1 union select col1, col2 from unsharded where 1 != 1) as a where 1 != 1", - "Query": "select col1, col2 from (select col1, col2 from unsharded where id = 1 union select col1, col2 from unsharded where id = 3) as a", + "FieldQuery": "select * from (select col1, col2 from unsharded where 1 != 1 union select col1, col2 from unsharded where 1 != 1) as a where 1 != 1", + "Query": "select * from (select col1, col2 from unsharded where id = 1 union select col1, col2 from unsharded where id = 3) as a", "Table": "unsharded" }, "TablesUsed": [ @@ -2693,7 +2693,7 @@ "Sharded": false }, "FieldQuery": "select 1 from (select col, count(*) as a from unsharded where 1 != 1 group by col) as f left join unsharded as u on f.col = u.id where 1 != 1", - "Query": "select 1 from (select col, count(*) as a from unsharded group by col having count(*) > 0 limit 0, 12) as f left join unsharded as u on f.col = u.id", + "Query": "select 1 from (select col, count(*) as a from unsharded group by col having a > 0 limit 0, 12) as f left join unsharded as u on f.col = u.id", "Table": "unsharded" }, "TablesUsed": [ diff --git a/go/vt/vtgate/planbuilder/testdata/union_cases.json b/go/vt/vtgate/planbuilder/testdata/union_cases.json index 12a709d023f..9ac8db73be7 100644 --- a/go/vt/vtgate/planbuilder/testdata/union_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/union_cases.json @@ -874,7 +874,7 @@ "Sharded": false }, "FieldQuery": "(select 1 from unsharded where 1 != 1 union select 1 from unsharded where 1 != 1 union all select 1 from unsharded where 1 != 1) union select 1 from unsharded where 1 != 1 union all select 1 from unsharded where 1 != 1", - "Query": "(select 1 from unsharded union select 1 from unsharded union all select 1 from unsharded order by `1` asc) union select 1 from unsharded union all select 1 from unsharded order by `1` asc", + "Query": "(select 1 from unsharded union select 1 from unsharded union all select 1 from unsharded order by 1 asc) union select 1 from unsharded union all select 1 from unsharded order by 1 asc", "Table": "unsharded" }, "TablesUsed": [ diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index f52bd983104..cf2b3300208 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -28,48 +28,61 @@ import ( // analyzer controls the flow of the analysis. // It starts the tree walking and controls which part of the analysis sees which parts of the tree type analyzer struct { - scoper *scoper - tables *tableCollector - binder *binder - typer *typer - rewriter *earlyRewriter - sig QuerySignature + scoper *scoper + earlyTables *earlyTableCollector + tables *tableCollector + binder *binder + typer *typer + rewriter *earlyRewriter + sig QuerySignature + si SchemaInformation + currentDb string err error inProjection int - projErr error - unshardedErr error - warning string + projErr error + unshardedErr error + warning string + singleUnshardedKeyspace bool + fullAnalysis bool } // newAnalyzer create the semantic analyzer -func newAnalyzer(dbName string, si SchemaInformation) *analyzer { +func newAnalyzer(dbName string, si SchemaInformation, fullAnalysis bool) *analyzer { // TODO dependencies between these components are a little tangled. We should try to clean up s := newScoper() a := &analyzer{ - scoper: s, - tables: newTableCollector(s, si, dbName), - typer: newTyper(si.Environment().CollationEnv()), + scoper: s, + earlyTables: newEarlyTableCollector(si, dbName), + typer: newTyper(si.Environment().CollationEnv()), + si: si, + currentDb: dbName, + fullAnalysis: fullAnalysis, } s.org = a - a.tables.org = a + return a +} - b := newBinder(s, a, a.tables, a.typer) - a.binder = b +func (a *analyzer) lateInit() { + a.tables = a.earlyTables.newTableCollector(a.scoper, a) + a.binder = newBinder(a.scoper, a, a.tables, a.typer) + a.scoper.binder = a.binder a.rewriter = &earlyRewriter{ - env: si.Environment(), - scoper: s, - binder: b, + env: a.si.Environment(), + scoper: a.scoper, + binder: a.binder, expandedColumns: map[sqlparser.TableName][]*sqlparser.ColName{}, } - s.binder = b - return a } // Analyze analyzes the parsed query. func Analyze(statement sqlparser.Statement, currentDb string, si SchemaInformation) (*SemTable, error) { - analyzer := newAnalyzer(currentDb, newSchemaInfo(si)) + return analyseAndGetSemTable(statement, currentDb, si, false) +} + +func analyseAndGetSemTable(statement sqlparser.Statement, currentDb string, si SchemaInformation, fullAnalysis bool) (*SemTable, error) { + analyzer := newAnalyzer(currentDb, newSchemaInfo(si), fullAnalysis) // Analysis for initial scope err := analyzer.analyze(statement) @@ -83,7 +96,7 @@ func Analyze(statement sqlparser.Statement, currentDb string, si SchemaInformati // AnalyzeStrict analyzes the parsed query, and fails the analysis for any possible errors func AnalyzeStrict(statement sqlparser.Statement, currentDb string, si SchemaInformation) (*SemTable, error) { - st, err := Analyze(statement, currentDb, si) + st, err := analyseAndGetSemTable(statement, currentDb, si, true) if err != nil { return nil, err } @@ -109,6 +122,32 @@ func (a *analyzer) newSemTable( if isCommented { comments = commentedStmt.GetParsedComments() } + + if a.singleUnshardedKeyspace { + return &SemTable{ + Tables: a.earlyTables.Tables, + Comments: comments, + Warning: a.warning, + Collation: coll, + ExprTypes: map[sqlparser.Expr]evalengine.Type{}, + NotSingleRouteErr: a.projErr, + NotUnshardedErr: a.unshardedErr, + Recursive: ExprDependencies{}, + Direct: ExprDependencies{}, + Targets: map[sqlparser.IdentifierCS]TableSet{}, + ColumnEqualities: map[columnName][]sqlparser.Expr{}, + ExpandedColumns: map[sqlparser.TableName][]*sqlparser.ColName{}, + columns: map[*sqlparser.Union]sqlparser.SelectExprs{}, + comparator: nil, + StatementIDs: a.scoper.statementIDs, + QuerySignature: QuerySignature{}, + childForeignKeysInvolved: map[TableSet][]vindexes.ChildFKInfo{}, + parentForeignKeysInvolved: map[TableSet][]vindexes.ParentFKInfo{}, + childFkToUpdExprs: map[string]sqlparser.UpdateExprs{}, + collEnv: env, + }, nil + } + columns := map[*sqlparser.Union]sqlparser.SelectExprs{} for union, info := range a.tables.unionInfo { columns[union] = info.exprs @@ -298,10 +337,66 @@ func (a *analyzer) depsForExpr(expr sqlparser.Expr) (direct, recursive TableSet, } func (a *analyzer) analyze(statement sqlparser.Statement) error { + _ = sqlparser.Rewrite(statement, nil, a.earlyUp) + if a.err != nil { + return a.err + } + + if a.canShortCut(statement) { + return nil + } + + a.lateInit() + _ = sqlparser.Rewrite(statement, a.analyzeDown, a.analyzeUp) return a.err } +// canShortCut checks if we are dealing with a single unsharded keyspace and no tables that have managed foreign keys +// if so, we can stop the analyzer early +func (a *analyzer) canShortCut(statement sqlparser.Statement) (canShortCut bool) { + if a.fullAnalysis { + return false + } + ks, _ := singleUnshardedKeyspace(a.earlyTables.Tables) + if ks == nil { + return false + } + + defer func() { + a.singleUnshardedKeyspace = canShortCut + }() + + if !sqlparser.IsDMLStatement(statement) { + return true + } + + fkMode, err := a.si.ForeignKeyMode(ks.Name) + if err != nil { + a.err = err + return false + } + if fkMode != vschemapb.Keyspace_managed { + return true + } + + for _, table := range a.earlyTables.Tables { + vtbl := table.GetVindexTable() + if len(vtbl.ChildForeignKeys) > 0 || len(vtbl.ParentForeignKeys) > 0 { + return false + } + } + + return true +} + +// earlyUp collects tables in the query, so we can check +// if this a single unsharded query we are dealing with +func (a *analyzer) earlyUp(cursor *sqlparser.Cursor) bool { + a.earlyTables.up(cursor) + return true +} + func (a *analyzer) shouldContinue() bool { return a.err == nil } @@ -500,7 +595,7 @@ func (a *analyzer) getAllManagedForeignKeys() (map[TableSet][]vindexes.ChildFKIn continue } // Check whether Vitess needs to manage the foreign keys in this keyspace or not. - fkMode, err := a.tables.si.ForeignKeyMode(vi.Keyspace.Name) + fkMode, err := a.si.ForeignKeyMode(vi.Keyspace.Name) if err != nil { return nil, nil, err } @@ -508,7 +603,7 @@ func (a *analyzer) getAllManagedForeignKeys() (map[TableSet][]vindexes.ChildFKIn continue } // Cyclic foreign key constraints error is stored in the keyspace. - ksErr := a.tables.si.KeyspaceError(vi.Keyspace.Name) + ksErr := a.si.KeyspaceError(vi.Keyspace.Name) if ksErr != nil { return nil, nil, ksErr } diff --git a/go/vt/vtgate/semantics/analyzer_fk_test.go b/go/vt/vtgate/semantics/analyzer_fk_test.go index 17d1674fff8..05a5991b49f 100644 --- a/go/vt/vtgate/semantics/analyzer_fk_test.go +++ b/go/vt/vtgate/semantics/analyzer_fk_test.go @@ -147,11 +147,11 @@ func TestGetAllManagedForeignKeys(t *testing.T) { tbl["t1"], &DerivedTable{}, }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - "ks_unmanaged": vschemapb.Keyspace_unmanaged, - }, + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + "ks_unmanaged": vschemapb.Keyspace_unmanaged, }, }, }, @@ -176,10 +176,10 @@ func TestGetAllManagedForeignKeys(t *testing.T) { tbl["t2"], tbl["t3"], }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - }, + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, }, }, }, @@ -193,14 +193,14 @@ func TestGetAllManagedForeignKeys(t *testing.T) { tbl["t0"], tbl["t1"], &DerivedTable{}, }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - "ks_unmanaged": vschemapb.Keyspace_unmanaged, - }, - KsError: map[string]error{ - "ks": fmt.Errorf("VT09019: keyspace 'ks' has cyclic foreign keys"), - }, + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + "ks_unmanaged": vschemapb.Keyspace_unmanaged, + }, + KsError: map[string]error{ + "ks": fmt.Errorf("VT09019: keyspace 'ks' has cyclic foreign keys"), }, }, }, @@ -355,11 +355,11 @@ func TestGetInvolvedForeignKeys(t *testing.T) { tbl["t0"], tbl["t1"], }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - "ks_unmanaged": vschemapb.Keyspace_unmanaged, - }, + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + "ks_unmanaged": vschemapb.Keyspace_unmanaged, }, }, }, @@ -394,10 +394,10 @@ func TestGetInvolvedForeignKeys(t *testing.T) { tbl["t4"], tbl["t5"], }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - }, + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, }, }, }, @@ -438,11 +438,11 @@ func TestGetInvolvedForeignKeys(t *testing.T) { tbl["t0"], tbl["t1"], }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - "ks_unmanaged": vschemapb.Keyspace_unmanaged, - }, + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + "ks_unmanaged": vschemapb.Keyspace_unmanaged, }, }, }, @@ -470,11 +470,11 @@ func TestGetInvolvedForeignKeys(t *testing.T) { tbl["t0"], tbl["t1"], }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - "ks_unmanaged": vschemapb.Keyspace_unmanaged, - }, + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + "ks_unmanaged": vschemapb.Keyspace_unmanaged, }, }, }, @@ -507,11 +507,11 @@ func TestGetInvolvedForeignKeys(t *testing.T) { tbl["t6"], tbl["t1"], }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - "ks_unmanaged": vschemapb.Keyspace_unmanaged, - }, + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, + "ks_unmanaged": vschemapb.Keyspace_unmanaged, }, }, }, @@ -541,10 +541,10 @@ func TestGetInvolvedForeignKeys(t *testing.T) { tbl["t2"], tbl["t3"], }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - }, + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, }, }, }, @@ -559,10 +559,10 @@ func TestGetInvolvedForeignKeys(t *testing.T) { tbl["t2"], tbl["t3"], }, - si: &FakeSI{ - KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ - "ks": vschemapb.Keyspace_managed, - }, + }, + si: &FakeSI{ + KsForeignKeyMode: map[string]vschemapb.Keyspace_ForeignKeyMode{ + "ks": vschemapb.Keyspace_managed, }, }, }, diff --git a/go/vt/vtgate/semantics/analyzer_test.go b/go/vt/vtgate/semantics/analyzer_test.go index dd8db6f5cd1..a7c173ccc96 100644 --- a/go/vt/vtgate/semantics/analyzer_test.go +++ b/go/vt/vtgate/semantics/analyzer_test.go @@ -28,7 +28,7 @@ import ( "vitess.io/vitess/go/vt/vtgate/vindexes" ) -var T0 TableSet +var NoTables TableSet var ( // Just here to make outputs more readable @@ -598,7 +598,7 @@ func TestOrderByBindingTable(t *testing.T) { TS0, }, { "select 1 as c from tabl order by c", - T0, + NoTables, }, { "select name, name from t1, t2 order by name", TS1, @@ -652,15 +652,15 @@ func TestVindexHints(t *testing.T) { sql string expectedErr string }{{ - sql: "select col from t use vindex (does_not_exist)", - expectedErr: "Vindex 'does_not_exist' does not exist in table 'ks1.t'", + sql: "select col from t1 use vindex (does_not_exist)", + expectedErr: "Vindex 'does_not_exist' does not exist in table 'ks2.t1'", }, { - sql: "select col from t ignore vindex (does_not_exist)", - expectedErr: "Vindex 'does_not_exist' does not exist in table 'ks1.t'", + sql: "select col from t1 ignore vindex (does_not_exist)", + expectedErr: "Vindex 'does_not_exist' does not exist in table 'ks2.t1'", }, { - sql: "select col from t use vindex (id_vindex)", + sql: "select id from t1 use vindex (id_vindex)", }, { - sql: "select col from t ignore vindex (id_vindex)", + sql: "select id from t1 ignore vindex (id_vindex)", }} for _, tc := range tcases { t.Run(tc.sql, func(t *testing.T) { @@ -710,7 +710,7 @@ func TestGroupByBinding(t *testing.T) { TS0, }, { "select 1 as c from tabl group by c", - T0, + NoTables, }, { "select t1.id from t1, t2 group by id", TS0, @@ -759,13 +759,13 @@ func TestHavingBinding(t *testing.T) { TS0, }, { "select col from tabl having 1 = 1", - T0, + NoTables, }, { "select col as c from tabl having c = 1", TS0, }, { "select 1 as c from tabl having c = 1", - T0, + NoTables, }, { "select t1.id from t1, t2 having id = 1", TS0, @@ -932,109 +932,6 @@ func TestUnionWithOrderBy(t *testing.T) { assert.Equal(t, TS1, d2) } -func TestScopingWDerivedTables(t *testing.T) { - queries := []struct { - query string - errorMessage string - recursiveExpectation TableSet - expectation TableSet - }{ - { - query: "select id from (select x as id from user) as t", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select id from (select foo as id from user) as t", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select id from (select foo as id from (select x as foo from user) as c) as t", - recursiveExpectation: TS0, - expectation: TS2, - }, { - query: "select t.id from (select foo as id from user) as t", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select t.id2 from (select foo as id from user) as t", - errorMessage: "column 't.id2' not found", - }, { - query: "select id from (select 42 as id) as t", - recursiveExpectation: T0, - expectation: TS1, - }, { - query: "select t.id from (select 42 as id) as t", - recursiveExpectation: T0, - expectation: TS1, - }, { - query: "select ks.t.id from (select 42 as id) as t", - errorMessage: "column 'ks.t.id' not found", - }, { - query: "select * from (select id, id from user) as t", - errorMessage: "Duplicate column name 'id'", - }, { - query: "select t.baz = 1 from (select id as baz from user) as t", - expectation: TS1, - recursiveExpectation: TS0, - }, { - query: "select t.id from (select * from user, music) as t", - expectation: TS2, - recursiveExpectation: MergeTableSets(TS0, TS1), - }, { - query: "select t.id from (select * from user, music) as t order by t.id", - expectation: TS2, - recursiveExpectation: MergeTableSets(TS0, TS1), - }, { - query: "select t.id from (select * from user) as t join user as u on t.id = u.id", - expectation: TS1, - recursiveExpectation: TS0, - }, { - query: "select t.col1 from t3 ua join (select t1.id, t1.col1 from t1 join t2) as t", - expectation: TS3, - recursiveExpectation: TS1, - }, { - query: "select uu.test from (select id from t1) uu", - errorMessage: "column 'uu.test' not found", - }, { - query: "select uu.id from (select id as col from t1) uu", - errorMessage: "column 'uu.id' not found", - }, { - query: "select uu.id from (select id as col from t1) uu", - errorMessage: "column 'uu.id' not found", - }, { - query: "select uu.id from (select id from t1) as uu where exists (select * from t2 as uu where uu.id = uu.uid)", - expectation: TS1, - recursiveExpectation: TS0, - }, { - query: "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", - expectation: T0, - recursiveExpectation: T0, - }} - for _, query := range queries { - t.Run(query.query, func(t *testing.T) { - parse, err := sqlparser.NewTestParser().Parse(query.query) - require.NoError(t, err) - st, err := Analyze(parse, "user", &FakeSI{ - Tables: map[string]*vindexes.Table{ - "t": {Name: sqlparser.NewIdentifierCS("t")}, - }, - }) - - switch { - case query.errorMessage != "" && err != nil: - require.EqualError(t, err, query.errorMessage) - case query.errorMessage != "": - require.EqualError(t, st.NotUnshardedErr, query.errorMessage) - default: - require.NoError(t, err) - sel := parse.(*sqlparser.Select) - assert.Equal(t, query.recursiveExpectation, st.RecursiveDeps(extract(sel, 0)), "RecursiveDeps") - assert.Equal(t, query.expectation, st.DirectDeps(extract(sel, 0)), "DirectDeps") - } - }) - } -} - func TestScopingWithWITH(t *testing.T) { queries := []struct { query string @@ -1052,7 +949,7 @@ func TestScopingWithWITH(t *testing.T) { }, { query: "with c as (select x as foo from user), t as (select foo as id from c) select id from t", recursive: TS0, - direct: TS2, + direct: TS3, }, { query: "with t as (select foo as id from user) select t.id from t", recursive: TS0, @@ -1062,11 +959,11 @@ func TestScopingWithWITH(t *testing.T) { errorMessage: "column 't.id2' not found", }, { query: "with t as (select 42 as id) select id from t", - recursive: T0, + recursive: NoTables, direct: TS1, }, { query: "with t as (select 42 as id) select t.id from t", - recursive: T0, + recursive: NoTables, direct: TS1, }, { query: "with t as (select 42 as id) select ks.t.id from t", @@ -1088,12 +985,12 @@ func TestScopingWithWITH(t *testing.T) { recursive: MergeTableSets(TS0, TS1), }, { query: "with t as (select * from user) select t.id from t join user as u on t.id = u.id", - direct: TS1, + direct: TS2, recursive: TS0, }, { query: "with t as (select t1.id, t1.col1 from t1 join t2) select t.col1 from t3 ua join t", direct: TS3, - recursive: TS1, + recursive: TS0, }, { query: "with uu as (select id from t1) select uu.test from uu", errorMessage: "column 'uu.test' not found", @@ -1105,12 +1002,12 @@ func TestScopingWithWITH(t *testing.T) { errorMessage: "column 'uu.id' not found", }, { query: "select uu.id from (select id from t1) as uu where exists (select * from t2 as uu where uu.id = uu.uid)", - direct: TS1, + direct: TS2, recursive: TS0, }, { query: "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", - direct: T0, - recursive: T0, + direct: NoTables, + recursive: NoTables, }} for _, query := range queries { t.Run(query.query, func(t *testing.T) { @@ -1152,15 +1049,15 @@ func TestJoinPredicateDependencies(t *testing.T) { directExpect: MergeTableSets(TS0, TS1), }, { query: "select 1 from (select * from t1) x join t2 on x.id = t2.uid", - recursiveExpect: MergeTableSets(TS0, TS2), + recursiveExpect: MergeTableSets(TS0, TS1), directExpect: MergeTableSets(TS1, TS2), }, { query: "select 1 from (select id from t1) x join t2 on x.id = t2.uid", - recursiveExpect: MergeTableSets(TS0, TS2), + recursiveExpect: MergeTableSets(TS0, TS1), directExpect: MergeTableSets(TS1, TS2), }, { query: "select 1 from (select id from t1 union select id from t) x join t2 on x.id = t2.uid", - recursiveExpect: MergeTableSets(TS0, TS1, TS3), + recursiveExpect: MergeTableSets(TS0, TS1, TS2), directExpect: MergeTableSets(TS2, TS3), }} for _, query := range queries { @@ -1179,107 +1076,6 @@ func TestJoinPredicateDependencies(t *testing.T) { } } -func TestDerivedTablesOrderClause(t *testing.T) { - queries := []struct { - query string - recursiveExpectation TableSet - expectation TableSet - }{{ - query: "select 1 from (select id from user) as t order by id", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select id from (select id from user) as t order by id", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select id from (select id from user) as t order by t.id", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select id as foo from (select id from user) as t order by foo", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select bar from (select id as bar from user) as t order by bar", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select bar as foo from (select id as bar from user) as t order by bar", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select bar as foo from (select id as bar from user) as t order by foo", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select bar as foo from (select id as bar, oo from user) as t order by oo", - recursiveExpectation: TS0, - expectation: TS1, - }, { - query: "select bar as foo from (select id, oo from user) as t(bar,oo) order by bar", - recursiveExpectation: TS0, - expectation: TS1, - }} - si := &FakeSI{Tables: map[string]*vindexes.Table{"t": {Name: sqlparser.NewIdentifierCS("t")}}} - for _, query := range queries { - t.Run(query.query, func(t *testing.T) { - parse, err := sqlparser.NewTestParser().Parse(query.query) - require.NoError(t, err) - - st, err := Analyze(parse, "user", si) - require.NoError(t, err) - - sel := parse.(*sqlparser.Select) - assert.Equal(t, query.recursiveExpectation, st.RecursiveDeps(sel.OrderBy[0].Expr), "RecursiveDeps") - assert.Equal(t, query.expectation, st.DirectDeps(sel.OrderBy[0].Expr), "DirectDeps") - - }) - } -} - -func TestScopingWComplexDerivedTables(t *testing.T) { - queries := []struct { - query string - errorMessage string - rightExpectation TableSet - leftExpectation TableSet - }{ - { - query: "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", - rightExpectation: TS0, - leftExpectation: TS0, - }, - { - query: "select 1 from user.user uu where exists (select 1 from user.user as uu where exists (select 1 from (select 1 from user.t1) uu where uu.user_id = uu.id))", - rightExpectation: TS1, - leftExpectation: TS1, - }, - } - for _, query := range queries { - t.Run(query.query, func(t *testing.T) { - parse, err := sqlparser.NewTestParser().Parse(query.query) - require.NoError(t, err) - st, err := Analyze(parse, "user", &FakeSI{ - Tables: map[string]*vindexes.Table{ - "t": {Name: sqlparser.NewIdentifierCS("t")}, - }, - }) - if query.errorMessage != "" { - require.EqualError(t, err, query.errorMessage) - } else { - require.NoError(t, err) - sel := parse.(*sqlparser.Select) - comparisonExpr := sel.Where.Expr.(*sqlparser.ExistsExpr).Subquery.Select.(*sqlparser.Select).Where.Expr.(*sqlparser.ExistsExpr).Subquery.Select.(*sqlparser.Select).Where.Expr.(*sqlparser.ComparisonExpr) - left := comparisonExpr.Left - right := comparisonExpr.Right - assert.Equal(t, query.leftExpectation, st.RecursiveDeps(left), "Left RecursiveDeps") - assert.Equal(t, query.rightExpectation, st.RecursiveDeps(right), "Right RecursiveDeps") - } - }) - } -} - func TestScopingWVindexTables(t *testing.T) { queries := []struct { query string @@ -1399,36 +1195,6 @@ func BenchmarkAnalyzeSubQueries(b *testing.B) { } } -func BenchmarkAnalyzeDerivedTableQueries(b *testing.B) { - queries := []string{ - "select id from (select x as id from user) as t", - "select id from (select foo as id from user) as t", - "select id from (select foo as id from (select x as foo from user) as c) as t", - "select t.id from (select foo as id from user) as t", - "select t.id2 from (select foo as id from user) as t", - "select id from (select 42 as id) as t", - "select t.id from (select 42 as id) as t", - "select ks.t.id from (select 42 as id) as t", - "select * from (select id, id from user) as t", - "select t.baz = 1 from (select id as baz from user) as t", - "select t.id from (select * from user, music) as t", - "select t.id from (select * from user, music) as t order by t.id", - "select t.id from (select * from user) as t join user as u on t.id = u.id", - "select t.col1 from t3 ua join (select t1.id, t1.col1 from t1 join t2) as t", - "select uu.id from (select id from t1) as uu where exists (select * from t2 as uu where uu.id = uu.uid)", - "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", - } - - for i := 0; i < b.N; i++ { - for _, query := range queries { - parse, err := sqlparser.NewTestParser().Parse(query) - require.NoError(b, err) - - _, _ = Analyze(parse, "d", fakeSchemaInfo()) - } - } -} - func BenchmarkAnalyzeHavingQueries(b *testing.B) { queries := []string{ "select col from tabl having col = 1", @@ -1521,43 +1287,30 @@ func TestSingleUnshardedKeyspace(t *testing.T) { tests := []struct { query string unsharded *vindexes.Keyspace - tables []*vindexes.Table }{ { query: "select 1 from t, t1", unsharded: nil, // both tables are unsharded, but from different keyspaces - tables: nil, }, { query: "select 1 from t2", unsharded: nil, - tables: nil, }, { query: "select 1 from t, t2", unsharded: nil, - tables: nil, }, { query: "select 1 from t as A, t as B", - unsharded: ks1, - tables: []*vindexes.Table{ - tableT(), - tableT(), - }, + unsharded: unsharded, }, { query: "insert into t select * from t", - unsharded: ks1, - tables: []*vindexes.Table{ - tableT(), - tableT(), - }, + unsharded: unsharded, }, } for _, test := range tests { t.Run(test.query, func(t *testing.T) { _, semTable := parseAndAnalyze(t, test.query, "d") - queryIsUnsharded, tables := semTable.SingleUnshardedKeyspace() + queryIsUnsharded, _ := semTable.SingleUnshardedKeyspace() assert.Equal(t, test.unsharded, queryIsUnsharded) - assert.Equal(t, test.tables, tables) }) } } @@ -1611,13 +1364,13 @@ func TestScopingSubQueryJoinClause(t *testing.T) { } -var ks1 = &vindexes.Keyspace{ - Name: "ks1", +var unsharded = &vindexes.Keyspace{ + Name: "unsharded", Sharded: false, } var ks2 = &vindexes.Keyspace{ Name: "ks2", - Sharded: false, + Sharded: true, } var ks3 = &vindexes.Keyspace{ Name: "ks3", @@ -1628,7 +1381,6 @@ var ks3 = &vindexes.Keyspace{ // create table t1(id bigint) // create table t2(uid bigint, name varchar(255)) func fakeSchemaInfo() *FakeSI { - si := &FakeSI{ Tables: map[string]*vindexes.Table{ "t": tableT(), @@ -1642,10 +1394,7 @@ func fakeSchemaInfo() *FakeSI { func tableT() *vindexes.Table { return &vindexes.Table{ Name: sqlparser.NewIdentifierCS("t"), - Keyspace: ks1, - ColumnVindexes: []*vindexes.ColumnVindex{ - {Name: "id_vindex"}, - }, + Keyspace: unsharded, } } func tableT1() *vindexes.Table { @@ -1656,7 +1405,10 @@ func tableT1() *vindexes.Table { Type: querypb.Type_INT64, }}, ColumnListAuthoritative: true, - Keyspace: ks2, + ColumnVindexes: []*vindexes.ColumnVindex{ + {Name: "id_vindex"}, + }, + Keyspace: ks2, } } func tableT2() *vindexes.Table { diff --git a/go/vt/vtgate/semantics/derived_test.go b/go/vt/vtgate/semantics/derived_test.go new file mode 100644 index 00000000000..8344fd1e261 --- /dev/null +++ b/go/vt/vtgate/semantics/derived_test.go @@ -0,0 +1,265 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package semantics + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtgate/vindexes" +) + +func TestScopingWDerivedTables(t *testing.T) { + queries := []struct { + query string + errorMessage string + recursiveDeps TableSet + directDeps TableSet + }{ + { + query: "select id from (select x as id from user) as t", + recursiveDeps: TS0, + directDeps: TS1, + }, { + query: "select id from (select foo as id from user) as t", + recursiveDeps: TS0, + directDeps: TS1, + }, { + query: "select id from (select foo as id from (select x as foo from user) as c) as t", + recursiveDeps: TS0, + directDeps: TS2, + }, { + query: "select t.id from (select foo as id from user) as t", + recursiveDeps: TS0, + directDeps: TS1, + }, { + query: "select t.id2 from (select foo as id from user) as t", + errorMessage: "column 't.id2' not found", + }, { + query: "select id from (select 42 as id) as t", + recursiveDeps: NoTables, + directDeps: TS1, + }, { + query: "select t.id from (select 42 as id) as t", + recursiveDeps: NoTables, + directDeps: TS1, + }, { + query: "select ks.t.id from (select 42 as id) as t", + errorMessage: "column 'ks.t.id' not found", + }, { + query: "select * from (select id, id from user) as t", + errorMessage: "Duplicate column name 'id'", + }, { + query: "select t.baz = 1 from (select id as baz from user) as t", + directDeps: TS1, + recursiveDeps: TS0, + }, { + query: "select t.id from (select * from user, music) as t", + directDeps: TS2, + recursiveDeps: MergeTableSets(TS0, TS1), + }, { + query: "select t.id from (select * from user, music) as t order by t.id", + directDeps: TS2, + recursiveDeps: MergeTableSets(TS0, TS1), + }, { + query: "select t.id from (select * from user) as t join user as u on t.id = u.id", + directDeps: TS2, + recursiveDeps: TS0, + }, { + query: "select t.col1 from t3 ua join (select t1.id, t1.col1 from t1 join t2) as t", + directDeps: TS3, + recursiveDeps: TS1, + }, { + query: "select uu.test from (select id from t1) uu", + errorMessage: "column 'uu.test' not found", + }, { + query: "select uu.id from (select id as col from t1) uu", + errorMessage: "column 'uu.id' not found", + }, { + query: "select uu.id from (select id as col from t1) uu", + errorMessage: "column 'uu.id' not found", + }, { + query: "select uu.id from (select id from t1) as uu where exists (select * from t2 as uu where uu.id = uu.uid)", + directDeps: TS2, + recursiveDeps: TS0, + }, { + query: "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", + directDeps: NoTables, + recursiveDeps: NoTables, + }, { + query: "select uu.count from (select count(*) as `count` from t1) uu", + directDeps: TS1, + recursiveDeps: TS0, + }} + for _, query := range queries { + t.Run(query.query, func(t *testing.T) { + parse, err := sqlparser.NewTestParser().Parse(query.query) + require.NoError(t, err) + st, err := Analyze(parse, "user", &FakeSI{ + Tables: map[string]*vindexes.Table{ + "t": {Name: sqlparser.NewIdentifierCS("t"), Keyspace: ks2}, + }, + }) + + switch { + case query.errorMessage != "" && err != nil: + require.EqualError(t, err, query.errorMessage) + case query.errorMessage != "": + require.EqualError(t, st.NotUnshardedErr, query.errorMessage) + default: + require.NoError(t, err) + sel := parse.(*sqlparser.Select) + assert.Equal(t, query.recursiveDeps, st.RecursiveDeps(extract(sel, 0)), "RecursiveDeps") + assert.Equal(t, query.directDeps, st.DirectDeps(extract(sel, 0)), "DirectDeps") + } + }) + } +} + +func TestDerivedTablesOrderClause(t *testing.T) { + queries := []struct { + query string + recursiveExpectation TableSet + expectation TableSet + }{{ + query: "select 1 from (select id from user) as t order by id", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select id from (select id from user) as t order by id", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select id from (select id from user) as t order by t.id", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select id as foo from (select id from user) as t order by foo", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select bar from (select id as bar from user) as t order by bar", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select bar as foo from (select id as bar from user) as t order by bar", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select bar as foo from (select id as bar from user) as t order by foo", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select bar as foo from (select id as bar, oo from user) as t order by oo", + recursiveExpectation: TS0, + expectation: TS1, + }, { + query: "select bar as foo from (select id, oo from user) as t(bar,oo) order by bar", + recursiveExpectation: TS0, + expectation: TS1, + }} + si := &FakeSI{Tables: map[string]*vindexes.Table{"t": {Name: sqlparser.NewIdentifierCS("t")}}} + for _, query := range queries { + t.Run(query.query, func(t *testing.T) { + parse, err := sqlparser.NewTestParser().Parse(query.query) + require.NoError(t, err) + + st, err := Analyze(parse, "user", si) + require.NoError(t, err) + + sel := parse.(*sqlparser.Select) + assert.Equal(t, query.recursiveExpectation, st.RecursiveDeps(sel.OrderBy[0].Expr), "RecursiveDeps") + assert.Equal(t, query.expectation, st.DirectDeps(sel.OrderBy[0].Expr), "DirectDeps") + + }) + } +} + +func TestScopingWComplexDerivedTables(t *testing.T) { + queries := []struct { + query string + errorMessage string + rightExpectation TableSet + leftExpectation TableSet + }{ + { + query: "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", + rightExpectation: TS0, + leftExpectation: TS0, + }, + { + query: "select 1 from user.user uu where exists (select 1 from user.user as uu where exists (select 1 from (select 1 from user.t1) uu where uu.user_id = uu.id))", + rightExpectation: TS1, + leftExpectation: TS1, + }, + } + for _, query := range queries { + t.Run(query.query, func(t *testing.T) { + parse, err := sqlparser.NewTestParser().Parse(query.query) + require.NoError(t, err) + st, err := Analyze(parse, "user", &FakeSI{ + Tables: map[string]*vindexes.Table{ + "t": {Name: sqlparser.NewIdentifierCS("t")}, + }, + }) + if query.errorMessage != "" { + require.EqualError(t, err, query.errorMessage) + } else { + require.NoError(t, err) + sel := parse.(*sqlparser.Select) + comparisonExpr := sel.Where.Expr.(*sqlparser.ExistsExpr).Subquery.Select.(*sqlparser.Select).Where.Expr.(*sqlparser.ExistsExpr).Subquery.Select.(*sqlparser.Select).Where.Expr.(*sqlparser.ComparisonExpr) + left := comparisonExpr.Left + right := comparisonExpr.Right + assert.Equal(t, query.leftExpectation, st.RecursiveDeps(left), "Left RecursiveDeps") + assert.Equal(t, query.rightExpectation, st.RecursiveDeps(right), "Right RecursiveDeps") + } + }) + } +} + +func BenchmarkAnalyzeDerivedTableQueries(b *testing.B) { + queries := []string{ + "select id from (select x as id from user) as t", + "select id from (select foo as id from user) as t", + "select id from (select foo as id from (select x as foo from user) as c) as t", + "select t.id from (select foo as id from user) as t", + "select t.id2 from (select foo as id from user) as t", + "select id from (select 42 as id) as t", + "select t.id from (select 42 as id) as t", + "select ks.t.id from (select 42 as id) as t", + "select * from (select id, id from user) as t", + "select t.baz = 1 from (select id as baz from user) as t", + "select t.id from (select * from user, music) as t", + "select t.id from (select * from user, music) as t order by t.id", + "select t.id from (select * from user) as t join user as u on t.id = u.id", + "select t.col1 from t3 ua join (select t1.id, t1.col1 from t1 join t2) as t", + "select uu.id from (select id from t1) as uu where exists (select * from t2 as uu where uu.id = uu.uid)", + "select 1 from user uu where exists (select 1 from user where exists (select 1 from (select 1 from t1) uu where uu.user_id = uu.id))", + } + + for i := 0; i < b.N; i++ { + for _, query := range queries { + parse, err := sqlparser.NewTestParser().Parse(query) + require.NoError(b, err) + + _, _ = Analyze(parse, "d", fakeSchemaInfo()) + } + } +} diff --git a/go/vt/vtgate/semantics/early_rewriter_test.go b/go/vt/vtgate/semantics/early_rewriter_test.go index 3b7b30d5f39..e681f722b1d 100644 --- a/go/vt/vtgate/semantics/early_rewriter_test.go +++ b/go/vt/vtgate/semantics/early_rewriter_test.go @@ -32,7 +32,7 @@ import ( func TestExpandStar(t *testing.T) { ks := &vindexes.Keyspace{ Name: "main", - Sharded: false, + Sharded: true, } schemaInfo := &FakeSI{ Tables: map[string]*vindexes.Table{ @@ -483,7 +483,7 @@ func TestSemTableDependenciesAfterExpandStar(t *testing.T) { func TestRewriteNot(t *testing.T) { ks := &vindexes.Keyspace{ Name: "main", - Sharded: false, + Sharded: true, } schemaInfo := &FakeSI{ Tables: map[string]*vindexes.Table{ @@ -535,7 +535,7 @@ func TestRewriteNot(t *testing.T) { func TestConstantFolding(t *testing.T) { ks := &vindexes.Keyspace{ Name: "main", - Sharded: false, + Sharded: true, } schemaInfo := &FakeSI{ Tables: map[string]*vindexes.Table{ @@ -612,13 +612,13 @@ func TestDeleteTargetTableRewrite(t *testing.T) { sql string target string }{{ - sql: "delete from t", - target: "t", - }, { - sql: "delete from t t1", + sql: "delete from t1", target: "t1", }, { - sql: "delete t2 from t t1, t t2", + sql: "delete from t1 XYZ", + target: "XYZ", + }, { + sql: "delete t2 from t1 t1, t t2", target: "t2", }, { sql: "delete t2,t1 from t t1, t t2", diff --git a/go/vt/vtgate/semantics/semantic_state.go b/go/vt/vtgate/semantics/semantic_state.go index 74bd9dd1d69..91c535ffaff 100644 --- a/go/vt/vtgate/semantics/semantic_state.go +++ b/go/vt/vtgate/semantics/semantic_state.go @@ -306,7 +306,7 @@ func (st *SemTable) ErrIfFkDependentColumnUpdated(updateExprs sqlparser.UpdateEx for _, updateExpr := range updateExprs { deps := st.RecursiveDeps(updateExpr.Name) if deps.NumberOfTables() != 1 { - panic("expected to have single table dependency") + return vterrors.VT13001("expected to have single table dependency") } // Get all the child and parent foreign keys for the given table that the update expression belongs to. childFks := st.childForeignKeysInvolved[deps] @@ -730,6 +730,10 @@ func (st *SemTable) ColumnLookup(col *sqlparser.ColName) (int, error) { // SingleUnshardedKeyspace returns the single keyspace if all tables in the query are in the same, unsharded keyspace func (st *SemTable) SingleUnshardedKeyspace() (ks *vindexes.Keyspace, tables []*vindexes.Table) { + return singleUnshardedKeyspace(st.Tables) +} + +func singleUnshardedKeyspace(tableInfos []TableInfo) (ks *vindexes.Keyspace, tables []*vindexes.Table) { validKS := func(this *vindexes.Keyspace) bool { if this == nil || this.Sharded { return false @@ -744,7 +748,7 @@ func (st *SemTable) SingleUnshardedKeyspace() (ks *vindexes.Keyspace, tables []* return true } - for _, table := range st.Tables { + for _, table := range tableInfos { if _, isDT := table.(*DerivedTable); isDT { continue } diff --git a/go/vt/vtgate/semantics/semantic_state_test.go b/go/vt/vtgate/semantics/semantic_state_test.go index 4ae0a5562b5..84f8cec6cf9 100644 --- a/go/vt/vtgate/semantics/semantic_state_test.go +++ b/go/vt/vtgate/semantics/semantic_state_test.go @@ -844,7 +844,7 @@ func TestIsFkDependentColumnUpdated(t *testing.T) { Tables: map[string]*vindexes.Table{ "t1": { Name: sqlparser.NewIdentifierCS("t1"), - Keyspace: &vindexes.Keyspace{Name: keyspaceName}, + Keyspace: &vindexes.Keyspace{Name: keyspaceName, Sharded: true}, }, }, }, diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index 90f939538a6..5bc160f52a6 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -35,15 +35,73 @@ type tableCollector struct { currentDb string org originable unionInfo map[*sqlparser.Union]unionInfo + done map[*sqlparser.AliasedTableExpr]TableInfo } -func newTableCollector(scoper *scoper, si SchemaInformation, currentDb string) *tableCollector { +type earlyTableCollector struct { + si SchemaInformation + currentDb string + Tables []TableInfo + done map[*sqlparser.AliasedTableExpr]TableInfo + withTables map[sqlparser.IdentifierCS]any +} + +func newEarlyTableCollector(si SchemaInformation, currentDb string) *earlyTableCollector { + return &earlyTableCollector{ + si: si, + currentDb: currentDb, + done: map[*sqlparser.AliasedTableExpr]TableInfo{}, + withTables: map[sqlparser.IdentifierCS]any{}, + } +} + +func (etc *earlyTableCollector) up(cursor *sqlparser.Cursor) { + switch node := cursor.Node().(type) { + case *sqlparser.AliasedTableExpr: + etc.visitAliasedTableExpr(node) + case *sqlparser.With: + for _, cte := range node.CTEs { + etc.withTables[cte.ID] = nil + } + } +} + +func (etc *earlyTableCollector) visitAliasedTableExpr(aet *sqlparser.AliasedTableExpr) { + tbl, ok := aet.Expr.(sqlparser.TableName) + if !ok { + return + } + etc.handleTableName(tbl, aet) +} + +func (etc *earlyTableCollector) newTableCollector(scoper *scoper, org originable) *tableCollector { return &tableCollector{ + Tables: etc.Tables, scoper: scoper, - si: si, - currentDb: currentDb, + si: etc.si, + currentDb: etc.currentDb, unionInfo: map[*sqlparser.Union]unionInfo{}, + done: etc.done, + org: org, + } +} + +func (etc *earlyTableCollector) handleTableName(tbl sqlparser.TableName, aet *sqlparser.AliasedTableExpr) { + if tbl.Qualifier.IsEmpty() { + _, isCTE := etc.withTables[tbl.Name] + if isCTE { + // no need to handle these tables here, we wait for the late phase instead + return + } + } + tableInfo, err := getTableInfo(aet, tbl, etc.si, etc.currentDb) + if err != nil { + // this could just be a CTE that we haven't processed, so we'll give it the benefit of the doubt for now + return } + + etc.done[aet] = tableInfo + etc.Tables = append(etc.Tables, tableInfo) } func (tc *tableCollector) up(cursor *sqlparser.Cursor) error { @@ -103,28 +161,42 @@ func (tc *tableCollector) visitAliasedTableExpr(node *sqlparser.AliasedTableExpr return nil } -func (tc *tableCollector) handleTableName(node *sqlparser.AliasedTableExpr, t sqlparser.TableName) error { +func (tc *tableCollector) handleTableName(node *sqlparser.AliasedTableExpr, t sqlparser.TableName) (err error) { + var tableInfo TableInfo + var found bool + + tableInfo, found = tc.done[node] + if !found { + tableInfo, err = getTableInfo(node, t, tc.si, tc.currentDb) + if err != nil { + return err + } + tc.Tables = append(tc.Tables, tableInfo) + } + + scope := tc.scoper.currentScope() + return scope.addTable(tableInfo) +} + +func getTableInfo(node *sqlparser.AliasedTableExpr, t sqlparser.TableName, si SchemaInformation, currentDb string) (TableInfo, error) { var tbl *vindexes.Table var vindex vindexes.Vindex isInfSchema := sqlparser.SystemSchema(t.Qualifier.String()) var err error - tbl, vindex, _, _, _, err = tc.si.FindTableOrVindex(t) + tbl, vindex, _, _, _, err = si.FindTableOrVindex(t) if err != nil && !isInfSchema { // if we are dealing with a system table, it might not be available in the vschema, but that is OK - return err + return nil, err } if tbl == nil && vindex != nil { tbl = newVindexTable(t.Name) } - scope := tc.scoper.currentScope() - tableInfo, err := tc.createTable(t, node, tbl, isInfSchema, vindex) + tableInfo, err := createTable(t, node, tbl, isInfSchema, vindex, si, currentDb) if err != nil { - return err + return nil, err } - - tc.Tables = append(tc.Tables, tableInfo) - return scope.addTable(tableInfo) + return tableInfo, nil } func (tc *tableCollector) handleDerivedTable(node *sqlparser.AliasedTableExpr, t *sqlparser.DerivedTable) error { @@ -228,12 +300,14 @@ func (tc *tableCollector) tableInfoFor(id TableSet) (TableInfo, error) { return tc.Tables[offset], nil } -func (tc *tableCollector) createTable( +func createTable( t sqlparser.TableName, alias *sqlparser.AliasedTableExpr, tbl *vindexes.Table, isInfSchema bool, vindex vindexes.Vindex, + si SchemaInformation, + currentDb string, ) (TableInfo, error) { hint := getVindexHint(alias.Hints) @@ -247,13 +321,13 @@ func (tc *tableCollector) createTable( Table: tbl, VindexHint: hint, isInfSchema: isInfSchema, - collationEnv: tc.si.Environment().CollationEnv(), + collationEnv: si.Environment().CollationEnv(), } if alias.As.IsEmpty() { dbName := t.Qualifier.String() if dbName == "" { - dbName = tc.currentDb + dbName = currentDb } table.dbName = dbName