diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala index 0a3e4ebe2cd1..ac85892b7a3e 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxMetricsSuite.scala @@ -20,6 +20,8 @@ import org.apache.gluten.GlutenConfig import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.SparkConf +import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} +import org.apache.spark.sql.TestUtils import org.apache.spark.sql.execution.CommandResultExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.internal.SQLConf @@ -201,4 +203,28 @@ class VeloxMetricsSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa } } } + + test("File scan task input metrics") { + createTPCHNotNullTables() + + @volatile var inputRecords = 0L + val partTableRecords = spark.sql("select * from part").count() + val itemTableRecords = spark.sql("select * from lineitem").count() + val inputMetricsListener = new SparkListener { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + inputRecords += stageCompleted.stageInfo.taskMetrics.inputMetrics.recordsRead + } + } + + TestUtils.withListener(spark.sparkContext, inputMetricsListener) { + _ => + val df = spark.sql(""" + |select /*+ BROADCAST(part) */ * from part join lineitem + |on l_partkey = p_partkey + |""".stripMargin) + df.count() + } + + assert(inputRecords == (partTableRecords + itemTableRecords)) + } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala index 78132c08c782..6892921f831f 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/WholeStageTransformer.scala @@ -38,11 +38,13 @@ import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.FilePartition import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.utils.SparkInputMetricsUtil.InputMetricsWrapper import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.common.collect.Lists import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer case class TransformContext( inputAttributes: Seq[Attribute], @@ -300,7 +302,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f inputPartitions, inputRDDs, pipelineTime, - leafMetricsUpdater().updateInputMetrics, + leafInputMetricsUpdater(), BackendsApiManager.getMetricsApiInstance.metricsUpdatingFunction( child, wsCtx.substraitContext.registeredRelMap, @@ -354,14 +356,25 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f } } - private def leafMetricsUpdater(): MetricsUpdater = { - child - .find { - case t: TransformSupport if t.children.forall(!_.isInstanceOf[TransformSupport]) => true - case _ => false + private def leafInputMetricsUpdater(): InputMetricsWrapper => Unit = { + def collectLeaves(plan: SparkPlan, buffer: ArrayBuffer[TransformSupport]): Unit = { + plan match { + case node: TransformSupport if node.children.forall(!_.isInstanceOf[TransformSupport]) => + buffer.append(node) + case node: TransformSupport => + node.children + .foreach(collectLeaves(_, buffer)) + case _ => } - .map(_.asInstanceOf[TransformSupport].metricsUpdater()) - .getOrElse(MetricsUpdater.None) + } + + val leafBuffer = new ArrayBuffer[TransformSupport]() + collectLeaves(child, leafBuffer) + val leafMetricsUpdater = leafBuffer.map(_.metricsUpdater()) + + (inputMetrics: InputMetricsWrapper) => { + leafMetricsUpdater.foreach(_.updateInputMetrics(inputMetrics)) + } } override protected def withNewChildInternal(newChild: SparkPlan): WholeStageTransformer = diff --git a/gluten-substrait/src/test/scala/org/apache/spark/sql/TestUtils.scala b/gluten-substrait/src/test/scala/org/apache/spark/sql/TestUtils.scala index a679c2272879..c87f59466360 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/sql/TestUtils.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/sql/TestUtils.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql import org.apache.gluten.exception.GlutenException +import org.apache.spark.{TestUtils => SparkTestUtils} +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.SparkListener import org.apache.spark.sql.test.SQLTestUtils object TestUtils { @@ -27,4 +30,8 @@ object TestUtils { throw new GlutenException("Failed to compare answer" + result.get) } } + + def withListener[L <: SparkListener](sc: SparkContext, listener: L)(body: L => Unit): Unit = { + SparkTestUtils.withListener(sc, listener)(body) + } }