Skip to content

Commit

Permalink
[feature](mtmv) Support agg state roll up and optimize the roll up co…
Browse files Browse the repository at this point in the history
…de (apache#35026)

agg_state is agg  intermediate state, detail see 
state combinator: https://doris.apache.org/zh-CN/docs/dev/sql-manual/sql-functions/combinators/state

this support agg function roll up as following
 
+---------------------+---------------------------------------------+---------------------+
| query               | materialized view                           | roll up             |
| ------------------- | ------------------------------------------- | ------------------- |
| agg_funtion()       | agg_funtion_unoin()  or agg_funtion_state() | agg_funtion_merge() |
| agg_funtion_unoin() | agg_funtion_unoin() or agg_funtion_state()  | agg_funtion_union() |
| agg_funtion_merge() | agg_funtion_unoin() or agg_funtion_state()  | agg_funtion_merge() |
+---------------------+---------------------------------------------+---------------------+

for example which can be rewritten by mv sucessfully as following

MV defination is

```
            select
            o_orderstatus,
            l_partkey,
            l_suppkey,
            sum_union(sum_state(o_shippriority)),
            group_concat_union(group_concat_state(l_shipinstruct)),
            avg_union(avg_state(l_linenumber)),
            max_by_union(max_by_state(l_shipmode, l_suppkey)),
            count_union(count_state(l_orderkey)),
            multi_distinct_count_union(multi_distinct_count_state(l_shipmode))
            from lineitem
            left join orders
            on lineitem.l_orderkey = o_orderkey and l_shipdate = o_orderdate
            group by
            o_orderstatus,
            l_partkey,
            l_suppkey;
```

Query is

```
            select
            o_orderstatus,
            l_suppkey,
            sum(o_shippriority),
            group_concat(l_shipinstruct),
            avg(l_linenumber),
            max_by(l_shipmode,l_suppkey),
            count(l_orderkey),
            multi_distinct_count(l_shipmode)
            from lineitem
            left join orders 
            on l_orderkey = o_orderkey and l_shipdate = o_orderdate
            group by
            o_orderstatus,
            l_suppkey;
```
  • Loading branch information
seawinde authored May 24, 2024
1 parent fc27d7a commit 2d94142
Show file tree
Hide file tree
Showing 28 changed files with 1,350 additions and 244 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,47 +22,35 @@
import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanCheckContext;
import org.apache.doris.nereids.rules.exploration.mv.StructInfo.PlanSplitContext;
import org.apache.doris.nereids.rules.exploration.mv.mapping.SlotMapping;
import org.apache.doris.nereids.rules.exploration.mv.rollup.AggFunctionRollUpHandler;
import org.apache.doris.nereids.rules.exploration.mv.rollup.BothCombinatorRollupHandler;
import org.apache.doris.nereids.rules.exploration.mv.rollup.DirectRollupHandler;
import org.apache.doris.nereids.rules.exploration.mv.rollup.MappingRollupHandler;
import org.apache.doris.nereids.rules.exploration.mv.rollup.SingleCombinatorRollupHandler;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Any;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnionCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.CouldRollUp;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnionAgg;
import org.apache.doris.nereids.trees.expressions.functions.agg.Ndv;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllCardinality;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllHash;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ToBitmap;
import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.ExpressionLineageReplacer;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

Expand All @@ -72,92 +60,15 @@
*/
public abstract class AbstractMaterializedViewAggregateRule extends AbstractMaterializedViewRule {

protected static final Multimap<Function, Expression>
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP = ArrayListMultimap.create();
public static final List<AggFunctionRollUpHandler> ROLL_UP_HANDLERS =
ImmutableList.of(DirectRollupHandler.INSTANCE,
MappingRollupHandler.INSTANCE,
SingleCombinatorRollupHandler.INSTANCE,
BothCombinatorRollupHandler.INSTANCE);

protected static final AggregateExpressionRewriter AGGREGATE_EXPRESSION_REWRITER =
new AggregateExpressionRewriter();

static {
// support roll up when count distinct is in query
// the column type is not bitMap
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE),
new BitmapUnion(new ToBitmap(Any.INSTANCE)));
// with bitmap_union, to_bitmap and cast
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Count(true, Any.INSTANCE),
new BitmapUnion(new ToBitmap(new Cast(Any.INSTANCE, BigIntType.INSTANCE))));

// support roll up when bitmap_union_count is in query
// the column type is bitMap
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new BitmapUnionCount(Any.INSTANCE),
new BitmapUnion(Any.INSTANCE));
// the column type is not bitMap
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new BitmapUnionCount(new ToBitmap(Any.INSTANCE)),
new BitmapUnion(new ToBitmap(Any.INSTANCE)));
// with bitmap_union, to_bitmap and cast
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new BitmapUnionCount(new ToBitmap(new Cast(Any.INSTANCE, BigIntType.INSTANCE))),
new BitmapUnion(new ToBitmap(new Cast(Any.INSTANCE, BigIntType.INSTANCE))));

