Skip to content

Commit

Permalink
[GLUTEN-7068][CORE] Fix issue updating leaf input metrics (#7067)
Browse files Browse the repository at this point in the history
Closes #7068
  • Loading branch information
ivoson authored Sep 4, 2024
1 parent bdf3421 commit 8bc1842
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import org.apache.gluten.GlutenConfig
import org.apache.gluten.sql.shims.SparkShimLoader

import org.apache.spark.SparkConf
import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted}
import org.apache.spark.sql.TestUtils
import org.apache.spark.sql.execution.CommandResultExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -201,4 +203,28 @@ class VeloxMetricsSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa
}
}
}

test("File scan task input metrics") {
createTPCHNotNullTables()

@volatile var inputRecords = 0L
val partTableRecords = spark.sql("select * from part").count()
val itemTableRecords = spark.sql("select * from lineitem").count()
val inputMetricsListener = new SparkListener {
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
inputRecords += stageCompleted.stageInfo.taskMetrics.inputMetrics.recordsRead
}
}

TestUtils.withListener(spark.sparkContext, inputMetricsListener) {
_ =>
val df = spark.sql("""
|select /*+ BROADCAST(part) */ * from part join lineitem
|on l_partkey = p_partkey
|""".stripMargin)
df.count()
}

assert(inputRecords == (partTableRecords + itemTableRecords))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.FilePartition
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.utils.SparkInputMetricsUtil.InputMetricsWrapper
import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.common.collect.Lists

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

case class TransformContext(
inputAttributes: Seq[Attribute],
Expand Down Expand Up @@ -300,7 +302,7 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
inputPartitions,
inputRDDs,
pipelineTime,
leafMetricsUpdater().updateInputMetrics,
leafInputMetricsUpdater(),
BackendsApiManager.getMetricsApiInstance.metricsUpdatingFunction(
child,
wsCtx.substraitContext.registeredRelMap,
Expand Down Expand Up @@ -354,14 +356,25 @@ case class WholeStageTransformer(child: SparkPlan, materializeInput: Boolean = f
}
}

private def leafMetricsUpdater(): MetricsUpdater = {
child
.find {
case t: TransformSupport if t.children.forall(!_.isInstanceOf[TransformSupport]) => true
case _ => false
private def leafInputMetricsUpdater(): InputMetricsWrapper => Unit = {
def collectLeaves(plan: SparkPlan, buffer: ArrayBuffer[TransformSupport]): Unit = {
plan match {
case node: TransformSupport if node.children.forall(!_.isInstanceOf[TransformSupport]) =>
buffer.append(node)
case node: TransformSupport =>
node.children
.foreach(collectLeaves(_, buffer))
case _ =>
}
.map(_.asInstanceOf[TransformSupport].metricsUpdater())
.getOrElse(MetricsUpdater.None)
}

val leafBuffer = new ArrayBuffer[TransformSupport]()
collectLeaves(child, leafBuffer)
val leafMetricsUpdater = leafBuffer.map(_.metricsUpdater())

(inputMetrics: InputMetricsWrapper) => {
leafMetricsUpdater.foreach(_.updateInputMetrics(inputMetrics))
}
}

override protected def withNewChildInternal(newChild: SparkPlan): WholeStageTransformer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ package org.apache.spark.sql

import org.apache.gluten.exception.GlutenException

import org.apache.spark.{TestUtils => SparkTestUtils}
import org.apache.spark.SparkContext
import org.apache.spark.scheduler.SparkListener
import org.apache.spark.sql.test.SQLTestUtils

object TestUtils {
Expand All @@ -27,4 +30,8 @@ object TestUtils {
throw new GlutenException("Failed to compare answer" + result.get)
}
}

def withListener[L <: SparkListener](sc: SparkContext, listener: L)(body: L => Unit): Unit = {
SparkTestUtils.withListener(sc, listener)(body)
}
}

0 comments on commit 8bc1842

Please sign in to comment.