diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index c33b8b50609635..2069e8ae8a49de 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -31,6 +31,7 @@ void register_aggregate_function_combinator_distinct(AggregateFunctionSimpleFact void register_aggregate_function_combinator_foreach(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_sum0(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_minmax(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_min_by(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_max_by(AggregateFunctionSimpleFactory& factory); @@ -70,6 +71,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { static AggregateFunctionSimpleFactory instance; std::call_once(oc, [&]() { register_aggregate_function_sum(instance); + register_aggregate_function_sum0(instance); register_aggregate_function_minmax(instance); register_aggregate_function_min_by(instance); register_aggregate_function_max_by(instance); diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp index 3ee7dc6ff48333..e0676957d467df 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sum.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_sum.cpp @@ -31,4 +31,8 @@ void register_aggregate_function_sum(AggregateFunctionSimpleFactory& factory) { "sum_decimal256", creator_with_type::creator); } +void register_aggregate_function_sum0(AggregateFunctionSimpleFactory& factory) { + factory.register_function_both("sum0", creator_with_type::creator); +} + } // namespace doris::vectorized diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java index e6134cfc31f9e5..1e631329e28ac3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java @@ -64,6 +64,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Stddev; import org.apache.doris.nereids.trees.expressions.functions.agg.StddevSamp; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0; import org.apache.doris.nereids.trees.expressions.functions.agg.TopN; import org.apache.doris.nereids.trees.expressions.functions.agg.TopNArray; import org.apache.doris.nereids.trees.expressions.functions.agg.TopNWeighted; @@ -132,6 +133,7 @@ public class BuiltinAggregateFunctions implements FunctionHelper { agg(Stddev.class, "stddev_pop", "stddev"), agg(StddevSamp.class, "stddev_samp"), agg(Sum.class, "sum"), + agg(Sum0.class, "sum0"), agg(TopN.class, "topn"), agg(TopNArray.class, "topn_array"), agg(TopNWeighted.class, "topn_weighted"), diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index b69d52b2c11c1b..5b22d5bbc08470 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -46,6 +46,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Set; public class FunctionSet { @@ -1174,86 +1175,91 @@ private void initAggregateBuiltins() { } // Sum - String []sumNames = {"sum", "sum_distinct"}; - for (String name : sumNames) { - addBuiltin(AggregateFunction.createBuiltin(name, + // functionName(String) -> returnsNonNullOnEmpty(Boolean) + Map sumNames = ImmutableMap.of( + "sum", false, + "sum_distinct", false, + "sum0", true + ); + for (Entry nameWithReturn : sumNames.entrySet()) { + addBuiltin(AggregateFunction.createBuiltin(nameWithReturn.getKey(), Lists.newArrayList(Type.BOOLEAN), Type.BIGINT, Type.BIGINT, "", "", "", null, null, "", - null, false, true, false, true)); + null, false, true, nameWithReturn.getValue(), true)); - addBuiltin(AggregateFunction.createBuiltin(name, + addBuiltin(AggregateFunction.createBuiltin(nameWithReturn.getKey(), Lists.newArrayList(Type.TINYINT), Type.BIGINT, Type.BIGINT, "", "", "", null, null, "", - null, false, true, false, true)); - addBuiltin(AggregateFunction.createBuiltin(name, + null, false, true, nameWithReturn.getValue(), true)); + addBuiltin(AggregateFunction.createBuiltin(nameWithReturn.getKey(), Lists.newArrayList(Type.SMALLINT), Type.BIGINT, Type.BIGINT, "", "", "", null, null, "", - null, false, true, false, true)); - addBuiltin(AggregateFunction.createBuiltin(name, + null, false, true, nameWithReturn.getValue(), true)); + addBuiltin(AggregateFunction.createBuiltin(nameWithReturn.getKey(), Lists.newArrayList(Type.INT), Type.BIGINT, Type.BIGINT, "", "", "", null, null, "", - null, false, true, false, true)); - addBuiltin(AggregateFunction.createBuiltin(name, + null, false, true, nameWithReturn.getValue(), true)); + addBuiltin(AggregateFunction.createBuiltin(nameWithReturn.getKey(), Lists.newArrayList(Type.BIGINT), Type.BIGINT, Type.BIGINT, "", "", "", null, null, "", - null, false, true, false, true)); - addBuiltin(AggregateFunction.createBuiltin(name, + null, false, true, nameWithReturn.getValue(), true)); + addBuiltin(AggregateFunction.createBuiltin(nameWithReturn.getKey(), Lists.newArrayList(Type.DOUBLE), Type.DOUBLE, Type.DOUBLE, "", "", "", null, null, "", - null, false, true, false, true)); - addBuiltin(AggregateFunction.createBuiltin(name, + null, false, true, nameWithReturn.getValue(), true)); + addBuiltin(AggregateFunction.createBuiltin(nameWithReturn.getKey(), Lists.newArrayList(Type.MAX_DECIMALV2_TYPE), Type.MAX_DECIMALV2_TYPE, Type.MAX_DECIMALV2_TYPE, "", "", "", null, null, "", - null, false, true, false, true)); - addBuiltin(AggregateFunction.createBuiltin(name, + null, false, true, nameWithReturn.getValue(), true)); + addBuiltin(AggregateFunction.createBuiltin(nameWithReturn.getKey(), Lists.newArrayList(Type.DECIMAL32), ScalarType.DECIMAL128, Type.DECIMAL128, "", "", "", null, null, "", - null, false, true, false, true)); - addBuiltin(AggregateFunction.createBuiltin(name, + null, false, true, nameWithReturn.getValue(), true)); + addBuiltin(AggregateFunction.createBuiltin(nameWithReturn.getKey(), Lists.newArrayList(Type.DECIMAL64), Type.DECIMAL128, Type.DECIMAL128, "", "", "", null, null, "", - null, false, true, false, true)); - addBuiltin(AggregateFunction.createBuiltin(name, + null, false, true, nameWithReturn.getValue(), true)); + addBuiltin(AggregateFunction.createBuiltin(nameWithReturn.getKey(), Lists.newArrayList(Type.DECIMAL128), Type.DECIMAL128, Type.DECIMAL128, "", "", "", null, null, "", - null, false, true, false, true)); - addBuiltin(AggregateFunction.createBuiltin(name, + null, false, true, nameWithReturn.getValue(), true)); + addBuiltin(AggregateFunction.createBuiltin(nameWithReturn.getKey(), Lists.newArrayList(Type.LARGEINT), Type.LARGEINT, Type.LARGEINT, "", "", "", null, null, "", - null, false, true, false, true)); + null, false, true, nameWithReturn.getValue(), true)); } Type[] types = {Type.SMALLINT, Type.TINYINT, Type.INT, Type.BIGINT, Type.FLOAT, Type.DOUBLE, Type.CHAR, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java new file mode 100644 index 00000000000000..da84d0dc75940a --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum0.java @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.agg; + +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable; +import org.apache.doris.nereids.trees.expressions.functions.ComputePrecisionForSum; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.Function; +import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindowAnalytic; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.BooleanType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.LargeIntType; +import org.apache.doris.nereids.types.SmallIntType; +import org.apache.doris.nereids.types.TinyIntType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * AggregateFunction 'sum0'. sum0 returns the sum of the values which go into it like sum. + * It differs in that when no non null values are applied zero is returned instead of null. + */ +public class Sum0 extends AggregateFunction + implements UnaryExpression, AlwaysNotNullable, ExplicitlyCastableSignature, ComputePrecisionForSum, + SupportWindowAnalytic, CouldRollUp { + + public static final List SIGNATURES = ImmutableList.of( + FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE), + FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE), + FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE), + FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE), + FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE), + FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE), + FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD), + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE) + ); + + /** + * constructor with 1 argument. + */ + public Sum0(Expression arg) { + this(false, arg); + } + + /** + * constructor with 2 argument. + */ + public Sum0(boolean distinct, Expression arg) { + super("sum0", distinct, arg); + } + + public MultiDistinctSum convertToMultiDistinct() { + Preconditions.checkArgument(distinct, + "can't convert to multi_distinct_sum because there is no distinct args"); + return new MultiDistinctSum(false, child()); + } + + @Override + public void checkLegalityBeforeTypeCoercion() { + DataType argType = child().getDataType(); + if ((!argType.isNumericType() && !argType.isBooleanType() && !argType.isNullType()) + || argType.isOnlyMetricType()) { + throw new AnalysisException("sum requires a numeric or boolean parameter: " + this.toSql()); + } + } + + /** + * withDistinctAndChildren. + */ + @Override + public Sum0 withDistinctAndChildren(boolean distinct, List children) { + Preconditions.checkArgument(children.size() == 1); + return new Sum0(distinct, children.get(0)); + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitSum0(this, context); + } + + @Override + public List getSignatures() { + return SIGNATURES; + } + + @Override + public FunctionSignature searchSignature(List signatures) { + if (getArgument(0).getDataType() instanceof FloatType) { + return FunctionSignature.ret(DoubleType.INSTANCE).args(FloatType.INSTANCE); + } + return ExplicitlyCastableSignature.super.searchSignature(signatures); + } + + @Override + public Function constructRollUp(Expression param, Expression... varParams) { + return new Sum0(this.distinct, param); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java index 594f9c754335aa..bb7c9286d4912a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java @@ -66,6 +66,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Stddev; import org.apache.doris.nereids.trees.expressions.functions.agg.StddevSamp; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0; import org.apache.doris.nereids.trees.expressions.functions.agg.TopN; import org.apache.doris.nereids.trees.expressions.functions.agg.TopNArray; import org.apache.doris.nereids.trees.expressions.functions.agg.TopNWeighted; @@ -274,6 +275,10 @@ default R visitSum(Sum sum, C context) { return visitNullableAggregateFunction(sum, context); } + default R visitSum0(Sum0 sum0, C context) { + return visitAggregateFunction(sum0, context); + } + default R visitTopN(TopN topN, C context) { return visitAggregateFunction(topN, context); }