// support roll up when the column type is not hll
// query is approx_count_distinct
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Ndv(Any.INSTANCE),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new Ndv(Any.INSTANCE),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));

// query is HLL_UNION_AGG
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnionAgg(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnionAgg(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnionAgg(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnionAgg(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));

// query is HLL_CARDINALITY
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllCardinality(new HllUnion(new HllHash(Any.INSTANCE))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllCardinality(new HllUnion(new HllHash(Any.INSTANCE))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllCardinality(new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT)))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllCardinality(new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT)))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));

// query is HLL_RAW_AGG or HLL_UNION
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnion(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnion(new HllHash(Any.INSTANCE)),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(Any.INSTANCE)));
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))),
new HllUnion(new HllHash(new Cast(Any.INSTANCE, VarcharType.SYSTEM_DEFAULT))));

// support roll up when the column type is hll
// query is HLL_UNION_AGG
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnionAgg(Any.INSTANCE),
new HllUnion(Any.INSTANCE));

// query is HLL_CARDINALITY
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllCardinality(new HllUnion(Any.INSTANCE)),
new HllUnion(Any.INSTANCE));

// query is HLL_RAW_AGG or HLL_UNION
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.put(new HllUnion(Any.INSTANCE),
new HllUnion(Any.INSTANCE));

}

