Skip to content

Commit

Permalink
[VL] Don't rewrite collect_list/collect_set in window
Browse files Browse the repository at this point in the history
  • Loading branch information
zml1206 committed Dec 13, 2024
1 parent 05b1e7a commit e5bfc52
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.shuffle.utils.ShuffleUtil
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, CollectList, CollectSet}
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical._
Expand Down Expand Up @@ -759,7 +759,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
Sig[UserDefinedAggregateFunction](ExpressionNames.UDAF_PLACEHOLDER),
Sig[NaNvl](ExpressionNames.NANVL),
Sig[VeloxCollectList](ExpressionNames.COLLECT_LIST),
Sig[CollectList](ExpressionNames.COLLECT_LIST),
Sig[VeloxCollectSet](ExpressionNames.COLLECT_SET),
Sig[CollectSet](ExpressionNames.COLLECT_SET),
Sig[VeloxBloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN),
Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG),
// For test purpose.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ import org.apache.gluten.expression.ExpressionMappings
import org.apache.gluten.expression.aggregate.{VeloxCollectList, VeloxCollectSet}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{Expression, WindowExpression}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Window}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, AGGREGATE_EXPRESSION, WINDOW, WINDOW_EXPRESSION}
import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, AGGREGATE_EXPRESSION}

import scala.reflect.{classTag, ClassTag}

Expand All @@ -40,7 +40,7 @@ case class CollectRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] {
return plan
}

val newPlan = plan.transformUpWithPruning(_.containsAnyPattern(WINDOW, AGGREGATE)) {
val newPlan = plan.transformUpWithPruning(_.containsPattern(AGGREGATE)) {
case node =>
replaceAggCollect(node)
}
Expand All @@ -57,12 +57,6 @@ case class CollectRewriteRule(spark: SparkSession) extends Rule[LogicalPlan] {
case ToVeloxCollect(newAggExpr) =>
newAggExpr
}
case w: Window =>
w.transformExpressionsWithPruning(
_.containsAllPatterns(AGGREGATE_EXPRESSION, WINDOW_EXPRESSION)) {
case windowExpr @ WindowExpression(ToVeloxCollect(newAggExpr), _) =>
windowExpr.copy(newAggExpr)
}
case other => other
}
}
Expand Down

0 comments on commit e5bfc52

Please sign in to comment.