Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[release-20.0] fix: order by subquery planning (#16049) #16134

Merged
merged 2 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion go/test/endtoend/vtgate/queries/subquery/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,39 @@ create table t1
id2 bigint,
primary key (id1)
) Engine = InnoDB;

create table t1_id2_idx
(
id2 bigint,
keyspace_id varbinary(10),
primary key (id2)
) Engine = InnoDB;

create table t2
(
id3 bigint,
id4 bigint,
primary key (id3)
) Engine = InnoDB;

create table t2_id4_idx
(
id bigint not null auto_increment,
id4 bigint,
id3 bigint,
primary key (id),
key idx_id4 (id4)
) Engine = InnoDB;
) Engine = InnoDB;

CREATE TABLE user
(
id INT PRIMARY KEY,
name VARCHAR(100)
);

CREATE TABLE user_extra
(
user_id INT,
extra_info VARCHAR(100),
PRIMARY KEY (user_id, extra_info)
);
44 changes: 44 additions & 0 deletions go/test/endtoend/vtgate/queries/subquery/subquery_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package subquery

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -188,3 +189,46 @@ func TestSubqueryInDerivedTable(t *testing.T) {
mcmp.Exec(`select t.a from (select t1.id2, t2.id3, (select id2 from t1 order by id2 limit 1) as a from t1 join t2 on t1.id1 = t2.id4) t`)
mcmp.Exec(`SELECT COUNT(*) FROM (SELECT DISTINCT t1.id1 FROM t1 JOIN t2 ON t1.id1 = t2.id4) dt`)
}

