Skip to content

Commit

Permalink
extract aggregation function arguments for subquery handling
Browse files Browse the repository at this point in the history
Signed-off-by: Harshit Gangal <harshit@planetscale.com>
  • Loading branch information
harshit-gangal committed Jun 11, 2024
1 parent 0923bb0 commit 7ec219b
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 5 deletions.
21 changes: 21 additions & 0 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -2874,6 +2874,8 @@ type (
Expr
GetArg() Expr
GetArgs() Exprs
SetArg(expr Expr)
SetArgs(exprs Exprs) error
// AggrName returns the lower case string representing this aggregation function
AggrName() string
}
Expand Down Expand Up @@ -3401,6 +3403,25 @@ func (varS *VarSamp) GetArgs() Exprs { return Exprs{varS.Arg} }
func (variance *Variance) GetArgs() Exprs { return Exprs{variance.Arg} }
func (av *AnyValue) GetArgs() Exprs { return Exprs{av.Arg} }

func (min *Min) SetArg(expr Expr) { min.Arg = expr }
func (sum *Sum) SetArg(expr Expr) { sum.Arg = expr }
func (max *Max) SetArg(expr Expr) { max.Arg = expr }
func (avg *Avg) SetArg(expr Expr) { avg.Arg = expr }
func (*CountStar) SetArg(expr Expr) {}
func (count *Count) SetArg(expr Expr) { count.Args = Exprs{expr} }
func (grpConcat *GroupConcatExpr) SetArg(expr Expr) { grpConcat.Exprs = Exprs{expr} }
func (bAnd *BitAnd) SetArg(expr Expr) { bAnd.Arg = expr }
func (bOr *BitOr) SetArg(expr Expr) { bOr.Arg = expr }
func (bXor *BitXor) SetArg(expr Expr) { bXor.Arg = expr }
func (std *Std) SetArg(expr Expr) { std.Arg = expr }
func (stdD *StdDev) SetArg(expr Expr) { stdD.Arg = expr }
func (stdP *StdPop) SetArg(expr Expr) { stdP.Arg = expr }
func (stdS *StdSamp) SetArg(expr Expr) { stdS.Arg = expr }
func (varP *VarPop) SetArg(expr Expr) { varP.Arg = expr }
func (varS *VarSamp) SetArg(expr Expr) { varS.Arg = expr }
func (variance *Variance) SetArg(expr Expr) { variance.Arg = expr }
func (av *AnyValue) SetArg(expr Expr) { av.Arg = expr }

func (sum *Sum) IsDistinct() bool { return sum.Distinct }
func (min *Min) IsDistinct() bool { return min.Distinct }
func (max *Max) IsDistinct() bool { return max.Distinct }
Expand Down
65 changes: 65 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -2184,6 +2184,71 @@ func ContainsAggregation(e SQLNode) bool {
return hasAggregates
}

func (min *Min) SetArgs(exprs Exprs) error {
return setFuncArgs(min, exprs, "MIN")
}
func (sum *Sum) SetArgs(exprs Exprs) error {
return setFuncArgs(sum, exprs, "SUM")
}
func (max *Max) SetArgs(exprs Exprs) error {
return setFuncArgs(max, exprs, "MAX")
}
func (avg *Avg) SetArgs(exprs Exprs) error {
return setFuncArgs(avg, exprs, "AVG")
}
func (*CountStar) SetArgs(exprs Exprs) error {
return nil
}
func (count *Count) SetArgs(exprs Exprs) error {
count.Args = exprs
return nil
}
func (grpConcat *GroupConcatExpr) SetArgs(exprs Exprs) error {
grpConcat.Exprs = exprs
return nil
}
func (bAnd *BitAnd) SetArgs(exprs Exprs) error {
return setFuncArgs(bAnd, exprs, "BIT_AND")
}
func (bOr *BitOr) SetArgs(exprs Exprs) error {
return setFuncArgs(bOr, exprs, "BIT_OR")
}
func (bXor *BitXor) SetArgs(exprs Exprs) error {
return setFuncArgs(bXor, exprs, "BIT_XOR")
}
func (std *Std) SetArgs(exprs Exprs) error {
return setFuncArgs(std, exprs, "STD")
}
func (stdD *StdDev) SetArgs(exprs Exprs) error {
return setFuncArgs(stdD, exprs, "STDDEV")
}
func (stdP *StdPop) SetArgs(exprs Exprs) error {
return setFuncArgs(stdP, exprs, "STDDEV_POP")
}
func (stdS *StdSamp) SetArgs(exprs Exprs) error {
return setFuncArgs(stdS, exprs, "STDDEV_SAMP")
}
func (varP *VarPop) SetArgs(exprs Exprs) error {
return setFuncArgs(varP, exprs, "VAR_POP")
}
func (varS *VarSamp) SetArgs(exprs Exprs) error {
return setFuncArgs(varS, exprs, "VAR_SAMP")
}
func (variance *Variance) SetArgs(exprs Exprs) error {
return setFuncArgs(variance, exprs, "VARIANCE")
}
func (av *AnyValue) SetArgs(exprs Exprs) error {
return setFuncArgs(av, exprs, "ANY_VALUE")
}

