Skip to content

Commit

Permalink
[Improve](array) support array_enumerate_uniq and array_suffle for ne…
Browse files Browse the repository at this point in the history
…reids (apache#29936)
  • Loading branch information
amorynan authored Jan 15, 2024
1 parent cd2cf95 commit 18fe84d
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayDifference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayDistinct;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayEnumerate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayEnumerateUniq;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExcept;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExists;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFilter;
Expand All @@ -59,6 +60,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRemove;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRepeat;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayReverseSort;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayShuffle;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySlice;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySort;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySortBy;
Expand Down Expand Up @@ -445,6 +447,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(ArrayDifference.class, "array_difference"),
scalar(ArrayDistinct.class, "array_distinct"),
scalar(ArrayEnumerate.class, "array_enumerate"),
scalar(ArrayEnumerateUniq.class, "array_enumerate_uniq"),
scalar(ArrayExcept.class, "array_except"),
scalar(ArrayExists.class, "array_exists"),
scalar(ArrayFilter.class, "array_filter"),
Expand All @@ -470,6 +473,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(ArraySlice.class, "array_slice"),
scalar(ArraySort.class, "array_sort"),
scalar(ArraySortBy.class, "array_sortby"),
scalar(ArrayShuffle.class, "array_shuffle", "shuffle"),
scalar(ArraySum.class, "array_sum"),
scalar(ArrayUnion.class, "array_union"),
scalar(ArrayWithConstant.class, "array_with_constant"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* ScalarFunction 'array_enumerate_uniq'.
* more than 0 array as args
*/
public class ArrayEnumerateUniq extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE)).varArgs(ArrayType.of(new AnyDataType(0)))
);

/**
* constructor with more than 0 arguments.
*/
public ArrayEnumerateUniq(Expression arg, Expression ...varArgs) {
super("array_enumerate_uniq", ExpressionUtils.mergeArguments(arg, varArgs));
}

/**
* withChildren.
*/
@Override
public ArrayEnumerateUniq withChildren(List<Expression> children) {
Preconditions.checkArgument(!children.isEmpty());
return new ArrayEnumerateUniq(children.get(0), children.subList(1, children.size()).toArray(new Expression[0]));
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitArrayEnumerateUniq(this, context);
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.coercion.AnyDataType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* ScalarFunction 'array_shuffle'
* with 1 or 2 arguments : array_shuffle(arr) or array_shuffle(arr, seed)
*/
public class ArrayShuffle extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0).args(ArrayType.of(new AnyDataType(0))),
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), BigIntType.INSTANCE)
);

/**
* constructor with 1 arguments.
*/
public ArrayShuffle(Expression arg) {
super("array_shuffle", arg);
}

/**
* constructor with 2 arguments.
*/
public ArrayShuffle(Expression arg, Expression arg1) {
super("array_shuffle", arg, arg1);
}

/**
* withChildren.
*/
@Override
public ArrayShuffle withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1
|| children.size() == 2);
if (children.size() == 1) {
return new ArrayShuffle(children.get(0));
} else {
return new ArrayShuffle(children.get(0), children.get(1));
}
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitArrayShuffle(this, context);
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayDifference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayDistinct;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayEnumerate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayEnumerateUniq;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExcept;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExists;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFilter;
Expand All @@ -59,6 +60,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRemove;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRepeat;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayReverseSort;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayShuffle;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySlice;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySort;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySortBy;
Expand Down Expand Up @@ -484,6 +486,10 @@ default R visitArrayEnumerate(ArrayEnumerate arrayEnumerate, C context) {
return visitScalarFunction(arrayEnumerate, context);
}

default R visitArrayEnumerateUniq(ArrayEnumerateUniq arrayEnumerateUniq, C context) {
return visitScalarFunction(arrayEnumerateUniq, context);
}

default R visitArrayExcept(ArrayExcept arrayExcept, C context) {
return visitScalarFunction(arrayExcept, context);
}
Expand Down Expand Up @@ -564,6 +570,10 @@ default R visitArraySortBy(ArraySortBy arraySortBy, C context) {
return visitScalarFunction(arraySortBy, context);
}

