From 60eaface89bd8174c2049efac5270ab18edce470 Mon Sep 17 00:00:00 2001 From: Manan Gupta <35839558+GuptaManan100@users.noreply.github.com> Date: Wed, 3 Apr 2024 22:05:00 +0530 Subject: [PATCH] Fix AVG() sharded planning (#15626) Signed-off-by: Manan Gupta --- .../queries/aggregation/aggregation_test.go | 4 ++ .../planbuilder/operators/projection.go | 10 ++-- .../planbuilder/operators/query_planning.go | 21 ++++---- .../planbuilder/testdata/select_cases.json | 52 +++++++++++++++++++ 4 files changed, 73 insertions(+), 14 deletions(-) diff --git a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go index 05a45abac69..67a6c5e3710 100644 --- a/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go +++ b/go/test/endtoend/vtgate/queries/aggregation/aggregation_test.go @@ -91,6 +91,10 @@ func TestAggregateTypes(t *testing.T) { mcmp.SkipIfBinaryIsBelowVersion(19, "vtgate") mcmp.AssertMatches("select avg(val1) from aggr_test", `[[FLOAT64(0)]]`) }) + mcmp.Run("Average with group by without selecting the grouped columns", func(mcmp *utils.MySQLCompare) { + mcmp.SkipIfBinaryIsBelowVersion(20, "vtgate") + mcmp.AssertMatches("select avg(val2) from aggr_test group by val1 order by val1", `[[DECIMAL(1.0000)] [DECIMAL(1.0000)] [DECIMAL(3.5000)] [NULL] [DECIMAL(1.0000)]]`) + }) } func TestGroupBy(t *testing.T) { diff --git a/go/vt/vtgate/planbuilder/operators/projection.go b/go/vt/vtgate/planbuilder/operators/projection.go index 787d3a3b1de..2de282571df 100644 --- a/go/vt/vtgate/planbuilder/operators/projection.go +++ b/go/vt/vtgate/planbuilder/operators/projection.go @@ -182,12 +182,12 @@ var _ selectExpressions = (*Projection)(nil) // createSimpleProjection returns a projection where all columns are offsets. // used to change the name and order of the columns in the final output -func createSimpleProjection(ctx *plancontext.PlanningContext, qp *QueryProjection, src Operator) *Projection { +func createSimpleProjection(ctx *plancontext.PlanningContext, selExprs []sqlparser.SelectExpr, src Operator) *Projection { p := newAliasedProjection(src) - for _, e := range qp.SelectExprs { - ae, err := e.GetAliasedExpr() - if err != nil { - panic(err) + for _, e := range selExprs { + ae, isAe := e.(*sqlparser.AliasedExpr) + if !isAe { + panic(vterrors.VT09015()) } offset := p.Source.AddColumn(ctx, true, false, ae) expr := newProjExpr(ae) diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index e6db9d407e3..1b54a94201d 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -26,6 +26,12 @@ import ( ) func planQuery(ctx *plancontext.PlanningContext, root Operator) Operator { + var selExpr sqlparser.SelectExprs + if horizon, isHorizon := root.(*Horizon); isHorizon { + sel := sqlparser.GetFirstSelect(horizon.Query) + selExpr = sqlparser.CloneSelectExprs(sel.SelectExprs) + } + output := runPhases(ctx, root) output = planOffsets(ctx, output) @@ -36,7 +42,7 @@ func planQuery(ctx *plancontext.PlanningContext, root Operator) Operator { output = compact(ctx, output) - return addTruncationOrProjectionToReturnOutput(ctx, root, output) + return addTruncationOrProjectionToReturnOutput(ctx, selExpr, output) } // runPhases is the process of figuring out how to perform the operations in the Horizon @@ -571,24 +577,21 @@ func tryPushUnion(ctx *plancontext.PlanningContext, op *Union) (Operator, *Apply } // addTruncationOrProjectionToReturnOutput uses the original Horizon to make sure that the output columns line up with what the user asked for -func addTruncationOrProjectionToReturnOutput(ctx *plancontext.PlanningContext, oldHorizon Operator, output Operator) Operator { - horizon, ok := oldHorizon.(*Horizon) - if !ok { +func addTruncationOrProjectionToReturnOutput(ctx *plancontext.PlanningContext, selExprs sqlparser.SelectExprs, output Operator) Operator { + if len(selExprs) == 0 { return output } cols := output.GetSelectExprs(ctx) - sel := sqlparser.GetFirstSelect(horizon.Query) - if len(sel.SelectExprs) == len(cols) { + if len(selExprs) == len(cols) { return output } - if tryTruncateColumnsAt(output, len(sel.SelectExprs)) { + if tryTruncateColumnsAt(output, len(selExprs)) { return output } - qp := horizon.getQP(ctx) - proj := createSimpleProjection(ctx, qp, output) + proj := createSimpleProjection(ctx, selExprs, output) return proj } diff --git a/go/vt/vtgate/planbuilder/testdata/select_cases.json b/go/vt/vtgate/planbuilder/testdata/select_cases.json index 9707fa68d2c..e925db28dbc 100644 --- a/go/vt/vtgate/planbuilder/testdata/select_cases.json +++ b/go/vt/vtgate/planbuilder/testdata/select_cases.json @@ -1695,6 +1695,58 @@ ] } }, + { + "comment": "avg in sharded keyspace with group by without selecting the group by columns", + "query": "select avg(intcol) as avg_col from user group by textcol1, textcol2 order by textcol1, textcol2;", + "plan": { + "QueryType": "SELECT", + "Original": "select avg(intcol) as avg_col from user group by textcol1, textcol2 order by textcol1, textcol2;", + "Instructions": { + "OperatorType": "SimpleProjection", + "ColumnNames": [ + "avg_col" + ], + "Columns": [ + 0 + ], + "Inputs": [ + { + "OperatorType": "Projection", + "Expressions": [ + "sum(intcol) / count(intcol) as avg_col", + ":1 as textcol1", + ":2 as textcol2" + ], + "Inputs": [ + { + "OperatorType": "Aggregate", + "Variant": "Ordered", + "Aggregates": "sum(0) AS avg_col, sum_count(3) AS count(intcol)", + "GroupBy": "1 COLLATE latin1_swedish_ci, (2|4) COLLATE ", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select sum(intcol) as avg_col, textcol1, textcol2, count(intcol), weight_string(textcol2) from `user` where 1 != 1 group by textcol1, textcol2, weight_string(textcol2)", + "OrderBy": "1 ASC COLLATE latin1_swedish_ci, (2|4) ASC COLLATE ", + "Query": "select sum(intcol) as avg_col, textcol1, textcol2, count(intcol), weight_string(textcol2) from `user` group by textcol1, textcol2, weight_string(textcol2) order by textcol1 asc, textcol2 asc", + "Table": "`user`" + } + ] + } + ] + } + ] + }, + "TablesUsed": [ + "user.user" + ] + } + }, { "comment": "don't filter on the vtgate", "query": "select 42 from dual where false",