func setFuncArgs(aggr AggrFunc, exprs Exprs, name string) error {
if len(exprs) != 1 {
return vterrors.VT03001(name)
}
aggr.SetArg(exprs[0])
return nil
}

// GetFirstSelect gets the first select statement
func GetFirstSelect(selStmt SelectStatement) *Select {
if selStmt == nil {
Expand Down
24 changes: 24 additions & 0 deletions go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,19 @@ func (a *Aggregator) planOffsets(ctx *plancontext.PlanningContext) Operator {
return nil
}

func (aggr Aggr) setPushColumn(exprs sqlparser.Exprs) {
if aggr.Func != nil {
err := aggr.Func.SetArgs(exprs)
if err != nil {
panic(err)
}
}
if len(exprs) > 1 {
panic(vterrors.VT13001(fmt.Sprintf("unexpected number of expression in an random aggregation: %s", sqlparser.String(exprs))))
}
aggr.Original.Expr = exprs[0]
}

func (aggr Aggr) getPushColumn() sqlparser.Expr {
switch aggr.OpCode {
case opcode.AggregateAnyValue:
Expand All @@ -398,6 +411,17 @@ func (aggr Aggr) getPushColumn() sqlparser.Expr {
}
}

func (aggr Aggr) getPushColumnExprs() sqlparser.Exprs {
switch aggr.OpCode {
case opcode.AggregateAnyValue:
return sqlparser.Exprs{aggr.Original.Expr}
case opcode.AggregateCountStar:
return sqlparser.Exprs{sqlparser.NewIntLiteral("1")}
default:
return aggr.Func.GetArgs()
}
}

func (a *Aggregator) planOffsetsNotPushed(ctx *plancontext.PlanningContext) {
a.Source = newAliasedProjection(a.Source)
// we need to keep things in the column order, so we can't iterate over the aggregations or groupings
Expand Down
17 changes: 13 additions & 4 deletions go/vt/vtgate/planbuilder/operators/horizon_expanding.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,19 @@ func createProjectionWithAggr(ctx *plancontext.PlanningContext, qp *QueryProject
sqc := &SubQueryBuilder{}
outerID := TableID(src)
for idx, aggr := range aggregations {
expr := aggr.Original.Expr
newExpr, subqs := sqc.pullOutValueSubqueries(ctx, expr, outerID, false)
if newExpr != nil {
aggregations[idx].SubQueryExpression = subqs
exprs := aggr.getPushColumnExprs()
var newExprs sqlparser.Exprs
for _, expr := range exprs {
newExpr, subqs := sqc.pullOutValueSubqueries(ctx, expr, outerID, false)
if newExpr != nil {
newExprs = append(newExprs, newExpr)
aggregations[idx].SubQueryExpression = append(aggregations[idx].SubQueryExpression, subqs...)
} else {
newExprs = append(newExprs, expr)
}
}
if len(aggregations[idx].SubQueryExpression) > 0 {
aggr.setPushColumn(newExprs)
}
}
aggrOp.Source = sqc.getRootOperator(src, nil)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/subquery_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func rewriteMergedSubqueryExpr(ctx *plancontext.PlanningContext, se SubQueryExpr
return true
}
case *sqlparser.Argument:
if expr.Name != sq. /**/ ArgName {
if expr.Name != sq.ArgName {
return true
}
default:
Expand Down

0 comments on commit 7ec219b

Please sign in to comment.