func TestSubqueries(t *testing.T) {
// This method tests many types of subqueries. The queries should move to a vitess-tester test file once we have a way to run them.
// The commented out queries are failing because of wrong types being returned.
// The tests are commented out until the issue is fixed.
utils.SkipIfBinaryIsBelowVersion(t, 20, "vtgate")
mcmp, closer := start(t)
defer closer()
queries := []string{
`INSERT INTO user (id, name) VALUES (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'), (4, 'David'), (5, 'Eve'), (6, 'Frank'), (7, 'Grace'), (8, 'Hannah'), (9, 'Ivy'), (10, 'Jack')`,
`INSERT INTO user_extra (user_id, extra_info) VALUES (1, 'info1'), (1, 'info2'), (2, 'info1'), (3, 'info1'), (3, 'info2'), (4, 'info1'), (5, 'info1'), (6, 'info1'), (7, 'info1'), (8, 'info1')`,
`SELECT (SELECT COUNT(*) FROM user_extra) AS order_count, id FROM user WHERE id = (SELECT COUNT(*) FROM user_extra)`,
`SELECT id, (SELECT COUNT(*) FROM user_extra) AS order_count FROM user ORDER BY (SELECT COUNT(*) FROM user_extra)`,
`SELECT id FROM user WHERE id = (SELECT COUNT(*) FROM user_extra) ORDER BY (SELECT COUNT(*) FROM user_extra)`,
`SELECT (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) AS extra_count, id, name FROM user WHERE (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) > 0`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) AS extra_count FROM user ORDER BY (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`,
`SELECT id, name FROM user WHERE (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) > 0 ORDER BY (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) AS extra_count FROM user GROUP BY id, name HAVING COUNT(*) > (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`,
`SELECT id, name, COUNT(*) FROM user WHERE (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) > 0 GROUP BY id, name HAVING COUNT(*) > (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id)`,
`SELECT id, round(MAX(id + (SELECT COUNT(*) FROM user_extra where user_id = 42))) as r FROM user WHERE id = 42 GROUP BY id ORDER BY r`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) * 2 AS double_extra_count FROM user`,
`SELECT id, name FROM user WHERE id IN (SELECT user_id FROM user_extra WHERE LENGTH(extra_info) > 4)`,
`SELECT id, COUNT(*) FROM user GROUP BY id HAVING COUNT(*) > (SELECT COUNT(*) FROM user_extra WHERE user_extra.user_id = user.id) + 1`,
`SELECT id, name FROM user ORDER BY (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) * id`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra WHERE user.id = user_extra.user_id) + id AS extra_count_plus_id FROM user`,
`SELECT id, name FROM user WHERE id IN (SELECT user_id FROM user_extra WHERE extra_info = 'info1') OR id IN (SELECT user_id FROM user_extra WHERE extra_info = 'info2')`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count, SUM(id) AS sum_ids FROM user GROUP BY id, name ORDER BY (SELECT COUNT(*) FROM user_extra)`,
// `SELECT id, name, (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info, AVG(id) AS avg_ids FROM user GROUP BY id, name HAVING (SELECT SUM(LENGTH(extra_info)) FROM user_extra) > 10`,
`SELECT id, name, (SELECT AVG(LENGTH(extra_info)) FROM user_extra) AS avg_length_extra_info, MAX(id) AS max_id FROM user WHERE id IN (SELECT user_id FROM user_extra) GROUP BY id, name`,
`SELECT id, name, (SELECT MAX(LENGTH(extra_info)) FROM user_extra) AS max_length_extra_info, MIN(id) AS min_id FROM user GROUP BY id, name ORDER BY (SELECT MAX(LENGTH(extra_info)) FROM user_extra)`,
`SELECT id, name, (SELECT MIN(LENGTH(extra_info)) FROM user_extra) AS min_length_extra_info, SUM(id) AS sum_ids FROM user GROUP BY id, name HAVING (SELECT MIN(LENGTH(extra_info)) FROM user_extra) < 5`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count, AVG(id) AS avg_ids FROM user WHERE id > (SELECT COUNT(*) FROM user_extra) GROUP BY id, name`,
// `SELECT id, name, (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info, COUNT(id) AS count_ids FROM user GROUP BY id, name ORDER BY (SELECT SUM(LENGTH(extra_info)) FROM user_extra)`,
// `SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count, (SELECT SUM(LENGTH(extra_info)) FROM user_extra) AS total_length_extra_info, (SELECT AVG(LENGTH(extra_info)) FROM user_extra) AS avg_length_extra_info, (SELECT MAX(LENGTH(extra_info)) FROM user_extra) AS max_length_extra_info, (SELECT MIN(LENGTH(extra_info)) FROM user_extra) AS min_length_extra_info, SUM(id) AS sum_ids FROM user GROUP BY id, name HAVING (SELECT AVG(LENGTH(extra_info)) FROM user_extra) > 2`,
`SELECT id, name, (SELECT COUNT(*) FROM user_extra) + id AS total_extra_count_plus_id, AVG(id) AS avg_ids FROM user WHERE id < (SELECT MAX(user_id) FROM user_extra) GROUP BY id, name`,
}

for idx, query := range queries {
mcmp.Run(fmt.Sprintf("%d %s", idx, query), func(mcmp *utils.MySQLCompare) {
mcmp.Exec(query)
})
}
}
31 changes: 31 additions & 0 deletions go/test/endtoend/vtgate/queries/subquery/vschema.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
"autocommit": "true"
},
"owner": "t2"
},
"xxhash": {
"type": "xxhash"
}
},
"tables": {
Expand Down Expand Up @@ -64,6 +67,34 @@
"name": "hash"
}
]
},
"user_extra": {
"name": "user_extra",
"column_vindexes": [
{
"columns": [
"user_id",
"extra_info"
],
"type": "xxhash",
"name": "xxhash",
"vindex": null
}
]
},
"user": {
"name": "user",
"column_vindexes": [
{
"columns": [
"id"
],
"type": "xxhash",
"name": "xxhash",
"vindex": null
}
]
}

}
}
47 changes: 47 additions & 0 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2874,6 +2874,8 @@ type (
Expr
GetArg() Expr
GetArgs() Exprs
SetArg(expr Expr)
SetArgs(exprs Exprs) error
// AggrName returns the lower case string representing this aggregation function
AggrName() string
}
Expand Down Expand Up @@ -3401,6 +3403,51 @@ func (varS *VarSamp) GetArgs() Exprs { return Exprs{varS.Arg} }
func (variance *Variance) GetArgs() Exprs { return Exprs{variance.Arg} }
func (av *AnyValue) GetArgs() Exprs { return Exprs{av.Arg} }

func (min *Min) SetArg(expr Expr) { min.Arg = expr }
func (sum *Sum) SetArg(expr Expr) { sum.Arg = expr }
func (max *Max) SetArg(expr Expr) { max.Arg = expr }
func (avg *Avg) SetArg(expr Expr) { avg.Arg = expr }
func (*CountStar) SetArg(expr Expr) {}
func (count *Count) SetArg(expr Expr) { count.Args = Exprs{expr} }
func (grpConcat *GroupConcatExpr) SetArg(expr Expr) { grpConcat.Exprs = Exprs{expr} }
func (bAnd *BitAnd) SetArg(expr Expr) { bAnd.Arg = expr }
func (bOr *BitOr) SetArg(expr Expr) { bOr.Arg = expr }
func (bXor *BitXor) SetArg(expr Expr) { bXor.Arg = expr }
func (std *Std) SetArg(expr Expr) { std.Arg = expr }
func (stdD *StdDev) SetArg(expr Expr) { stdD.Arg = expr }
func (stdP *StdPop) SetArg(expr Expr) { stdP.Arg = expr }
func (stdS *StdSamp) SetArg(expr Expr) { stdS.Arg = expr }
func (varP *VarPop) SetArg(expr Expr) { varP.Arg = expr }
func (varS *VarSamp) SetArg(expr Expr) { varS.Arg = expr }
func (variance *Variance) SetArg(expr Expr) { variance.Arg = expr }
func (av *AnyValue) SetArg(expr Expr) { av.Arg = expr }

func (min *Min) SetArgs(exprs Exprs) error { return setFuncArgs(min, exprs, "MIN") }
func (sum *Sum) SetArgs(exprs Exprs) error { return setFuncArgs(sum, exprs, "SUM") }
func (max *Max) SetArgs(exprs Exprs) error { return setFuncArgs(max, exprs, "MAX") }
func (avg *Avg) SetArgs(exprs Exprs) error { return setFuncArgs(avg, exprs, "AVG") }
func (*CountStar) SetArgs(Exprs) error { return nil }
func (bAnd *BitAnd) SetArgs(exprs Exprs) error { return setFuncArgs(bAnd, exprs, "BIT_AND") }
func (bOr *BitOr) SetArgs(exprs Exprs) error { return setFuncArgs(bOr, exprs, "BIT_OR") }
func (bXor *BitXor) SetArgs(exprs Exprs) error { return setFuncArgs(bXor, exprs, "BIT_XOR") }
func (std *Std) SetArgs(exprs Exprs) error { return setFuncArgs(std, exprs, "STD") }
func (stdD *StdDev) SetArgs(exprs Exprs) error { return setFuncArgs(stdD, exprs, "STDDEV") }
func (stdP *StdPop) SetArgs(exprs Exprs) error { return setFuncArgs(stdP, exprs, "STDDEV_POP") }
func (stdS *StdSamp) SetArgs(exprs Exprs) error { return setFuncArgs(stdS, exprs, "STDDEV_SAMP") }
func (varP *VarPop) SetArgs(exprs Exprs) error { return setFuncArgs(varP, exprs, "VAR_POP") }
func (varS *VarSamp) SetArgs(exprs Exprs) error { return setFuncArgs(varS, exprs, "VAR_SAMP") }
func (variance *Variance) SetArgs(exprs Exprs) error { return setFuncArgs(variance, exprs, "VARIANCE") }
func (av *AnyValue) SetArgs(exprs Exprs) error { return setFuncArgs(av, exprs, "ANY_VALUE") }

func (count *Count) SetArgs(exprs Exprs) error {
count.Args = exprs
return nil
}
func (grpConcat *GroupConcatExpr) SetArgs(exprs Exprs) error {
grpConcat.Exprs = exprs
return nil
}

func (sum *Sum) IsDistinct() bool { return sum.Distinct }
func (min *Min) IsDistinct() bool { return min.Distinct }
func (max *Max) IsDistinct() bool { return max.Distinct }
Expand Down
9 changes: 9 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2184,6 +2184,15 @@ func ContainsAggregation(e SQLNode) bool {
return hasAggregates
}

// setFuncArgs sets the arguments for the aggregation function, while checking that there is only one argument
func setFuncArgs(aggr AggrFunc, exprs Exprs, name string) error {
if len(exprs) != 1 {
return vterrors.VT03001(name)
}
aggr.SetArg(exprs[0])
return nil
}

// GetFirstSelect gets the first select statement
func GetFirstSelect(selStmt SelectStatement) *Select {
if selStmt == nil {
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ import (
"sync"
"testing"

fuzz "github.com/AdaLogics/go-fuzz-headers"

"vitess.io/vitess/go/json2"
"vitess.io/vitess/go/sqltypes"
vschemapb "vitess.io/vitess/go/vt/proto/vschema"
"vitess.io/vitess/go/vt/vtgate/vindexes"

fuzz "github.com/AdaLogics/go-fuzz-headers"
)

var initter sync.Once
Expand Down
21 changes: 12 additions & 9 deletions go/vt/vtgate/planbuilder/operators/aggregation_pushing.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,20 +89,21 @@ func reachedPhase(ctx *plancontext.PlanningContext, p Phase) bool {
// Any columns that are needed to evaluate the subquery needs to be added as
// grouping columns to the aggregation being pushed down, and then after the
// subquery evaluation we are free to reassemble the total aggregation values.
// This is very similar to how we push aggregation through an shouldRun-join.
// This is very similar to how we push aggregation through an apply-join.
func pushAggregationThroughSubquery(
ctx *plancontext.PlanningContext,
rootAggr *Aggregator,
src *SubQueryContainer,
) (Operator, *ApplyResult) {
pushedAggr := rootAggr.SplitAggregatorBelowOperators([]Operator{src.Outer})
pushedAggr := rootAggr.SplitAggregatorBelowOperators(ctx, []Operator{src.Outer})
for _, subQuery := range src.Inner {
lhsCols := subQuery.OuterExpressionsNeeded(ctx, src.Outer)
for _, colName := range lhsCols {
idx := slices.IndexFunc(pushedAggr.Columns, func(ae *sqlparser.AliasedExpr) bool {
findColName := func(ae *sqlparser.AliasedExpr) bool {
return ctx.SemTable.EqualsExpr(ae.Expr, colName)
})
if idx >= 0 {
}
if slices.IndexFunc(pushedAggr.Columns, findColName) >= 0 {
// we already have the column, no need to push it again
continue
}
pushedAggr.addColumnWithoutPushing(ctx, aeWrap(colName), true)
Expand All @@ -111,8 +112,10 @@ func pushAggregationThroughSubquery(

src.Outer = pushedAggr

for _, aggregation := range pushedAggr.Aggregations {
aggregation.Original.Expr = rewriteColNameToArgument(ctx, aggregation.Original.Expr, aggregation.SubQueryExpression, src.Inner...)
for _, aggr := range pushedAggr.Aggregations {
// we rewrite columns in the aggregation to use the argument form of the subquery
aggr.Original.Expr = rewriteColNameToArgument(ctx, aggr.Original.Expr, aggr.SubQueryExpression, src.Inner...)
pushedAggr.Columns[aggr.ColOffset].Expr = rewriteColNameToArgument(ctx, pushedAggr.Columns[aggr.ColOffset].Expr, aggr.SubQueryExpression, src.Inner...)
}

if !rootAggr.Original {
Expand Down Expand Up @@ -147,7 +150,7 @@ func pushAggregationThroughRoute(
route *Route,
) (Operator, *ApplyResult) {
// Create a new aggregator to be placed below the route.
aggrBelowRoute := aggregator.SplitAggregatorBelowOperators(route.Inputs())
aggrBelowRoute := aggregator.SplitAggregatorBelowOperators(ctx, route.Inputs())
aggrBelowRoute.Aggregations = nil

pushAggregations(ctx, aggregator, aggrBelowRoute)
Expand Down Expand Up @@ -248,7 +251,7 @@ func pushAggregationThroughFilter(
) (Operator, *ApplyResult) {

columnsNeeded := collectColNamesNeeded(ctx, filter)
pushedAggr := aggregator.SplitAggregatorBelowOperators([]Operator{filter.Source})
pushedAggr := aggregator.SplitAggregatorBelowOperators(ctx, []Operator{filter.Source})
withNextColumn:
for _, col := range columnsNeeded {
for _, gb := range pushedAggr.Grouping {
Expand Down
40 changes: 39 additions & 1 deletion go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,21 @@ func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) Operator {
return nil
}

func (aggr Aggr) setPushColumn(exprs sqlparser.Exprs) {
if aggr.Func == nil {
if len(exprs) > 1 {
panic(vterrors.VT13001(fmt.Sprintf("unexpected number of expression in an random aggregation: %s", sqlparser.String(exprs))))
}
aggr.Original.Expr = exprs[0]
return
}

err := aggr.Func.SetArgs(exprs)
if err != nil {
panic(err)
}
}

func (aggr Aggr) getPushColumn() sqlparser.Expr {
switch aggr.OpCode {
case opcode.AggregateAnyValue:
Expand All @@ -398,6 +413,17 @@ func (aggr Aggr) getPushColumn() sqlparser.Expr {
}
}

func (aggr Aggr) getPushColumnExprs() sqlparser.Exprs {
switch aggr.OpCode {
case opcode.AggregateAnyValue:
return sqlparser.Exprs{aggr.Original.Expr}
case opcode.AggregateCountStar:
return sqlparser.Exprs{sqlparser.NewIntLiteral("1")}
default:
return aggr.Func.GetArgs()
}
}

func (a *Aggregator) planOffsetsNotPushed(ctx *plancontext.PlanningContext) {
a.Source = newAliasedProjection(a.Source)
// we need to keep things in the column order, so we can't iterate over the aggregations or groupings
Expand Down Expand Up @@ -498,11 +524,23 @@ func (a *Aggregator) internalAddColumn(ctx *plancontext.PlanningContext, aliased
// SplitAggregatorBelowOperators returns the aggregator that will live under the Route.
// This is used when we are splitting the aggregation so one part is done
// at the mysql level and one part at the vtgate level
func (a *Aggregator) SplitAggregatorBelowOperators(input []Operator) *Aggregator {
func (a *Aggregator) SplitAggregatorBelowOperators(ctx *plancontext.PlanningContext, input []Operator) *Aggregator {
newOp := a.Clone(input).(*Aggregator)
newOp.Pushed = false
newOp.Original = false
newOp.DT = nil

// We need to make sure that the columns are cloned so that the original operator is not affected
// by the changes we make to the new operator
newOp.Columns = slice.Map(a.Columns, func(from *sqlparser.AliasedExpr) *sqlparser.AliasedExpr {
return ctx.SemTable.Clone(from).(*sqlparser.AliasedExpr)
})
for idx, aggr := range newOp.Aggregations {
newOp.Aggregations[idx].Original = ctx.SemTable.Clone(aggr.Original).(*sqlparser.AliasedExpr)
}
for idx, gb := range newOp.Grouping {
newOp.Grouping[idx].Inner = ctx.SemTable.Clone(gb.Inner).(sqlparser.Expr)
}
return newOp
}

Expand Down
10 changes: 5 additions & 5 deletions go/vt/vtgate/planbuilder/operators/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ import (
"vitess.io/vitess/go/vt/vtgate/vindexes"
)

type compactable interface {
// Compact implement this interface for operators that have easy to see optimisations
Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult)
}

// compact will optimise the operator tree into a smaller but equivalent version
func compact(ctx *plancontext.PlanningContext, op Operator) Operator {
type compactable interface {
// Compact implement this interface for operators that have easy to see optimisations
Compact(ctx *plancontext.PlanningContext) (Operator, *ApplyResult)
}

newOp := BottomUp(op, TableID, func(op Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) {
newOp, ok := op.(compactable)
if !ok {
Expand Down
Loading
Loading