@Override
protected Plan rewriteQueryByView(MatchMode matchMode,
StructInfo queryStructInfo,
Expand Down Expand Up @@ -375,35 +286,22 @@ private boolean isGroupByEquals(Pair<Plan, LogicalAggregate<Plan>> queryTopPlanA
private static Function rollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Map<Expression, Expression> mvExprToMvScanExprQueryBased) {
if (!(queryAggregateFunction instanceof CouldRollUp)) {
return null;
}
Expression rollupParam = null;
Expression viewRollupFunction = null;
// handle simple aggregate function roll up which is not in the AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP
if (mvExprToMvScanExprQueryBased.containsKey(queryAggregateFunctionShuttled)
&& AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.keySet().stream()
.noneMatch(aggFunction -> aggFunction.equals(queryAggregateFunction))) {
rollupParam = mvExprToMvScanExprQueryBased.get(queryAggregateFunctionShuttled);
viewRollupFunction = queryAggregateFunctionShuttled;
} else {
// handle complex functions roll up
// eg: query is count(distinct param), mv sql is bitmap_union(to_bitmap(param))
for (Expression mvExprShuttled : mvExprToMvScanExprQueryBased.keySet()) {
if (!(mvExprShuttled instanceof Function)) {
for (Map.Entry<Expression, Expression> expressionEntry : mvExprToMvScanExprQueryBased.entrySet()) {
Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair = Pair.of(expressionEntry.getKey(),
expressionEntry.getValue());
for (AggFunctionRollUpHandler rollUpHandler : ROLL_UP_HANDLERS) {
if (!rollUpHandler.canRollup(queryAggregateFunction, queryAggregateFunctionShuttled,
mvExprToMvScanExprQueryBasedPair)) {
continue;
}
if (isAggregateFunctionEquivalent(queryAggregateFunction, (Function) mvExprShuttled)) {
rollupParam = mvExprToMvScanExprQueryBased.get(mvExprShuttled);
viewRollupFunction = mvExprShuttled;
Function rollupFunction = rollUpHandler.doRollup(queryAggregateFunction,
queryAggregateFunctionShuttled, mvExprToMvScanExprQueryBasedPair);
if (rollupFunction != null) {
return rollupFunction;
}
}
}
if (rollupParam == null || !canRollup(viewRollupFunction)) {
return null;
}
// do roll up
return ((CouldRollUp) queryAggregateFunction).constructRollUp(rollupParam);
return null;
}

// Check the aggregate function can roll up or not, return true if could roll up
Expand All @@ -418,7 +316,7 @@ private static boolean canRollup(Expression rollupExpression) {
}
if (rollupExpression instanceof AggregateFunction) {
AggregateFunction aggregateFunction = (AggregateFunction) rollupExpression;
return !aggregateFunction.isDistinct() && aggregateFunction instanceof CouldRollUp;
return !aggregateFunction.isDistinct() && aggregateFunction instanceof RollUpTrait;
}
return true;
}
Expand Down Expand Up @@ -480,60 +378,6 @@ protected boolean checkPattern(StructInfo structInfo, CascadesContext cascadesCo
&& checkContext.isContainsTopAggregate() && checkContext.getTopAggregateNum() <= 1;
}

/**
* Check the queryFunction is equivalent to view function when function roll up.
* Not only check the function name but also check the argument between query and view aggregate function.
* Such as query is
* select count(distinct a) + 1 from table group by b.
* mv is
* select bitmap_union(to_bitmap(a)) from table group by a, b.
* the queryAggregateFunction is count(distinct a), queryAggregateFunctionShuttled is count(distinct a) + 1
* mvExprToMvScanExprQueryBased is { bitmap_union(to_bitmap(a)) : MTMVScan(output#0) }
* This will check the count(distinct a) in query is equivalent to bitmap_union(to_bitmap(a)) in mv,
* and then check their arguments is equivalent.
*/
private static boolean isAggregateFunctionEquivalent(Function queryFunction, Function viewFunction) {
if (queryFunction.equals(viewFunction)) {
return true;
}
// check the argument of rollup function is equivalent to view function or not
for (Map.Entry<Function, Collection<Expression>> equivalentFunctionEntry :
AGGREGATE_ROLL_UP_EQUIVALENT_FUNCTION_MAP.asMap().entrySet()) {
if (equivalentFunctionEntry.getKey().equals(queryFunction)) {
// check is have equivalent function or not
for (Expression equivalentFunction : equivalentFunctionEntry.getValue()) {
if (!Any.equals(equivalentFunction, viewFunction)) {
continue;
}
// check param in query function is same as the view function
List<Expression> viewFunctionArguments = extractArguments(equivalentFunction, viewFunction);
List<Expression> queryFunctionArguments =
extractArguments(equivalentFunctionEntry.getKey(), queryFunction);
// check argument size,we only support roll up function which has only one argument currently
if (queryFunctionArguments.size() != 1 || viewFunctionArguments.size() != 1) {
continue;
}
if (Objects.equals(queryFunctionArguments.get(0), viewFunctionArguments.get(0))) {
return true;
}
}
}
}
return false;
}

/**
* Extract the function arguments by functionWithAny pattern
* Such as functionWithAny def is bitmap_union(to_bitmap(Any.INSTANCE)),
* actualFunction is bitmap_union(to_bitmap(case when a = 5 then 1 else 2 end))
* after extracting, the return argument is: case when a = 5 then 1 else 2 end
*/
private static List<Expression> extractArguments(Expression functionWithAny, Function actualFunction) {
Set<Object> exprSetToRemove = functionWithAny.collectToSet(expr -> !(expr instanceof Any));
return actualFunction.collectFirst(expr ->
exprSetToRemove.stream().noneMatch(exprToRemove -> exprToRemove.equals(expr)));
}

/**
* Aggregate expression rewriter which is responsible for rewriting group by and
* aggregate function expression
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.exploration.mv.rollup;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Any;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.RollUpTrait;

import java.util.List;
import java.util.Set;

/**
* Aggregate function roll up handler
*/
public abstract class AggFunctionRollUpHandler {

/**
* Decide the query and view function can roll up or not
*/
public boolean canRollup(AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair) {
Expression viewExpression = mvExprToMvScanExprQueryBasedPair.key();
if (!(viewExpression instanceof RollUpTrait) || !((RollUpTrait) viewExpression).canRollUp()) {
return false;
}
AggregateFunction aggregateFunction = (AggregateFunction) viewExpression;
return !aggregateFunction.isDistinct();
}

/**
* Do the aggregate function roll up
*/
public abstract Function doRollup(
AggregateFunction queryAggregateFunction,
Expression queryAggregateFunctionShuttled,
Pair<Expression, Expression> mvExprToMvScanExprQueryBasedPair);

/**
* Extract the function arguments by functionWithAny pattern
* Such as functionWithAny def is bitmap_union(to_bitmap(Any.INSTANCE)),
* actualFunction is bitmap_union(to_bitmap(case when a = 5 then 1 else 2 end))
* after extracting, the return argument is: case when a = 5 then 1 else 2 end
*/
protected static List<Expression> extractArguments(Expression functionWithAny, Function actualFunction) {
Set<Object> exprSetToRemove = functionWithAny.collectToSet(expr -> !(expr instanceof Any));
return actualFunction.collectFirst(expr ->
exprSetToRemove.stream().noneMatch(exprToRemove -> exprToRemove.equals(expr)));
}

/**
* Extract the target expression in actualFunction by targetClazz
* Such as actualFunction def is avg_merge(avg_union(c1)), target Clazz is Combinator
* after extracting, the return argument is avg_union(c1)
*/
protected static <T> T extractLastExpression(Expression actualFunction, Class<T> targetClazz) {
List<Expression> expressions = actualFunction.collectToList(targetClazz::isInstance);
return targetClazz.cast(expressions.get(expressions.size() - 1));
}
}
Loading

0 comments on commit 2d94142

Please sign in to comment.