Skip to content

Commit

Permalink
Gen4 Planner: support aggregate UDFs (#15710)
Browse files Browse the repository at this point in the history
  • Loading branch information
systay authored Apr 17, 2024
1 parent 178e6e8 commit a63f9c9
Show file tree
Hide file tree
Showing 37 changed files with 314 additions and 156 deletions.
4 changes: 4 additions & 0 deletions go/test/vschemawrapper/vschema_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ func (vw *VSchemaWrapper) KeyspaceError(keyspace string) error {
return nil
}

func (vw *VSchemaWrapper) GetAggregateUDFs() (udfs []string) {
return vw.V.GetAggregateUDFs()
}

func (vw *VSchemaWrapper) GetForeignKeyChecksState() *bool {
return vw.ForeignKeyChecksState
}
Expand Down
4 changes: 4 additions & 0 deletions go/vt/schemadiff/semantics.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ func (si *declarativeSchemaInformation) KeyspaceError(keyspace string) error {
return nil
}

func (si *declarativeSchemaInformation) GetAggregateUDFs() []string {
return nil
}

func (si *declarativeSchemaInformation) GetForeignKeyChecksState() *bool {
return nil
}
Expand Down
10 changes: 10 additions & 0 deletions go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,16 @@ func (node IdentifierCI) EqualString(str string) bool {
return node.Lowered() == strings.ToLower(str)
}

// EqualsAnyString returns true if any of these strings match
func (node IdentifierCI) EqualsAnyString(str []string) bool {
for _, s := range str {
if node.EqualString(s) {
return true
}
}
return false
}

// MarshalJSON marshals into JSON.
func (node IdentifierCI) MarshalJSON() ([]byte, error) {
return json.Marshal(node.val)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vterrors/code.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ var (
VT03029 = errorWithState("VT03029", vtrpcpb.Code_INVALID_ARGUMENT, WrongValueCountOnRow, "column count does not match value count with the row for vindex '%s'", "The number of columns you want to insert do not match the number of columns of your SELECT query.")
VT03030 = errorWithState("VT03030", vtrpcpb.Code_INVALID_ARGUMENT, WrongValueCountOnRow, "lookup column count does not match value count with the row (columns, count): (%v, %d)", "The number of columns you want to insert do not match the number of columns of your SELECT query.")
VT03031 = errorWithoutState("VT03031", vtrpcpb.Code_INVALID_ARGUMENT, "EXPLAIN is only supported for single keyspace", "EXPLAIN has to be sent down as a single query to the underlying MySQL, and this is not possible if it uses tables from multiple keyspaces")
VT03032 = errorWithState("VT03031", vtrpcpb.Code_INVALID_ARGUMENT, NonUpdateableTable, "the target table %s of the UPDATE is not updatable", "You cannot update a table that is not a real MySQL table.")
VT03032 = errorWithState("VT03032", vtrpcpb.Code_INVALID_ARGUMENT, NonUpdateableTable, "the target table %s of the UPDATE is not updatable", "You cannot update a table that is not a real MySQL table.")

VT05001 = errorWithState("VT05001", vtrpcpb.Code_NOT_FOUND, DbDropExists, "cannot drop database '%s'; database does not exists", "The given database does not exist; Vitess cannot drop it.")
VT05002 = errorWithState("VT05002", vtrpcpb.Code_NOT_FOUND, BadDb, "cannot alter database '%s'; unknown database", "The given database does not exist; Vitess cannot alter it.")
Expand Down
16 changes: 3 additions & 13 deletions go/vt/vtgate/engine/opcode/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,10 @@ const (
AggregateCountStar
AggregateGroupConcat
AggregateAvg
AggregateUDF // This is an opcode used to represent UDFs
_NumOfOpCodes // This line must be last of the opcodes!
)

var (
// OpcodeType keeps track of the known output types for different aggregate functions
OpcodeType = map[AggregateOpcode]querypb.Type{
AggregateCountDistinct: sqltypes.Int64,
AggregateCount: sqltypes.Int64,
AggregateCountStar: sqltypes.Int64,
AggregateSumDistinct: sqltypes.Decimal,
AggregateSum: sqltypes.Decimal,
AggregateAvg: sqltypes.Decimal,
AggregateGtid: sqltypes.VarChar,
}
)

// SupportedAggregates maps the list of supported aggregate
// functions to their opcodes.
var SupportedAggregates = map[string]AggregateOpcode{
Expand Down Expand Up @@ -166,6 +154,8 @@ func (code AggregateOpcode) SQLType(typ querypb.Type) querypb.Type {
return sqltypes.Int64
case AggregateGtid:
return sqltypes.VarChar
case AggregateUDF:
return sqltypes.Unknown
default:
panic(code.String()) // we have a unit test checking we never reach here
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func gen4DeleteStmtPlanner(
return nil, err
}

err = queryRewrite(ctx.SemTable, reservedVars, deleteStmt)
err = queryRewrite(ctx, deleteStmt)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func gen4InsertStmtPlanner(version querypb.ExecuteOptions_PlannerVersion, insStm
return nil, err
}

err = queryRewrite(ctx.SemTable, reservedVars, insStmt)
err = queryRewrite(ctx, insStmt)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) {

switch stmt := qb.stmt.(type) {
case *sqlparser.Select:
if containsAggr(expr) {
if ContainsAggr(qb.ctx, expr) {
addPred = stmt.AddHaving
} else {
addPred = stmt.AddWhere
Expand Down
8 changes: 7 additions & 1 deletion go/vt/vtgate/planbuilder/operators/aggregator.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func (a *Aggregator) AddPredicate(_ *plancontext.PlanningContext, expr sqlparser
return newFilter(a, expr)
}

func (a *Aggregator) addColumnWithoutPushing(_ *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) int {
func (a *Aggregator) addColumnWithoutPushing(ctx *plancontext.PlanningContext, expr *sqlparser.AliasedExpr, addToGroupBy bool) int {
offset := len(a.Columns)
a.Columns = append(a.Columns, expr)

Expand All @@ -96,6 +96,12 @@ func (a *Aggregator) addColumnWithoutPushing(_ *plancontext.PlanningContext, exp
switch e := expr.Expr.(type) {
case sqlparser.AggrFunc:
aggr = createAggrFromAggrFunc(e, expr)
case *sqlparser.FuncExpr:
if IsAggr(ctx, e) {
aggr = NewAggr(opcode.AggregateUDF, nil, expr, expr.As.String())
} else {
aggr = NewAggr(opcode.AggregateAnyValue, nil, expr, expr.As.String())
}
default:
aggr = NewAggr(opcode.AggregateAnyValue, nil, expr, expr.As.String())
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func breakExpressionInLHSandRHSForApplyJoin(
) (col applyJoinColumn) {
rewrittenExpr := sqlparser.CopyOnRewrite(expr, nil, func(cursor *sqlparser.CopyOnWriteCursor) {
nodeExpr, ok := cursor.Node().(sqlparser.Expr)
if !ok || !mustFetchFromInput(nodeExpr) {
if !ok || !mustFetchFromInput(ctx, nodeExpr) {
return
}
deps := ctx.SemTable.RecursiveDeps(nodeExpr)
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/hash_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func (hj *HashJoin) addColumn(ctx *plancontext.PlanningContext, in sqlparser.Exp
}
inOffset := op.FindCol(ctx, expr, false)
if inOffset == -1 {
if !mustFetchFromInput(expr) {
if !mustFetchFromInput(ctx, expr) {
return -1
}

Expand Down Expand Up @@ -398,7 +398,7 @@ func (hj *HashJoin) addSingleSidedColumn(
}
inOffset := op.FindCol(ctx, expr, false)
if inOffset == -1 {
if !mustFetchFromInput(expr) {
if !mustFetchFromInput(ctx, expr) {
return -1
}

Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/horizon.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func (h *Horizon) AddPredicate(ctx *plancontext.PlanningContext, expr sqlparser.
}

newExpr := semantics.RewriteDerivedTableExpression(expr, tableInfo)
if sqlparser.ContainsAggregation(newExpr) {
if ContainsAggr(ctx, newExpr) {
return newFilter(h, expr)
}
h.Source = h.Source.AddPredicate(ctx, newExpr)
Expand Down
43 changes: 34 additions & 9 deletions go/vt/vtgate/planbuilder/operators/offset_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package operators
import (
"fmt"

"vitess.io/vitess/go/vt/vtgate/engine/opcode"

"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
Expand Down Expand Up @@ -56,10 +58,12 @@ func planOffsets(ctx *plancontext.PlanningContext, root Operator) Operator {
}

// mustFetchFromInput returns true for expressions that have to be fetched from the input and cannot be evaluated
func mustFetchFromInput(e sqlparser.SQLNode) bool {
switch e.(type) {
func mustFetchFromInput(ctx *plancontext.PlanningContext, e sqlparser.SQLNode) bool {
switch fun := e.(type) {
case *sqlparser.ColName, sqlparser.AggrFunc:
return true
case *sqlparser.FuncExpr:
return fun.Name.EqualsAnyString(ctx.VSchema.GetAggregateUDFs())
default:
return false
}
Expand Down Expand Up @@ -93,10 +97,10 @@ func useOffsets(ctx *plancontext.PlanningContext, expr sqlparser.Expr, op Operat
return rewritten.(sqlparser.Expr)
}

// addColumnsToInput adds columns needed by an operator to its input.
// This happens only when the filter expression can be retrieved as an offset from the underlying mysql.
func addColumnsToInput(ctx *plancontext.PlanningContext, root Operator) Operator {
visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) {
// addColumnsToInput adds columns needed by an operator to its input.
// This happens only when the filter expression can be retrieved as an offset from the underlying mysql.
addColumnsNeededByFilter := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) {
filter, ok := in.(*Filter)
if !ok {
return in, NoRewrite
Expand Down Expand Up @@ -126,12 +130,33 @@ func addColumnsToInput(ctx *plancontext.PlanningContext, root Operator) Operator
return in, NoRewrite
}

// while we are out here walking the operator tree, if we find a UDF in an aggregation, we should fail
failUDFAggregation := func(in Operator, _ semantics.TableSet, _ bool) (Operator, *ApplyResult) {
aggrOp, ok := in.(*Aggregator)
if !ok {
return in, NoRewrite
}
for _, aggr := range aggrOp.Aggregations {
if aggr.OpCode == opcode.AggregateUDF {
// we don't support UDFs in aggregation if it's still above a route
message := fmt.Sprintf("Aggregate UDF '%s' must be pushed down to MySQL", sqlparser.String(aggr.Original.Expr))
panic(vterrors.VT12001(message))
}
}
return in, NoRewrite
}

visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) {
out, res := addColumnsNeededByFilter(in, semantics.EmptyTableSet(), isRoot)
failUDFAggregation(in, semantics.EmptyTableSet(), isRoot)
return out, res
}

return TopDown(root, TableID, visitor, stopAtRoute)
}

// addColumnsToInput adds columns needed by an operator to its input.
// This happens only when the filter expression can be retrieved as an offset from the underlying mysql.
func pullDistinctFromUNION(_ *plancontext.PlanningContext, root Operator) Operator {
// isolateDistinctFromUnion will pull out the distinct from a union operator
func isolateDistinctFromUnion(_ *plancontext.PlanningContext, root Operator) Operator {
visitor := func(in Operator, _ semantics.TableSet, isRoot bool) (Operator, *ApplyResult) {
union, ok := in.(*Union)
if !ok || !union.distinct {
Expand Down Expand Up @@ -170,7 +195,7 @@ func getOffsetRewritingVisitor(
return false
}

if mustFetchFromInput(e) {
if mustFetchFromInput(ctx, e) {
notFound(e)
return false
}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/phases.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (p Phase) shouldRun(s semantics.QuerySignature) bool {
func (p Phase) act(ctx *plancontext.PlanningContext, op Operator) Operator {
switch p {
case pullDistinctFromUnion:
return pullDistinctFromUNION(ctx, op)
return isolateDistinctFromUnion(ctx, op)
case delegateAggregation:
return enableDelegateAggregation(ctx, op)
case addAggrOrdering:
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/operators/query_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ func tryPushOrdering(ctx *plancontext.PlanningContext, in *Ordering) (Operator,
case *Projection:
// we can move ordering under a projection if it's not introducing a column we're sorting by
for _, by := range in.Order {
if !mustFetchFromInput(by.SimplifiedExpr) {
if !mustFetchFromInput(ctx, by.SimplifiedExpr) {
return in, NoRewrite
}
}
Expand Down Expand Up @@ -459,7 +459,7 @@ func pushFilterUnderProjection(ctx *plancontext.PlanningContext, filter *Filter,
for _, p := range filter.Predicates {
cantPush := false
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
if !mustFetchFromInput(node) {
if !mustFetchFromInput(ctx, node) {
return true, nil
}

Expand Down
Loading

0 comments on commit a63f9c9

Please sign in to comment.