default R visitArrayShuffle(ArrayShuffle arrayShuffle, C context) {
return visitScalarFunction(arrayShuffle, context);
}

default R visitArrayMap(ArrayMap arraySort, C context) {
return visitScalarFunction(arraySort, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,15 @@ array_enumerate_uniq
-- !old_sql --
[1]

-- !sql --
array_enumerate_uniq

-- !nereid_sql --
[1, 1, 2]

-- !nereid_sql --
[1, 1, 1]

-- !nereid_sql --
[1]

Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,26 @@ suite("test_array_zip_array_enumerate_uniq", "p0") {

// nereid not support array_enumerate_uniq
// ============= array_enumerate_uniq =========
// qt_sql "SELECT 'array_enumerate_uniq';"
// order_qt_nereid_sql """ SELECT array_enumerate_uniq(array_enumerate_uniq(array(cast(10 as LargeInt), cast(100 as LargeInt), cast(2 as LargeInt))), array(cast(123 as LargeInt), cast(1023 as LargeInt), cast(123 as LargeInt))); """
//
// order_qt_nereid_sql """SELECT array_enumerate_uniq(
// [111111, 222222, 333333],
// [444444, 555555, 666666],
// [111111, 222222, 333333],
// [444444, 555555, 666666],
// [111111, 222222, 333333],
// [444444, 555555, 666666],
// [111111, 222222, 333333],
// [444444, 555555, 666666]);"""
// order_qt_nereid_sql """SELECT array_enumerate_uniq(array(STDDEV_SAMP(910947.571364)), array(NULL)) from numbers;"""
qt_sql "SELECT 'array_enumerate_uniq';"
order_qt_nereid_sql """ SELECT array_enumerate_uniq(array_enumerate_uniq(array(cast(10 as LargeInt), cast(100 as LargeInt), cast(2 as LargeInt))), array(cast(123 as LargeInt), cast(1023 as LargeInt), cast(123 as LargeInt))); """

order_qt_nereid_sql """SELECT array_enumerate_uniq(
[111111, 222222, 333333],
[444444, 555555, 666666],
[111111, 222222, 333333],
[444444, 555555, 666666],
[111111, 222222, 333333],
[444444, 555555, 666666],
[111111, 222222, 333333],
[444444, 555555, 666666]);"""
order_qt_nereid_sql """SELECT array_enumerate_uniq(array(STDDEV_SAMP(910947.571364)), array(NULL)) from numbers;"""
// //order_qt_sql """ SELECT max(array_join(arr)) FROM (SELECT array_enumerate_uniq(group_array(DIV(number, 54321)) AS nums, group_array(cast(DIV(number, 98765) as string))) AS arr FROM (SELECT number FROM numbers LIMIT 1000000) GROUP BY bitmap_hash(number) % 100000);"""

// array_shuffle
// do not check result, since shuffle result is random
sql "SELECT array_sum(array_shuffle([1, 2, 3, 3, null, null, 4, 4])), array_shuffle([1, 2, 3, 3, null, null, 4, 4], 0), shuffle([1, 2, 3, 3, null, null, 4, 4], 0)"
sql "SELECT array_sum(array_shuffle([1.111, 2.222, 3.333])), array_shuffle([1.111, 2.222, 3.333], 0), shuffle([1.111, 2.222, 3.333], 0)"
sql "SELECT array_size(array_shuffle(['aaa', null, 'bbb', 'fff'])), array_shuffle(['aaa', null, 'bbb', 'fff'], 0), shuffle(['aaa', null, 'bbb', 'fff'], 0)"
sql """select array_size(array("2020-01-02", "2022-01-03", "2021-01-01", "1996-04-17")), array_shuffle(array("2020-01-02", "2022-01-03", "2021-01-01", "1996-04-17"), 0), shuffle(array("2020-01-02", "2022-01-03", "2021-01-01", "1996-04-17"), 0)"""

}

0 comments on commit 18fe84d

Please sign in to comment.