diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 0026764970e..80c159dfec0 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -2874,6 +2874,7 @@ type ( Expr GetArg() Expr GetArgs() Exprs + SetArg(expr Expr) // AggrName returns the lower case string representing this aggregation function AggrName() string } @@ -3401,6 +3402,25 @@ func (varS *VarSamp) GetArgs() Exprs { return Exprs{varS.Arg} } func (variance *Variance) GetArgs() Exprs { return Exprs{variance.Arg} } func (av *AnyValue) GetArgs() Exprs { return Exprs{av.Arg} } +func (min *Min) SetArg(expr Expr) { min.Arg = expr } +func (sum *Sum) SetArg(expr Expr) { sum.Arg = expr } +func (max *Max) SetArg(expr Expr) { max.Arg = expr } +func (avg *Avg) SetArg(expr Expr) { avg.Arg = expr } +func (*CountStar) SetArg(expr Expr) {} +func (count *Count) SetArg(expr Expr) { count.Args = Exprs{expr} } +func (grpConcat *GroupConcatExpr) SetArg(expr Expr) { grpConcat.Exprs = Exprs{expr} } +func (bAnd *BitAnd) SetArg(expr Expr) { bAnd.Arg = expr } +func (bOr *BitOr) SetArg(expr Expr) { bOr.Arg = expr } +func (bXor *BitXor) SetArg(expr Expr) { bXor.Arg = expr } +func (std *Std) SetArg(expr Expr) { std.Arg = expr } +func (stdD *StdDev) SetArg(expr Expr) { stdD.Arg = expr } +func (stdP *StdPop) SetArg(expr Expr) { stdP.Arg = expr } +func (stdS *StdSamp) SetArg(expr Expr) { stdS.Arg = expr } +func (varP *VarPop) SetArg(expr Expr) { varP.Arg = expr } +func (varS *VarSamp) SetArg(expr Expr) { varS.Arg = expr } +func (variance *Variance) SetArg(expr Expr) { variance.Arg = expr } +func (av *AnyValue) SetArg(expr Expr) { av.Arg = expr } + func (sum *Sum) IsDistinct() bool { return sum.Distinct } func (min *Min) IsDistinct() bool { return min.Distinct } func (max *Max) IsDistinct() bool { return max.Distinct } diff --git a/go/vt/vtgate/planbuilder/operators/aggregator.go b/go/vt/vtgate/planbuilder/operators/aggregator.go index 49081eb6a10..526c96d88b4 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregator.go +++ b/go/vt/vtgate/planbuilder/operators/aggregator.go @@ -379,6 +379,14 @@ func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) Operator { return nil } +func (aggr Aggr) setPushColumn(expr sqlparser.Expr) { + if aggr.Func != nil { + aggr.Func.SetArg(expr) + return + } + aggr.Original.Expr = expr +} + func (aggr Aggr) getPushColumn() sqlparser.Expr { switch aggr.OpCode { case opcode.AggregateAnyValue: diff --git a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go index 58a254add7a..074197ec8d3 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go +++ b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go @@ -185,10 +185,10 @@ func createProjectionWithAggr(ctx *plancontext.PlanningContext, qp *QueryProject sqc := &SubQueryBuilder{} outerID := TableID(src) for idx, aggr := range aggregations { - expr := aggr.Original.Expr - newExpr, subqs := sqc.pullOutValueSubqueries(ctx, expr, outerID, false) + newExpr, subqs := sqc.pullOutValueSubqueries(ctx, aggr.getPushColumn(), outerID, false) if newExpr != nil { aggregations[idx].SubQueryExpression = subqs + aggr.setPushColumn(newExpr) } } aggrOp.Source = sqc.getRootOperator(src, nil) diff --git a/go/vt/vtgate/planbuilder/testdata/onecase.json b/go/vt/vtgate/planbuilder/testdata/onecase.json index da7543f706a..bf6e17d2dd5 100644 --- a/go/vt/vtgate/planbuilder/testdata/onecase.json +++ b/go/vt/vtgate/planbuilder/testdata/onecase.json @@ -1,9 +1,8 @@ [ { "comment": "Add your test case here for debugging and run go test -run=One.", - "query": "", + "query": "SELECT id, name, (SELECT COUNT(*) FROM user_extra) AS total_extra_count FROM user GROUP BY id, name ORDER BY total_extra_count", "plan": { - } } ] \ No newline at end of file