diff --git a/go/test/endtoend/vtgate/queries/subquery/schema.sql b/go/test/endtoend/vtgate/queries/subquery/schema.sql index a64ac799a4e..9dfa963d340 100644 --- a/go/test/endtoend/vtgate/queries/subquery/schema.sql +++ b/go/test/endtoend/vtgate/queries/subquery/schema.sql @@ -4,18 +4,21 @@ create table t1 id2 bigint, primary key (id1) ) Engine = InnoDB; + create table t1_id2_idx ( id2 bigint, keyspace_id varbinary(10), primary key (id2) ) Engine = InnoDB; + create table t2 ( id3 bigint, id4 bigint, primary key (id3) ) Engine = InnoDB; + create table t2_id4_idx ( id bigint not null auto_increment, @@ -23,4 +26,17 @@ create table t2_id4_idx id3 bigint, primary key (id), key idx_id4 (id4) -) Engine = InnoDB; \ No newline at end of file +) Engine = InnoDB; + +CREATE TABLE user +( + id INT PRIMARY KEY, + name VARCHAR(100) +); + +CREATE TABLE user_extra +( + user_id INT, + extra_info VARCHAR(100), + PRIMARY KEY (user_id, extra_info) +); \ No newline at end of file diff --git a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go index 59dc42de060..fe7c65d9c2b 100644 --- a/go/test/endtoend/vtgate/queries/subquery/subquery_test.go +++ b/go/test/endtoend/vtgate/queries/subquery/subquery_test.go @@ -17,6 +17,7 @@ limitations under the License. package subquery import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -188,3 +189,46 @@ func TestSubqueryInDerivedTable(t *testing.T) { mcmp.Exec(`select t.a from (select t1.id2, t2.id3, (select id2 from t1 order by id2 limit 1) as a from t1 join t2 on t1.id1 = t2.id4) t`) mcmp.Exec(`SELECT COUNT(*) FROM (SELECT DISTINCT t1.id1 FROM t1 JOIN t2 ON t1.id1 = t2.id4) dt`) } + +func TestSubqueries(t *testing.T) { + // This method tests many types of subqueries. The queries should move to a vitess-tester test file once we have a way to run them. + // The commented out queries are failing because of wrong types being returned. + // The tests are commented out until the issue is fixed. + utils.SkipIfBinaryIsBelowVersion(t, 19, "vtgate") + mcmp, closer := start(t) + defer closer() + queries := []string{ + `INSERT INTO user (id, name) VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'), (4, 'David'), (5, 'Eve'), (6, 'Frank'), (7, 'Grace'), (8, 'Hannah'), (9, 'Ivy'), (10, 'Jack')`, + `INSERT INTO user_extra (user_id, extra_info) VALUES (1, 'info1'), (1, 'info2'), (2, 'info1'), (3, 'info1'), (3, 'info2'), (4, 'info1'), (5, 'info1'), (6, 'info1'), (7, 'info1'), (8, 'info1')`, + `SELECT (SELECT COUNT(*) FROM user_extra) AS order_count, id FROM user WHERE id = (SELECT COUNT(*) FROM user_extra)`, + `SELECT id, (SELECT COUNT(*) FROM user_extra) AS order_count FROM user ORDER BY (SELECT COUNT(*) FROM user_extra)`, + `SELECT id FROM user WHERE id = (SELECT COUNT(*) FROM user_extra) ORDER BY (SELECT COUNT(*) FROM user_extra)`, + `SELECT (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) AS extra_count, id, name FROM user WHERE (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) > 0`, + `SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) AS extra_count FROM user ORDER BY (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`, + `SELECT id, name FROM user WHERE (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) > 0 ORDER BY (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`, + `SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) AS extra_count FROM user GROUP BY id, name HAVING COUNT(*) > (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`, + `SELECT id, name, COUNT(*) FROM user WHERE (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) > 0 GROUP BY id, name HAVING COUNT(*) > (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`, + `SELECT id, round(MAX(id + (SELECT COUNT(*) FROM user_extra where user_id = 42))) as r FROM user WHERE id = 42 GROUP BY id ORDER BY r`, + `SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) * 2 AS double_extra_count FROM user`, + `SELECT id, name FROM user WHERE id IN (SELECT user_id FROM user_extra WHERE LENGTH(extra_info) > 4)`, + `SELECT id, COUNT(*) FROM user GROUP BY id HAVING COUNT(*) > (SELECT COUNT(*) FROM user_extra WHERE user_extra.user_id = user.id) + 1`, + `SELECT id, name FROM user ORDER BY (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) * id`, + `SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) + id AS extra_count_plus_id FROM user`, + `SELECT id, name FROM user WHERE id IN (SELECT user_id FROM user_extra WHERE extra_info = 'info1') OR id IN (SELECT user_id FROM user_extra WHERE extra_info = 'info2')`, + `SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count, SUM(id) AS sum_ids FROM user GROUP BY id, name ORDER BY (SELECT COUNT(*) FROM user_extra)`, + // `SELECT id, name, (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info, AVG(id) AS avg_ids FROM user GROUP BY id, name HAVING (SELECT SUM(LENGTH(extra_info)) FROM user_extra) > 10`, + `SELECT id, name, (SELECT AVG(LENGTH(extra_info)) FROM user_extra) AS avg_length_extra_info, MAX(id) AS max_id FROM user WHERE id IN (SELECT user_id FROM user_extra) GROUP BY id, name`, + `SELECT id, name, (SELECT MAX(LENGTH(extra_info)) FROM user_extra) AS max_length_extra_info, MIN(id) AS min_id FROM user GROUP BY id, name ORDER BY (SELECT MAX(LENGTH(extra_info)) FROM user_extra)`, + `SELECT id, name, (SELECT MIN(LENGTH(extra_info)) FROM user_extra) AS min_length_extra_info, SUM(id) AS sum_ids FROM user GROUP BY id, name HAVING (SELECT MIN(LENGTH(extra_info)) FROM user_extra) < 5`, + `SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count, AVG(id) AS avg_ids FROM user WHERE id > (SELECT COUNT(*) FROM user_extra) GROUP BY id, name`, + // `SELECT id, name, (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info, COUNT(id) AS count_ids FROM user GROUP BY id, name ORDER BY (SELECT SUM(LENGTH(extra_info)) FROM user_extra)`, + // `SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count, (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info, (SELECT AVG(LENGTH(extra_info)) FROM user_extra) AS avg_length_extra_info, (SELECT MAX(LENGTH(extra_info)) FROM user_extra) AS max_length_extra_info, (SELECT MIN(LENGTH(extra_info)) FROM user_extra) AS min_length_extra_info, SUM(id) AS sum_ids FROM user GROUP BY id, name HAVING (SELECT AVG(LENGTH(extra_info)) FROM user_extra) > 2`, + `SELECT id, name, (SELECT COUNT(*) FROM user_extra) + id AS total_extra_count_plus_id, AVG(id) AS avg_ids FROM user WHERE id < (SELECT MAX(user_id) FROM user_extra) GROUP BY id, name`, + } + + for idx, query := range queries { + mcmp.Run(fmt.Sprintf("%d %s", idx, query), func(mcmp *utils.MySQLCompare) { + mcmp.Exec(query) + }) + } +} diff --git a/go/test/endtoend/vtgate/queries/subquery/vschema.json b/go/test/endtoend/vtgate/queries/subquery/vschema.json index da4e589f20f..a98255db65e 100644 --- a/go/test/endtoend/vtgate/queries/subquery/vschema.json +++ b/go/test/endtoend/vtgate/queries/subquery/vschema.json @@ -22,6 +22,9 @@ "autocommit": "true" }, "owner": "t2" + }, + "xxhash": { + "type": "xxhash" } }, "tables": { @@ -64,6 +67,34 @@ "name": "hash" } ] + }, + "user_extra": { + "name": "user_extra", + "column_vindexes": [ + { + "columns": [ + "user_id", + "extra_info" + ], + "type": "xxhash", + "name": "xxhash", + "vindex": null + } + ] + }, + "user": { + "name": "user", + "column_vindexes": [ + { + "columns": [ + "id" + ], + "type": "xxhash", + "name": "xxhash", + "vindex": null + } + ] } + } } \ No newline at end of file diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index ca7aae0f385..1dfc1050e0a 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -2864,6 +2864,8 @@ type ( Expr GetArg() Expr GetArgs() Exprs + SetArg(expr Expr) + SetArgs(exprs Exprs) error // AggrName returns the lower case string representing this aggregation function AggrName() string } @@ -3375,6 +3377,51 @@ func (varS *VarSamp) GetArgs() Exprs { return Exprs{varS.Arg} } func (variance *Variance) GetArgs() Exprs { return Exprs{variance.Arg} } func (av *AnyValue) GetArgs() Exprs { return Exprs{av.Arg} } +func (min *Min) SetArg(expr Expr) { min.Arg = expr } +func (sum *Sum) SetArg(expr Expr) { sum.Arg = expr } +func (max *Max) SetArg(expr Expr) { max.Arg = expr } +func (avg *Avg) SetArg(expr Expr) { avg.Arg = expr } +func (*CountStar) SetArg(expr Expr) {} +func (count *Count) SetArg(expr Expr) { count.Args = Exprs{expr} } +func (grpConcat *GroupConcatExpr) SetArg(expr Expr) { grpConcat.Exprs = Exprs{expr} } +func (bAnd *BitAnd) SetArg(expr Expr) { bAnd.Arg = expr } +func (bOr *BitOr) SetArg(expr Expr) { bOr.Arg = expr } +func (bXor *BitXor) SetArg(expr Expr) { bXor.Arg = expr } +func (std *Std) SetArg(expr Expr) { std.Arg = expr } +func (stdD *StdDev) SetArg(expr Expr) { stdD.Arg = expr } +func (stdP *StdPop) SetArg(expr Expr) { stdP.Arg = expr } +func (stdS *StdSamp) SetArg(expr Expr) { stdS.Arg = expr } +func (varP *VarPop) SetArg(expr Expr) { varP.Arg = expr } +func (varS *VarSamp) SetArg(expr Expr) { varS.Arg = expr } +func (variance *Variance) SetArg(expr Expr) { variance.Arg = expr } +func (av *AnyValue) SetArg(expr Expr) { av.Arg = expr } + +func (min *Min) SetArgs(exprs Exprs) error { return setFuncArgs(min, exprs, "MIN") } +func (sum *Sum) SetArgs(exprs Exprs) error { return setFuncArgs(sum, exprs, "SUM") } +func (max *Max) SetArgs(exprs Exprs) error { return setFuncArgs(max, exprs, "MAX") } +func (avg *Avg) SetArgs(exprs Exprs) error { return setFuncArgs(avg, exprs, "AVG") } +func (*CountStar) SetArgs(Exprs) error { return nil } +func (bAnd *BitAnd) SetArgs(exprs Exprs) error { return setFuncArgs(bAnd, exprs, "BIT_AND") } +func (bOr *BitOr) SetArgs(exprs Exprs) error { return setFuncArgs(bOr, exprs, "BIT_OR") } +func (bXor *BitXor) SetArgs(exprs Exprs) error { return setFuncArgs(bXor, exprs, "BIT_XOR") } +func (std *Std) SetArgs(exprs Exprs) error { return setFuncArgs(std, exprs, "STD") } +func (stdD *StdDev) SetArgs(exprs Exprs) error { return setFuncArgs(stdD, exprs, "STDDEV") } +func (stdP *StdPop) SetArgs(exprs Exprs) error { return setFuncArgs(stdP, exprs, "STDDEV_POP") } +func (stdS *StdSamp) SetArgs(exprs Exprs) error { return setFuncArgs(stdS, exprs, "STDDEV_SAMP") } +func (varP *VarPop) SetArgs(exprs Exprs) error { return setFuncArgs(varP, exprs, "VAR_POP") } +func (varS *VarSamp) SetArgs(exprs Exprs) error { return setFuncArgs(varS, exprs, "VAR_SAMP") } +func (variance *Variance) SetArgs(exprs Exprs) error { return setFuncArgs(variance, exprs, "VARIANCE") } +func (av *AnyValue) SetArgs(exprs Exprs) error { return setFuncArgs(av, exprs, "ANY_VALUE") } + +func (count *Count) SetArgs(exprs Exprs) error { + count.Args = exprs + return nil +} +func (grpConcat *GroupConcatExpr) SetArgs(exprs Exprs) error { + grpConcat.Exprs = exprs + return nil +} + func (sum *Sum) IsDistinct() bool { return sum.Distinct } func (min *Min) IsDistinct() bool { return min.Distinct } func (max *Max) IsDistinct() bool { return max.Distinct } diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 62bdded7598..b3798bbd28f 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -2103,6 +2103,15 @@ func ContainsAggregation(e SQLNode) bool { return hasAggregates } +// setFuncArgs sets the arguments for the aggregation function, while checking that there is only one argument +func setFuncArgs(aggr AggrFunc, exprs Exprs, name string) error { + if len(exprs) != 1 { + return vterrors.VT03001(name) + } + aggr.SetArg(exprs[0]) + return nil +} + // GetFirstSelect gets the first select statement func GetFirstSelect(selStmt SelectStatement) *Select { if selStmt == nil { diff --git a/go/vt/vtgate/planbuilder/fuzz.go b/go/vt/vtgate/planbuilder/fuzz.go index 6b8b37ba43f..79dcca01a53 100644 --- a/go/vt/vtgate/planbuilder/fuzz.go +++ b/go/vt/vtgate/planbuilder/fuzz.go @@ -20,12 +20,12 @@ import ( "sync" "testing" + fuzz "github.com/AdaLogics/go-fuzz-headers" + "vitess.io/vitess/go/json2" "vitess.io/vitess/go/sqltypes" vschemapb "vitess.io/vitess/go/vt/proto/vschema" "vitess.io/vitess/go/vt/vtgate/vindexes" - - fuzz "github.com/AdaLogics/go-fuzz-headers" ) var initter sync.Once diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go index a0963929eaa..25ab5f98b60 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -89,23 +89,21 @@ func reachedPhase(ctx *plancontext.PlanningContext, p Phase) bool { // Any columns that are needed to evaluate the subquery needs to be added as // grouping columns to the aggregation being pushed down, and then after the // subquery evaluation we are free to reassemble the total aggregation values. -// This is very similar to how we push aggregation through an shouldRun-join. +// This is very similar to how we push aggregation through an apply-join. func pushAggregationThroughSubquery( ctx *plancontext.PlanningContext, rootAggr *Aggregator, src *SubQueryContainer, ) (Operator, *ApplyResult) { - pushedAggr := rootAggr.Clone([]Operator{src.Outer}).(*Aggregator) - pushedAggr.Original = false - pushedAggr.Pushed = false - + pushedAggr := rootAggr.SplitAggregatorBelowOperators(ctx, []Operator{src.Outer}) for _, subQuery := range src.Inner { lhsCols := subQuery.OuterExpressionsNeeded(ctx, src.Outer) for _, colName := range lhsCols { - idx := slices.IndexFunc(pushedAggr.Columns, func(ae *sqlparser.AliasedExpr) bool { + findColName := func(ae *sqlparser.AliasedExpr) bool { return ctx.SemTable.EqualsExpr(ae.Expr, colName) - }) - if idx >= 0 { + } + if slices.IndexFunc(pushedAggr.Columns, findColName) >= 0 { + // we already have the column, no need to push it again continue } pushedAggr.addColumnWithoutPushing(ctx, aeWrap(colName), true) @@ -114,8 +112,10 @@ func pushAggregationThroughSubquery( src.Outer = pushedAggr - for _, aggregation := range pushedAggr.Aggregations { - aggregation.Original.Expr = rewriteColNameToArgument(ctx, aggregation.Original.Expr, aggregation.SubQueryExpression, src.Inner...) + for _, aggr := range pushedAggr.Aggregations { + // we rewrite columns in the aggregation to use the argument form of the subquery + aggr.Original.Expr = rewriteColNameToArgument(ctx, aggr.Original.Expr, aggr.SubQueryExpression, src.Inner...) + pushedAggr.Columns[aggr.ColOffset].Expr = rewriteColNameToArgument(ctx, pushedAggr.Columns[aggr.ColOffset].Expr, aggr.SubQueryExpression, src.Inner...) } if !rootAggr.Original { @@ -150,7 +150,7 @@ func pushAggregationThroughRoute( route *Route, ) (Operator, *ApplyResult) { // Create a new aggregator to be placed below the route. - aggrBelowRoute := aggregator.SplitAggregatorBelowRoute(route.Inputs()) + aggrBelowRoute := aggregator.SplitAggregatorBelowOperators(ctx, route.Inputs()) aggrBelowRoute.Aggregations = nil pushAggregations(ctx, aggregator, aggrBelowRoute) @@ -256,7 +256,6 @@ func pushAggregationThroughFilter( pushedAggr := aggregator.Clone([]Operator{filter.Source}).(*Aggregator) pushedAggr.Pushed = false pushedAggr.Original = false - withNextColumn: for _, col := range columnsNeeded { for _, gb := range pushedAggr.Grouping { diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index 256372c172f..0f4b5181385 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -292,6 +292,21 @@ func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) Operator { return nil } +func (aggr Aggr) setPushColumn(exprs sqlparser.Exprs) { + if aggr.Func == nil { + if len(exprs) > 1 { + panic(vterrors.VT13001(fmt.Sprintf("unexpected number of expression in an random aggregation: %s", sqlparser.String(exprs)))) + } + aggr.Original.Expr = exprs[0] + return + } + + err := aggr.Func.SetArgs(exprs) + if err != nil { + panic(err) + } +} + func (aggr Aggr) getPushColumn() sqlparser.Expr { switch aggr.OpCode { case opcode.AggregateAnyValue: @@ -311,6 +326,17 @@ func (aggr Aggr) getPushColumn() sqlparser.Expr { } } +func (aggr Aggr) getPushColumnExprs() sqlparser.Exprs { + switch aggr.OpCode { + case opcode.AggregateAnyValue: + return sqlparser.Exprs{aggr.Original.Expr} + case opcode.AggregateCountStar: + return sqlparser.Exprs{sqlparser.NewIntLiteral("1")} + default: + return aggr.Func.GetArgs() + } +} + func (a *Aggregator) planOffsetsNotPushed(ctx *plancontext.PlanningContext) { a.Source = newAliasedProjection(a.Source) // we need to keep things in the column order, so we can't iterate over the aggregations or groupings @@ -408,14 +434,26 @@ func (a *Aggregator) internalAddColumn(ctx *plancontext.PlanningContext, aliased return offset } -// SplitAggregatorBelowRoute returns the aggregator that will live under the Route. +// SplitAggregatorBelowOperators returns the aggregator that will live under the Route. // This is used when we are splitting the aggregation so one part is done // at the mysql level and one part at the vtgate level -func (a *Aggregator) SplitAggregatorBelowRoute(input []Operator) *Aggregator { +func (a *Aggregator) SplitAggregatorBelowOperators(ctx *plancontext.PlanningContext, input []Operator) *Aggregator { newOp := a.Clone(input).(*Aggregator) newOp.Pushed = false newOp.Original = false newOp.DT = nil + + // We need to make sure that the columns are cloned so that the original operator is not affected + // by the changes we make to the new operator + newOp.Columns = slice.Map(a.Columns, func(from *sqlparser.AliasedExpr) *sqlparser.AliasedExpr { + return ctx.SemTable.Clone(from).(*sqlparser.AliasedExpr) + }) + for idx, aggr := range newOp.Aggregations { + newOp.Aggregations[idx].Original = ctx.SemTable.Clone(aggr.Original).(*sqlparser.AliasedExpr) + } + for idx, gb := range newOp.Grouping { + newOp.Grouping[idx].Inner = ctx.SemTable.Clone(gb.Inner).(sqlparser.Expr) + } return newOp } diff --git a/go/vt/vtgate/planbuilder/operators/helpers.go b/go/vt/vtgate/planbuilder/operators/helpers.go index 0049a919e2a..31d9bcfd279 100644 --- a/go/vt/vtgate/planbuilder/operators/helpers.go +++ b/go/vt/vtgate/planbuilder/operators/helpers.go @@ -26,13 +26,13 @@ import ( "vitess.io/vitess/go/vt/vtgate/vindexes" ) +type compactable interface { + // Compact implement this interface for operators that have easy to see optimisations + Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult) +} + // compact will optimise the operator tree into a smaller but equivalent version func compact(ctx *plancontext.PlanningContext, op Operator) Operator { - type compactable interface { - // Compact implement this interface for operators that have easy to see optimisations - Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult) - } - newOp := BottomUp(op, TableID, func(op Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) { newOp, ok := op.(compactable) if !ok { diff --git a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go index 68880bef90b..66190a4eeef 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go +++ b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go @@ -24,6 +24,7 @@ import ( "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" ) func expandHorizon(ctx *plancontext.PlanningContext, horizon *Horizon) (Operator, *ApplyResult) { @@ -114,17 +115,27 @@ func expandSelectHorizon(ctx *plancontext.PlanningContext, horizon *Horizon, sel } func expandOrderBy(ctx *plancontext.PlanningContext, op Operator, qp *QueryProjection) Operator { - proj := newAliasedProjection(op) var newOrder []OrderBy sqc := &SubQueryBuilder{} + proj, ok := op.(*Projection) + for _, expr := range qp.OrderExprs { + // Attempt to extract any subqueries within the expression newExpr, subqs := sqc.pullOutValueSubqueries(ctx, expr.SimplifiedExpr, TableID(op), false) if newExpr == nil { - // no subqueries found, let's move on + // If no subqueries are found, retain the original order expression newOrder = append(newOrder, expr) continue } - proj.addSubqueryExpr(aeWrap(newExpr), newExpr, subqs...) + + // If the operator is not a projection, we cannot handle subqueries with aggregation + if !ok { + panic(vterrors.VT12001("subquery with aggregation in order by")) + } + + // Add the new subquery expression to the projection + proj.addSubqueryExpr(ctx, aeWrap(newExpr), newExpr, subqs...) + // Replace the original order expression with the new expression containing subqueries newOrder = append(newOrder, OrderBy{ Inner: &sqlparser.Order{ Expr: newExpr, @@ -132,15 +143,14 @@ func expandOrderBy(ctx *plancontext.PlanningContext, op Operator, qp *QueryProje }, SimplifiedExpr: newExpr, }) - } - if len(proj.Columns.GetColumns()) > 0 { - // if we had to project columns for the ordering, - // we need the projection as source - op = proj + // Update the source of the projection if we have it + if proj != nil { + proj.Source = sqc.getRootOperator(proj.Source, nil) } + // Return the updated operator with the new order by expressions return &Ordering{ Source: op, Order: newOrder, @@ -152,6 +162,7 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horiz var dt *DerivedTable if horizon.TableId != nil { + // if we are dealing with a derived table, we need to create a derived table object dt = &DerivedTable{ TableID: *horizon.TableId, Alias: horizon.Alias, @@ -159,13 +170,13 @@ func createProjectionFromSelect(ctx *plancontext.PlanningContext, horizon *Horiz } } - if !qp.NeedsAggregation() { - projX := createProjectionWithoutAggr(ctx, qp, horizon.src()) - projX.DT = dt - return projX + if qp.NeedsAggregation() { + return createProjectionWithAggr(ctx, qp, dt, horizon.src()) } - return createProjectionWithAggr(ctx, qp, dt, horizon.src()) + projX := createProjectionWithoutAggr(ctx, qp, horizon.src()) + projX.DT = dt + return projX } func createProjectionWithAggr(ctx *plancontext.PlanningContext, qp *QueryProjection, dt *DerivedTable, src Operator) Operator { @@ -181,13 +192,8 @@ func createProjectionWithAggr(ctx *plancontext.PlanningContext, qp *QueryProject // Go through all aggregations and check for any subquery. sqc := &SubQueryBuilder{} - outerID := TableID(src) for idx, aggr := range aggregations { - expr := aggr.Original.Expr - newExpr, subqs := sqc.pullOutValueSubqueries(ctx, expr, outerID, false) - if newExpr != nil { - aggregations[idx].SubQueryExpression = subqs - } + aggregations[idx] = pullOutValueSubqueries(ctx, aggr, sqc, TableID(src)) } aggrOp.Source = sqc.getRootOperator(src, nil) @@ -198,6 +204,25 @@ func createProjectionWithAggr(ctx *plancontext.PlanningContext, qp *QueryProject return createProjectionForSimpleAggregation(ctx, aggrOp, qp) } +func pullOutValueSubqueries(ctx *plancontext.PlanningContext, aggr Aggr, sqc *SubQueryBuilder, outerID semantics.TableSet) Aggr { + exprs := aggr.getPushColumnExprs() + var newExprs sqlparser.Exprs + for _, expr := range exprs { + newExpr, subqs := sqc.pullOutValueSubqueries(ctx, expr, outerID, false) + if newExpr != nil { + newExprs = append(newExprs, newExpr) + aggr.SubQueryExpression = append(aggr.SubQueryExpression, subqs...) + } else { + newExprs = append(newExprs, expr) + } + } + if len(aggr.SubQueryExpression) > 0 { + aggr.setPushColumn(newExprs) + } + + return aggr +} + func createProjectionForSimpleAggregation(ctx *plancontext.PlanningContext, a *Aggregator, qp *QueryProjection) Operator { outer: for colIdx, expr := range qp.SelectExprs { @@ -280,7 +305,7 @@ func createProjectionWithoutAggr(ctx *plancontext.PlanningContext, qp *QueryProj // there was no subquery in this expression proj.addUnexploredExpr(org, expr) } else { - proj.addSubqueryExpr(org, newExpr, subqs...) + proj.addSubqueryExpr(ctx, org, newExpr, subqs...) } } proj.Source = sqc.getRootOperator(src, nil) diff --git a/go/vt/vtgate/planbuilder/operators/projection.go b/go/vt/vtgate/planbuilder/operators/projection.go index 1eae4e0e06e..f46cbf21928 100644 --- a/go/vt/vtgate/planbuilder/operators/projection.go +++ b/go/vt/vtgate/planbuilder/operators/projection.go @@ -182,13 +182,24 @@ var _ selectExpressions = (*Projection)(nil) // createSimpleProjection returns a projection where all columns are offsets. // used to change the name and order of the columns in the final output -func createSimpleProjection(ctx *plancontext.PlanningContext, qp *QueryProjection, src Operator) *Projection { +func createSimpleProjection(ctx *plancontext.PlanningContext, selExprs sqlparser.SelectExprs, src Operator) *Projection { p := newAliasedProjection(src) - for _, e := range qp.SelectExprs { - ae, err := e.GetAliasedExpr() - if err != nil { - panic(err) + for _, e := range selExprs { + ae, isAe := e.(*sqlparser.AliasedExpr) + if !isAe { + panic(vterrors.VT09015()) + } + + if ae.As.IsEmpty() { + // if we don't have an alias, we can use the column name as the alias + // the expectation is that when users use columns without aliases, they want the column name as the alias + // for more complex expressions, we just assume they'll use column offsets instead of column names + col, ok := ae.Expr.(*sqlparser.ColName) + if ok { + ae.As = col.Name + } } + offset := p.Source.AddColumn(ctx, true, false, ae) expr := newProjExpr(ae) expr.Info = Offset(offset) @@ -218,11 +229,14 @@ func (p *Projection) canPush(ctx *plancontext.PlanningContext) bool { } func (p *Projection) GetAliasedProjections() (AliasedProjections, error) { - ap, ok := p.Columns.(AliasedProjections) - if !ok { + switch cols := p.Columns.(type) { + case AliasedProjections: + return cols, nil + case nil: + return nil, nil + default: return nil, vterrors.VT09015() } - return ap, nil } func (p *Projection) isDerived() bool { @@ -263,8 +277,7 @@ func (p *Projection) addProjExpr(pe ...*ProjExpr) int { } offset := len(ap) - ap = append(ap, pe...) - p.Columns = ap + p.Columns = append(ap, pe...) return offset } @@ -273,7 +286,18 @@ func (p *Projection) addUnexploredExpr(ae *sqlparser.AliasedExpr, e sqlparser.Ex return p.addProjExpr(newProjExprWithInner(ae, e)) } -func (p *Projection) addSubqueryExpr(ae *sqlparser.AliasedExpr, expr sqlparser.Expr, sqs ...*SubQuery) { +func (p *Projection) addSubqueryExpr(ctx *plancontext.PlanningContext, ae *sqlparser.AliasedExpr, expr sqlparser.Expr, sqs ...*SubQuery) { + ap, err := p.GetAliasedProjections() + if err != nil { + panic(err) + } + for _, projExpr := range ap { + if ctx.SemTable.EqualsExprWithDeps(projExpr.EvalExpr, expr) { + // if we already have this column, we can just return the offset + return + } + } + pe := newProjExprWithInner(ae, expr) pe.Info = SubQueryExpression(sqs) diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index f412e783f42..12128831ca6 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -28,6 +28,12 @@ import ( ) func planQuery(ctx *plancontext.PlanningContext, root Operator) Operator { + var selExpr sqlparser.SelectExprs + if horizon, isHorizon := root.(*Horizon); isHorizon { + sel := sqlparser.GetFirstSelect(horizon.Query) + selExpr = sqlparser.CloneSelectExprs(sel.SelectExprs) + } + output := runPhases(ctx, root) output = planOffsets(ctx, output) @@ -38,7 +44,7 @@ func planQuery(ctx *plancontext.PlanningContext, root Operator) Operator { output = compact(ctx, output) - return addTruncationOrProjectionToReturnOutput(ctx, root, output) + return addTruncationOrProjectionToReturnOutput(ctx, selExpr, output) } // runPhases is the process of figuring out how to perform the operations in the Horizon @@ -303,23 +309,14 @@ func tryPushOrdering(ctx *plancontext.PlanningContext, in *Ordering) (Operator, return Swap(in, src, "push ordering under filter") case *ApplyJoin: if canPushLeft(ctx, src, in.Order) { - // ApplyJoin is stable in regard to the columns coming from the LHS, - // so if all the ordering columns come from the LHS, we can push down the Ordering there - src.LHS, in.Source = in, src.LHS - return src, Rewrote("push down ordering on the LHS of a join") + return pushOrderLeftOfJoin(src, in) } case *Ordering: // we'll just remove the order underneath. The top order replaces whatever was incoming in.Source = src.Source return in, Rewrote("remove double ordering") case *Projection: - // we can move ordering under a projection if it's not introducing a column we're sorting by - for _, by := range in.Order { - if !mustFetchFromInput(by.SimplifiedExpr) { - return in, NoRewrite - } - } - return Swap(in, src, "push ordering under projection") + return pushOrderingUnderProjection(ctx, in, src) case *Aggregator: if !src.QP.AlignGroupByAndOrderBy(ctx) && !overlaps(ctx, in.Order, src.Grouping) { return in, NoRewrite @@ -327,29 +324,65 @@ func tryPushOrdering(ctx *plancontext.PlanningContext, in *Ordering) (Operator, return pushOrderingUnderAggr(ctx, in, src) case *SubQueryContainer: - outerTableID := TableID(src.Outer) - for _, order := range in.Order { - deps := ctx.SemTable.RecursiveDeps(order.Inner.Expr) - if !deps.IsSolvedBy(outerTableID) { - return in, NoRewrite - } - } - src.Outer, in.Source = in, src.Outer - return src, Rewrote("push ordering into outer side of subquery") + return pushOrderingToOuterOfSubqueryContainer(ctx, in, src) case *SubQuery: - outerTableID := TableID(src.Outer) - for _, order := range in.Order { - deps := ctx.SemTable.RecursiveDeps(order.Inner.Expr) - if !deps.IsSolvedBy(outerTableID) { - return in, NoRewrite - } - } - src.Outer, in.Source = in, src.Outer - return src, Rewrote("push ordering into outer side of subquery") + return pushOrderingToOuterOfSubquery(ctx, in, src) } return in, NoRewrite } +func pushOrderingToOuterOfSubquery(ctx *plancontext.PlanningContext, in *Ordering, sq *SubQuery) (Operator, *ApplyResult) { + outerTableID := TableID(sq.Outer) + for idx, order := range in.Order { + deps := ctx.SemTable.RecursiveDeps(order.Inner.Expr) + if !deps.IsSolvedBy(outerTableID) { + return in, NoRewrite + } + in.Order[idx].SimplifiedExpr = sq.rewriteColNameToArgument(order.SimplifiedExpr) + in.Order[idx].Inner.Expr = sq.rewriteColNameToArgument(order.Inner.Expr) + } + sq.Outer, in.Source = in, sq.Outer + return sq, Rewrote("push ordering into outer side of subquery") +} + +func pushOrderingToOuterOfSubqueryContainer(ctx *plancontext.PlanningContext, in *Ordering, subq *SubQueryContainer) (Operator, *ApplyResult) { + outerTableID := TableID(subq.Outer) + for _, order := range in.Order { + deps := ctx.SemTable.RecursiveDeps(order.Inner.Expr) + if !deps.IsSolvedBy(outerTableID) { + return in, NoRewrite + } + } + subq.Outer, in.Source = in, subq.Outer + return subq, Rewrote("push ordering into outer side of subquery") +} + +func pushOrderingUnderProjection(ctx *plancontext.PlanningContext, in *Ordering, proj *Projection) (Operator, *ApplyResult) { + // we can move ordering under a projection if it's not introducing a column we're sorting by + for _, by := range in.Order { + if !mustFetchFromInput(by.SimplifiedExpr) { + return in, NoRewrite + } + } + ap, ok := proj.Columns.(AliasedProjections) + if !ok { + return in, NoRewrite + } + for _, projExpr := range ap { + if projExpr.Info != nil { + return in, NoRewrite + } + } + return Swap(in, proj, "push ordering under projection") +} + +func pushOrderLeftOfJoin(src *ApplyJoin, in *Ordering) (Operator, *ApplyResult) { + // ApplyJoin is stable in regard to the columns coming from the LHS, + // so if all the ordering columns come from the LHS, we can push down the Ordering there + src.LHS, in.Source = in, src.LHS + return src, Rewrote("push down ordering on the LHS of a join") +} + func overlaps(ctx *plancontext.PlanningContext, order []OrderBy, grouping []GroupBy) bool { ordering: for _, orderBy := range order { @@ -605,25 +638,56 @@ func tryPushUnion(ctx *plancontext.PlanningContext, op *Union) (Operator, *Apply } // addTruncationOrProjectionToReturnOutput uses the original Horizon to make sure that the output columns line up with what the user asked for -func addTruncationOrProjectionToReturnOutput(ctx *plancontext.PlanningContext, oldHorizon Operator, output Operator) Operator { - horizon, ok := oldHorizon.(*Horizon) - if !ok { +func addTruncationOrProjectionToReturnOutput(ctx *plancontext.PlanningContext, selExprs sqlparser.SelectExprs, output Operator) Operator { + if len(selExprs) == 0 { return output } cols := output.GetSelectExprs(ctx) - sel := sqlparser.GetFirstSelect(horizon.Query) - if len(sel.SelectExprs) == len(cols) { + sizeCorrect := len(selExprs) == len(cols) || tryTruncateColumnsAt(output, len(selExprs)) + if sizeCorrect && colNamesAlign(selExprs, cols) { return output } - if tryTruncateColumnsAt(output, len(sel.SelectExprs)) { - return output + return createSimpleProjection(ctx, selExprs, output) +} + +func colNamesAlign(expected, actual sqlparser.SelectExprs) bool { + if len(expected) > len(actual) { + // if we expect more columns than we have, we can't align + return false + } + + for i, seE := range expected { + switch se := seE.(type) { + case *sqlparser.AliasedExpr: + if !areColumnNamesAligned(se, actual[i]) { + return false + } + case *sqlparser.StarExpr: + actualStar, isStar := actual[i].(*sqlparser.StarExpr) + if !isStar { + panic(vterrors.VT13001(fmt.Sprintf("star expression is expected here, found: %T", actual[i]))) + } + if !sqlparser.Equals.RefOfStarExpr(se, actualStar) { + return false + } + } } + return true +} - qp := horizon.getQP(ctx) - proj := createSimpleProjection(ctx, qp, output) - return proj +func areColumnNamesAligned(expectation *sqlparser.AliasedExpr, actual sqlparser.SelectExpr) bool { + _, isCol := expectation.Expr.(*sqlparser.ColName) + if expectation.As.IsEmpty() && !isCol { + // is the user didn't specify a name, we don't care + return true + } + actualAE, isAe := actual.(*sqlparser.AliasedExpr) + if !isAe { + panic(vterrors.VT13001("used star expression when user did not")) + } + return expectation.ColumnName() == actualAE.ColumnName() } func stopAtRoute(operator Operator) VisitRule { diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 14bea4f4674..c9ea589381c 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -69,7 +69,7 @@ type ( // Aggr encodes all information needed for aggregation functions Aggr struct { Original *sqlparser.AliasedExpr - Func sqlparser.AggrFunc + Func sqlparser.AggrFunc // if we are missing a Func, it means this is a AggregateAnyValue OpCode opcode.AggregateOpcode // OriginalOpCode will contain opcode.AggregateUnassigned unless we are changing opcode while pushing them down @@ -314,8 +314,7 @@ func (qp *QueryProjection) addOrderBy(ctx *plancontext.PlanningContext, orderBy canPushSorting := true es := &expressionSet{} for _, order := range orderBy { - if sqlparser.IsNull(order.Expr) { - // ORDER BY null can safely be ignored + if canIgnoreOrdering(ctx, order.Expr) { continue } if !es.add(ctx, order.Expr) { @@ -329,6 +328,18 @@ func (qp *QueryProjection) addOrderBy(ctx *plancontext.PlanningContext, orderBy } } +// canIgnoreOrdering returns true if the ordering expression has no effect on the result. +func canIgnoreOrdering(ctx *plancontext.PlanningContext, expr sqlparser.Expr) bool { + switch expr.(type) { + case *sqlparser.NullVal, *sqlparser.Literal, *sqlparser.Argument: + return true + case *sqlparser.Subquery: + return ctx.SemTable.RecursiveDeps(expr).IsEmpty() + default: + return false + } +} + func (qp *QueryProjection) calculateDistinct(ctx *plancontext.PlanningContext) { if qp.Distinct && !qp.HasAggr { distinct := qp.useGroupingOverDistinct(ctx) diff --git a/go/vt/vtgate/planbuilder/operators/subquery.go b/go/vt/vtgate/planbuilder/operators/subquery.go index 537737363c8..16fda66d14d 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery.go +++ b/go/vt/vtgate/planbuilder/operators/subquery.go @@ -53,7 +53,8 @@ type SubQuery struct { // We use this information to fail the planning if we are unable to merge the subquery with a route. correlated bool - IsProjection bool + // IsArgument is set to true if the subquery puts the + IsArgument bool } func (sq *SubQuery) planOffsets(ctx *plancontext.PlanningContext) Operator { @@ -156,8 +157,8 @@ func (sq *SubQuery) SetInputs(inputs []Operator) { func (sq *SubQuery) ShortDescription() string { var typ string - if sq.IsProjection { - typ = "PROJ" + if sq.IsArgument { + typ = "ARGUMENT" } else { typ = "FILTER" } @@ -175,8 +176,11 @@ func (sq *SubQuery) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparse return sq } -func (sq *SubQuery) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, addToGroupBy bool, exprs *sqlparser.AliasedExpr) int { - return sq.Outer.AddColumn(ctx, reuseExisting, addToGroupBy, exprs) +func (sq *SubQuery) AddColumn(ctx *plancontext.PlanningContext, reuseExisting bool, addToGroupBy bool, ae *sqlparser.AliasedExpr) int { + ae = sqlparser.CloneRefOfAliasedExpr(ae) + // we need to rewrite the column name to an argument if it's the same as the subquery column name + ae.Expr = rewriteColNameToArgument(ctx, ae.Expr, []*SubQuery{sq}, sq) + return sq.Outer.AddColumn(ctx, reuseExisting, addToGroupBy, ae) } func (sq *SubQuery) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int { @@ -206,7 +210,7 @@ func (sq *SubQuery) settle(ctx *plancontext.PlanningContext, outer Operator) Ope if sq.correlated && sq.FilterType != opcode.PulloutExists { panic(correlatedSubqueryErr) } - if sq.IsProjection { + if sq.IsArgument { if len(sq.GetMergePredicates()) > 0 { // this means that we have a correlated subquery on our hands panic(correlatedSubqueryErr) @@ -289,3 +293,18 @@ func (sq *SubQuery) mapExpr(f func(expr sqlparser.Expr) sqlparser.Expr) { sq.Original = f(sq.Original) sq.originalSubquery = f(sq.originalSubquery).(*sqlparser.Subquery) } + +func (sq *SubQuery) rewriteColNameToArgument(expr sqlparser.Expr) sqlparser.Expr { + pre := func(cursor *sqlparser.Cursor) bool { + colName, ok := cursor.Node().(*sqlparser.ColName) + if !ok || colName.Qualifier.NonEmpty() || !colName.Name.EqualString(sq.ArgName) { + // we only want to rewrite the column name to an argument if it's the right column + return true + } + + cursor.Replace(sqlparser.NewArgument(sq.ArgName)) + return true + } + + return sqlparser.Rewrite(expr, pre, nil).(sqlparser.Expr) +} diff --git a/go/vt/vtgate/planbuilder/operators/subquery_builder.go b/go/vt/vtgate/planbuilder/operators/subquery_builder.go index 4caf3530075..42bbfdb7c99 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_builder.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_builder.go @@ -159,7 +159,7 @@ func createSubquery( parent sqlparser.Expr, argName string, filterType opcode.PulloutOpcode, - isProjection bool, + isArg bool, ) *SubQuery { topLevel := ctx.SemTable.EqualsExpr(original, parent) original = cloneASTAndSemState(ctx, original) @@ -181,7 +181,7 @@ func createSubquery( Original: original, ArgName: argName, originalSubquery: originalSq, - IsProjection: isProjection, + IsArgument: isArg, TopLevel: topLevel, JoinColumns: joinCols, correlated: correlated, diff --git a/go/vt/vtgate/planbuilder/operators/subquery_container.go b/go/vt/vtgate/planbuilder/operators/subquery_container.go index e4feeab49d8..41b645ac7b4 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_container.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_container.go @@ -43,7 +43,7 @@ func (sqc *SubQueryContainer) Clone(inputs []Operator) Operator { if !ok { panic("got bad input") } - result.Inner = append(result.Inner, inner) + result.addInner(inner) } return result } @@ -90,3 +90,13 @@ func (sqc *SubQueryContainer) GetColumns(ctx *plancontext.PlanningContext) []*sq func (sqc *SubQueryContainer) GetSelectExprs(ctx *plancontext.PlanningContext) sqlparser.SelectExprs { return sqc.Outer.GetSelectExprs(ctx) } + +func (sqc *SubQueryContainer) addInner(inner *SubQuery) { + for _, sq := range sqc.Inner { + if sq.ArgName == inner.ArgName { + // we already have this subquery + return + } + } + sqc.Inner = append(sqc.Inner, inner) +} diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index d951568502d..a85829bab6d 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -100,8 +100,11 @@ func settleSubqueries(ctx *plancontext.PlanningContext, op Operator) Operator { newExpr, rewritten := rewriteMergedSubqueryExpr(ctx, aggr.SubQueryExpression, aggr.Original.Expr) if rewritten { aggr.Original.Expr = newExpr + op.Columns[aggr.ColOffset].Expr = newExpr } } + case *Ordering: + op.settleOrderingExpressions(ctx) } return op, NoRewrite } @@ -109,6 +112,29 @@ func settleSubqueries(ctx *plancontext.PlanningContext, op Operator) Operator { return BottomUp(op, TableID, visit, nil) } +func (o *Ordering) settleOrderingExpressions(ctx *plancontext.PlanningContext) { + for idx, order := range o.Order { + for _, sq := range ctx.MergedSubqueries { + arg := ctx.GetReservedArgumentFor(sq) + expr := sqlparser.Rewrite(order.SimplifiedExpr, nil, func(cursor *sqlparser.Cursor) bool { + switch expr := cursor.Node().(type) { + case *sqlparser.ColName: + if expr.Name.String() == arg { + cursor.Replace(sq) + } + case *sqlparser.Argument: + if expr.Name == arg { + cursor.Replace(sq) + } + } + + return true + }) + o.Order[idx].SimplifiedExpr = expr.(sqlparser.Expr) + } + } +} + func mergeSubqueryExpr(ctx *plancontext.PlanningContext, pe *ProjExpr) { se, ok := pe.Info.(SubQueryExpression) if !ok { @@ -319,7 +345,7 @@ func addSubQuery(in Operator, inner *SubQuery) Operator { } } - sql.Inner = append(sql.Inner, inner) + sql.addInner(inner) return sql } @@ -477,7 +503,7 @@ func tryMergeSubqueryWithOuter(ctx *plancontext.PlanningContext, subQuery *SubQu if op == nil { return outer, NoRewrite } - if !subQuery.IsProjection { + if !subQuery.IsArgument { op.Source = newFilter(outer.Source, subQuery.Original) } ctx.MergedSubqueries = append(ctx.MergedSubqueries, subQuery.originalSubquery) @@ -582,7 +608,7 @@ func (s *subqueryRouteMerger) merge(ctx *plancontext.PlanningContext, inner, out var src Operator if isSharded { src = s.outer.Source - if !s.subq.IsProjection { + if !s.subq.IsArgument { src = newFilter(s.outer.Source, s.original) } } else { @@ -643,7 +669,7 @@ func (s *subqueryRouteMerger) rewriteASTExpression(ctx *plancontext.PlanningCont panic(err) } - if s.subq.IsProjection { + if s.subq.IsArgument { ctx.SemTable.CopySemanticInfo(s.subq.originalSubquery.Select, subqStmt) s.subq.originalSubquery.Select = subqStmt } else { diff --git a/go/vt/vtgate/planbuilder/plan_test.go b/go/vt/vtgate/planbuilder/plan_test.go index 381983cb0ee..62dd567b0d8 100644 --- a/go/vt/vtgate/planbuilder/plan_test.go +++ b/go/vt/vtgate/planbuilder/plan_test.go @@ -28,10 +28,9 @@ import ( "strings" "testing" - "github.com/stretchr/testify/suite" - "github.com/nsf/jsondiff" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/test/utils" diff --git a/go/vt/vtgate/planbuilder/predicate_rewrite_test.go b/go/vt/vtgate/planbuilder/predicate_rewrite_test.go index ba1d60ff234..240c7ff3581 100644 --- a/go/vt/vtgate/planbuilder/predicate_rewrite_test.go +++ b/go/vt/vtgate/planbuilder/predicate_rewrite_test.go @@ -27,7 +27,6 @@ import ( "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtenv" diff --git a/go/vt/vtgate/planbuilder/set.go b/go/vt/vtgate/planbuilder/set.go index bf6820b7489..77f20be40f9 100644 --- a/go/vt/vtgate/planbuilder/set.go +++ b/go/vt/vtgate/planbuilder/set.go @@ -21,18 +21,14 @@ import ( "strconv" "strings" - "vitess.io/vitess/go/vt/sysvars" - "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" - - "vitess.io/vitess/go/vt/vtgate/evalengine" - - "vitess.io/vitess/go/vt/vtgate/vindexes" - - "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/key" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/sysvars" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/evalengine" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/vindexes" ) type ( diff --git a/go/vt/vtgate/planbuilder/show_test.go b/go/vt/vtgate/planbuilder/show_test.go index 931c914149d..bfdb9a623a0 100644 --- a/go/vt/vtgate/planbuilder/show_test.go +++ b/go/vt/vtgate/planbuilder/show_test.go @@ -21,15 +21,13 @@ import ( "fmt" "testing" - "vitess.io/vitess/go/test/vschemawrapper" - "vitess.io/vitess/go/vt/vtenv" - "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/test/vschemawrapper" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vtgate/vindexes" ) diff --git a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json index 4a879997925..122932cba4d 100644 --- a/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/memory_sort_cases.json @@ -288,49 +288,59 @@ "QueryType": "SELECT", "Original": "select user.col1 as a, user.col2 b, music.col3 c from user, music where user.id = music.id and user.id = 1 order by c", "Instructions": { - "OperatorType": "Sort", - "Variant": "Memory", - "OrderBy": "(2|3) ASC", - "ResultColumns": 3, + "OperatorType": "SimpleProjection", + "Columns": [ + 0, + 1, + 2 + ], "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "L:0,L:1,R:0,R:1", - "JoinVars": { - "user_id": 2 - }, - "TableName": "`user`_music", + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(2|3) ASC", + "ResultColumns": 3, "Inputs": [ { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select `user`.col1 as a, `user`.col2 as b, `user`.id from `user` where 1 != 1", - "Query": "select `user`.col1 as a, `user`.col2 as b, `user`.id from `user` where `user`.id = 1", - "Table": "`user`", - "Values": [ - "1" - ], - "Vindex": "user_index" - }, - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,L:1,R:0,R:1", + "JoinVars": { + "user_id": 2 }, - "FieldQuery": "select music.col3 as c, weight_string(music.col3) from music where 1 != 1", - "Query": "select music.col3 as c, weight_string(music.col3) from music where music.id = :user_id", - "Table": "music", - "Values": [ - ":user_id" - ], - "Vindex": "music_user_map" + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.col1 as a, `user`.col2 as b, `user`.id from `user` where 1 != 1", + "Query": "select `user`.col1 as a, `user`.col2 as b, `user`.id from `user` where `user`.id = 1", + "Table": "`user`", + "Values": [ + "1" + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select music.col3 as c, weight_string(music.col3) from music where 1 != 1", + "Query": "select music.col3 as c, weight_string(music.col3) from music where music.id = :user_id", + "Table": "music", + "Values": [ + ":user_id" + ], + "Vindex": "music_user_map" + } + ] } ] } @@ -349,49 +359,59 @@ "QueryType": "SELECT", "Original": "select user.col1 as a, user.col2, music.col3 from user join music on user.id = music.id where user.id = 1 order by 1 asc, 3 desc, 2 asc", "Instructions": { - "OperatorType": "Sort", - "Variant": "Memory", - "OrderBy": "(0|3) ASC, (2|4) DESC, (1|5) ASC", - "ResultColumns": 3, + "OperatorType": "SimpleProjection", + "Columns": [ + 0, + 1, + 2 + ], "Inputs": [ { - "OperatorType": "Join", - "Variant": "Join", - "JoinColumnIndexes": "L:0,L:1,R:0,L:2,R:1,L:3", - "JoinVars": { - "user_id": 4 - }, - "TableName": "`user`_music", + "OperatorType": "Sort", + "Variant": "Memory", + "OrderBy": "(0|3) ASC, (2|4) DESC, (1|5) ASC", + "ResultColumns": 3, "Inputs": [ { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true + "OperatorType": "Join", + "Variant": "Join", + "JoinColumnIndexes": "L:0,L:1,R:0,L:2,R:1,L:3", + "JoinVars": { + "user_id": 4 }, - "FieldQuery": "select `user`.col1 as a, `user`.col2, weight_string(`user`.col1), weight_string(`user`.col2), `user`.id from `user` where 1 != 1", - "Query": "select `user`.col1 as a, `user`.col2, weight_string(`user`.col1), weight_string(`user`.col2), `user`.id from `user` where `user`.id = 1", - "Table": "`user`", - "Values": [ - "1" - ], - "Vindex": "user_index" - }, - { - "OperatorType": "Route", - "Variant": "EqualUnique", - "Keyspace": { - "Name": "user", - "Sharded": true - }, - "FieldQuery": "select music.col3, weight_string(music.col3) from music where 1 != 1", - "Query": "select music.col3, weight_string(music.col3) from music where music.id = :user_id", - "Table": "music", - "Values": [ - ":user_id" - ], - "Vindex": "music_user_map" + "TableName": "`user`_music", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.col1 as a, `user`.col2, weight_string(`user`.col1), weight_string(`user`.col2), `user`.id from `user` where 1 != 1", + "Query": "select `user`.col1 as a, `user`.col2, weight_string(`user`.col1), weight_string(`user`.col2), `user`.id from `user` where `user`.id = 1", + "Table": "`user`", + "Values": [ + "1" + ], + "Vindex": "user_index" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select music.col3, weight_string(music.col3) from music where 1 != 1", + "Query": "select music.col3, weight_string(music.col3) from music where music.id = :user_id", + "Table": "music", + "Values": [ + ":user_id" + ], + "Vindex": "music_user_map" + } + ] } ] } diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index 14b8000557f..b6f9f1f36d3 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -2113,7 +2113,7 @@ } }, { - "comment": "select (select col from user limit 1) as a from user join user_extra order by a", + "comment": "ORDER BY subquery", "query": "select (select col from user limit 1) as a from user join user_extra order by a", "plan": { "QueryType": "SELECT", @@ -2157,9 +2157,9 @@ "Name": "user", "Sharded": true }, - "FieldQuery": "select :__sq1 as __sq1, weight_string(:__sq1) from `user` where 1 != 1", - "OrderBy": "(0|1) ASC", - "Query": "select :__sq1 as __sq1, weight_string(:__sq1) from `user` order by __sq1 asc", + "FieldQuery": "select :__sq1 as a from `user` where 1 != 1", + + "Query": "select :__sq1 as a from `user`", "Table": "`user`" } ] diff --git a/go/vt/vtgate/planbuilder/vexplain.go b/go/vt/vtgate/planbuilder/vexplain.go index 6f07c69d216..3d2b94a791b 100644 --- a/go/vt/vtgate/planbuilder/vexplain.go +++ b/go/vt/vtgate/planbuilder/vexplain.go @@ -20,8 +20,6 @@ import ( "context" "encoding/json" - "vitess.io/vitess/go/vt/vtgate/vindexes" - "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/key" querypb "vitess.io/vitess/go/vt/proto/query" @@ -31,6 +29,7 @@ import ( "vitess.io/vitess/go/vt/vtgate/engine" "vitess.io/vitess/go/vt/vtgate/planbuilder/operators" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/vindexes" ) func buildVExplainPlan(ctx context.Context, vexplainStmt *sqlparser.VExplainStmt, reservedVars *sqlparser.ReservedVars, vschema plancontext.VSchema, enableOnlineDDL, enableDirectDDL bool) (*planResult, error) { diff --git a/go/vt/vtgate/planbuilder/vindex_func.go b/go/vt/vtgate/planbuilder/vindex_func.go index abfd2d1d9b3..d3231249639 100644 --- a/go/vt/vtgate/planbuilder/vindex_func.go +++ b/go/vt/vtgate/planbuilder/vindex_func.go @@ -20,13 +20,11 @@ import ( "fmt" "vitess.io/vitess/go/mysql/collations" - "vitess.io/vitess/go/vt/vtgate/semantics" - - "vitess.io/vitess/go/vt/vterrors" - querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/semantics" ) var _ logicalPlan = (*vindexFunc)(nil) diff --git a/go/vt/vtgate/planbuilder/vstream.go b/go/vt/vtgate/planbuilder/vstream.go index fe07a4a021b..19713a6ffa3 100644 --- a/go/vt/vtgate/planbuilder/vstream.go +++ b/go/vt/vtgate/planbuilder/vstream.go @@ -20,13 +20,12 @@ import ( "strconv" "strings" - "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" - "vitess.io/vitess/go/vt/key" topodatapb "vitess.io/vitess/go/vt/proto/topodata" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/engine" + "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" ) const defaultLimit = 100