Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Query planner: Simplify validator FallbackByNativeValidation #8177

Merged
merged 16 commits into from
Dec 9, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ package org.apache.gluten.backendsapi.clickhouse

import org.apache.gluten.GlutenBuildInfo._
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backend.Component.BuildInfo
import org.apache.gluten.backendsapi._
import org.apache.gluten.columnarbatch.CHBatch
import org.apache.gluten.component.Component.BuildInfo
import org.apache.gluten.execution.WriteFilesExecTransformer
import org.apache.gluten.expression.WindowFunctionsBuilder
import org.apache.gluten.extension.ValidationResult
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ object CHRuleApi {
injector.injectPreTransform(_ => WriteFilesWithBucketValue)

// Legacy: The legacy transform rule.
val offloads = Seq(OffloadOthers(), OffloadExchange(), OffloadJoin())
val validatorBuilder: GlutenConfig => Validator = conf =>
Validator
.builder()
Expand All @@ -91,11 +92,10 @@ object CHRuleApi {
.fallbackByBackendSettings()
.fallbackByUserOptions()
.fallbackByTestInjects()
.fallbackByNativeValidation()
.fallbackByNativeValidation(offloads)
.build()
val rewrites =
Seq(RewriteIn, RewriteMultiChildrenCount, RewriteJoin, PullOutPreProject, PullOutPostProject)
val offloads = Seq(OffloadOthers(), OffloadExchange(), OffloadJoin())
injector.injectTransform(
c => intercept(HeuristicTransform.Single(validatorBuilder(c.glutenConf), rewrites, offloads)))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
}

override def createColumnarWriteFilesExec(
child: SparkPlan,
child: WriteFilesExecTransformer,
noop: SparkPlan,
fileFormat: FileFormat,
partitionColumns: Seq[Attribute],
Expand All @@ -666,6 +666,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
CHColumnarWriteFilesExec(
child,
noop,
child,
fileFormat,
partitionColumns,
bucketSpec,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.execution

import org.apache.gluten.execution.WriteFilesExecTransformer
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.memory.CHThreadGroup

import org.apache.spark.{Partition, SparkException, TaskContext, TaskOutputFileAlreadyExistException}
Expand Down Expand Up @@ -149,17 +151,22 @@ class CHColumnarWriteFilesRDD(
case class CHColumnarWriteFilesExec(
override val left: SparkPlan,
override val right: SparkPlan,
t: WriteFilesExecTransformer,
fileFormat: FileFormat,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
options: Map[String, String],
staticPartitions: TablePartitionSpec
) extends ColumnarWriteFilesExec(left, right) {

override protected def doValidateInternal(): ValidationResult = {
t.doValidateInternal()
}

override protected def withNewChildrenInternal(
newLeft: SparkPlan,
newRight: SparkPlan): SparkPlan =
copy(newLeft, newRight, fileFormat, partitionColumns, bucketSpec, options, staticPartitions)
copy(newLeft, newRight, t, fileFormat, partitionColumns, bucketSpec, options, staticPartitions)

override def doExecuteWrite(writeFilesSpec: WriteFilesSpec): RDD[WriterCommitMessage] = {
assert(child.supportsColumnar)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ package org.apache.gluten.backendsapi.velox

import org.apache.gluten.GlutenBuildInfo._
import org.apache.gluten.GlutenConfig
import org.apache.gluten.backend.Component.BuildInfo
import org.apache.gluten.backendsapi._
import org.apache.gluten.columnarbatch.VeloxBatch
import org.apache.gluten.component.Component.BuildInfo
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution.WriteFilesExecTransformer
import org.apache.gluten.expression.WindowFunctionsBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ object VeloxRuleApi {
injector.injectPreTransform(c => ArrowScanReplaceRule.apply(c.session))

// Legacy: The legacy transform rule.
val offloads = Seq(OffloadOthers(), OffloadExchange(), OffloadJoin())
val validatorBuilder: GlutenConfig => Validator = conf =>
Validator
.builder()
Expand All @@ -83,11 +84,10 @@ object VeloxRuleApi {
.fallbackByBackendSettings()
.fallbackByUserOptions()
.fallbackByTestInjects()
.fallbackByNativeValidation()
.fallbackByNativeValidation(offloads)
.build()
val rewrites =
Seq(RewriteIn, RewriteMultiChildrenCount, RewriteJoin, PullOutPreProject, PullOutPostProject)
val offloads = Seq(OffloadOthers(), OffloadExchange(), OffloadJoin())
injector.injectTransform(
c => HeuristicTransform.Single(validatorBuilder(c.glutenConf), rewrites, offloads))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
ShuffleUtil.genColumnarShuffleWriter(parameters)
}
override def createColumnarWriteFilesExec(
child: SparkPlan,
child: WriteFilesExecTransformer,
noop: SparkPlan,
fileFormat: FileFormat,
partitionColumns: Seq[Attribute],
Expand All @@ -570,6 +570,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
VeloxColumnarWriteFilesExec(
child,
noop,
child,
fileFormat,
partitionColumns,
bucketSpec,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package org.apache.spark.sql.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.execution.WriteFilesExecTransformer
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators

import org.apache.spark.{Partition, SparkException, TaskContext, TaskOutputFileAlreadyExistException}
Expand Down Expand Up @@ -250,13 +252,18 @@ class VeloxColumnarWriteFilesRDD(
case class VeloxColumnarWriteFilesExec private (
override val left: SparkPlan,
override val right: SparkPlan,
t: WriteFilesExecTransformer,
fileFormat: FileFormat,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
options: Map[String, String],
staticPartitions: TablePartitionSpec)
extends ColumnarWriteFilesExec(left, right) {

override protected def doValidateInternal(): ValidationResult = {
t.doValidateInternal()
}

override def doExecuteWrite(writeFilesSpec: WriteFilesSpec): RDD[WriterCommitMessage] = {
assert(child.supportsColumnar)

Expand All @@ -276,5 +283,5 @@ case class VeloxColumnarWriteFilesExec private (
override protected def withNewChildrenInternal(
newLeft: SparkPlan,
newRight: SparkPlan): SparkPlan =
copy(newLeft, newRight, fileFormat, partitionColumns, bucketSpec, options, staticPartitions)
copy(newLeft, newRight, t, fileFormat, partitionColumns, bucketSpec, options, staticPartitions)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.gluten

import org.apache.gluten.GlutenBuildInfo._
import org.apache.gluten.GlutenConfig._
import org.apache.gluten.backend.Component
import org.apache.gluten.component.Component
import org.apache.gluten.events.GlutenBuildInfoEvent
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.extension.GlutenSessionExtensions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.gluten.backend

import org.apache.gluten.component.Component

trait Backend extends Component {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.backend
package org.apache.gluten.component

import org.apache.gluten.extension.columnar.transition.ConventionFunc
import org.apache.gluten.extension.injector.Injector
Expand Down Expand Up @@ -100,7 +100,7 @@ object Component {
graph.sorted()
}

private[backend] def sortedUnsafe(): Seq[Component] = {
private[component] def sortedUnsafe(): Seq[Component] = {
graph.sorted()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@
*/
package org.apache.gluten

import org.apache.gluten.backend.Backend

import org.apache.spark.internal.Logging

import java.util.ServiceLoader
import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.JavaConverters._

package object backend extends Logging {
private[backend] val allComponentsLoaded: AtomicBoolean = new AtomicBoolean(false)
package object component extends Logging {
private val allComponentsLoaded: AtomicBoolean = new AtomicBoolean(false)

private[backend] def ensureAllComponentsRegistered(): Unit = {
private[component] def ensureAllComponentsRegistered(): Unit = {
if (!allComponentsLoaded.compareAndSet(false, true)) {
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.gluten.extension

import org.apache.gluten.GlutenConfig
import org.apache.gluten.backend.Component
import org.apache.gluten.component.Component
import org.apache.gluten.extension.injector.Injector

import org.apache.spark.internal.Logging
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.gluten.extension.columnar.enumerated

import org.apache.gluten.backend.Component
import org.apache.gluten.component.Component
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall
import org.apache.gluten.extension.columnar.enumerated.planner.GlutenOptimization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.gluten.extension.columnar.enumerated.planner.plan

import org.apache.gluten.execution.GlutenPlan
import org.apache.gluten.extension.columnar.enumerated.planner.metadata.GlutenMetadata
import org.apache.gluten.extension.columnar.enumerated.planner.metadata.{GlutenMetadata, LogicalLink}
import org.apache.gluten.extension.columnar.enumerated.planner.property.{Conv, ConvDef}
import org.apache.gluten.extension.columnar.transition.{Convention, ConventionReq}
import org.apache.gluten.ras.{Metadata, PlanModel}
Expand All @@ -27,6 +27,7 @@ import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{ColumnarToRowExec, LeafExecNode, SparkPlan}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExecBase
import org.apache.spark.task.{SparkTaskUtil, TaskResources}
Expand Down Expand Up @@ -75,6 +76,16 @@ object GlutenPlanModel {
final override val supportsRowBased: Boolean = {
rowType() != Convention.RowType.None
}

override def logicalLink: Option[LogicalPlan] = {
if (metadata.logicalLink() eq LogicalLink.notFound) {
return None
}
Some(metadata.logicalLink().plan)
}

override def setLogicalLink(logicalPlan: LogicalPlan): Unit =
throw new UnsupportedOperationException()
}

private object PlanModelImpl extends PlanModel[SparkPlan] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.gluten.extension.columnar.heuristic

import org.apache.gluten.backend.Component
import org.apache.gluten.component.Component
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.extension.columnar.ColumnarRuleApplier.ColumnarRuleCall
import org.apache.gluten.extension.columnar.offload.OffloadSingleNode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan

class LegacyOffload(rules: Seq[OffloadSingleNode]) extends Rule[SparkPlan] with LogLevelUtil {

def apply(plan: SparkPlan): SparkPlan = {
val out =
rules.foldLeft(plan)((p, rule) => p.transformUp { case p => rule.offload(p) })
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.gluten.extension.columnar.transition

import org.apache.gluten.backend.Component
import org.apache.gluten.component.Component
import org.apache.gluten.extension.columnar.transition.ConventionReq.KnownChildConvention
import org.apache.gluten.sql.shims.SparkShimLoader

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ object Validator {

/** Add a custom validator to pipeline. */
def add(validator: Validator): Builder = {
buffer += validator
buffer ++= flatten(validator)
this
}

Expand All @@ -64,7 +64,15 @@ object Validator {
new ValidatorPipeline(buffer.toSeq)
}

private class ValidatorPipeline(validators: Seq[Validator]) extends Validator {
private def flatten(validator: Validator): Seq[Validator] = validator match {
case p: ValidatorPipeline =>
p.validators.flatMap(flatten)
case other => Seq(other)
}

private class ValidatorPipeline(val validators: Seq[Validator]) extends Validator {
assert(!validators.exists(_.isInstanceOf[ValidatorPipeline]))

override def validate(plan: SparkPlan): Validator.OutCome = {
val init: Validator.OutCome = pass()
val finalOut = validators.foldLeft(init) {
Expand All @@ -86,4 +94,10 @@ object Validator {
private object Builder {
def apply(): Builder = new Builder()
}

implicit class ValidatorImplicits(v: Validator) {
def andThen(other: Validator): Validator = {
builder().add(v).add(other).build()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.backend
package org.apache.gluten.component

import org.apache.gluten.backend.Backend
import org.apache.gluten.extension.injector.Injector

import org.scalatest.BeforeAndAfterAll
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.gluten.backendsapi

import org.apache.gluten.backend.Component
import org.apache.gluten.component.Component

object BackendsApiManager {
private lazy val backend: SubstraitBackend = initializeInternal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ trait SparkPlanExecApi {

/** Create ColumnarWriteFilesExec */
def createColumnarWriteFilesExec(
child: SparkPlan,
child: WriteFilesExecTransformer,
noop: SparkPlan,
fileFormat: FileFormat,
partitionColumns: Seq[Attribute],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
val transformStageId: Int
) extends WholeStageTransformerGenerateTreeStringShim
with UnaryTransformSupport {
assert(child.isInstanceOf[TransformSupport])

def stageId: Int = transformStageId

Expand Down Expand Up @@ -353,6 +352,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
}

override def doExecuteColumnar(): RDD[ColumnarBatch] = {
assert(child.isInstanceOf[TransformSupport])
val pipelineTime: SQLMetric = longMetric("pipelineTime")
// We should do transform first to make sure all subqueries are materialized
val wsCtx = GlutenTimeMetric.withMillisTime {
Expand Down
Loading
Loading