Skip to content

Commit

Permalink
Refactor SparkSqlMeasure.scala
Browse files Browse the repository at this point in the history
  • Loading branch information
sjgllgh committed Nov 25, 2024
1 parent 78a9344 commit 66e67a3
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,21 @@ import org.apache.linkis.engineplugin.spark.sparkmeasure.SparkSqlMeasure
import org.apache.linkis.engineplugin.spark.utils.{DirectPushCache, EngineUtils}
import org.apache.linkis.governance.common.constant.job.JobRequestConstants
import org.apache.linkis.governance.common.paser.SQLCodeParser
import org.apache.linkis.governance.common.utils.JobUtils
import org.apache.linkis.manager.label.utils.LabelUtil
import org.apache.linkis.scheduler.executer._

import org.apache.commons.lang3.ObjectUtils
import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.spark.sql.DataFrame

import java.lang.reflect.InvocationTargetException
import java.util
import java.util.Date

import scala.jdk.CollectionConverters.seqAsJavaListConverter

import ch.cern.sparkmeasure.{StageMetrics, TaskMetrics}

class SparkSqlExecutor(
sparkEngineSession: SparkEngineSession,
id: Long,
Expand Down Expand Up @@ -91,30 +96,19 @@ class SparkSqlExecutor(
.setContextClassLoader(sparkEngineSession.sparkSession.sharedState.jarClassLoader)
val extensions =
org.apache.linkis.engineplugin.spark.extension.SparkSqlExtension.getSparkSqlExtensions()
val sparkMeasureType = engineExecutionContext.getProperties
.get(SparkConfiguration.SPARKMEASURE_AGGREGATE_TYPE)
val df = if (ObjectUtils.isNotEmpty(sparkMeasureType)) {
val outputPrefix = SparkConfiguration.SPARKMEASURE_OUTPUT_PREFIX.getValue(options)
val outputPath = FsPath.getFsPath(
outputPrefix,
options.get("user"),
sparkMeasureType.toString,
options.get("jobId"),
new Date().getTime.toString
)
val sparkMeasure =
new SparkSqlMeasure(
sparkEngineSession.sparkSession,
code,
sparkMeasureType.toString,
outputPath
)

sparkMeasure.generatorMetrics()
} else {
sparkEngineSession.sqlContext.sql(code)
// Start capturing Spark metrics
val sparkMeasure: Option[SparkSqlMeasure] =
createSparkMeasure(engineExecutionContext, sparkEngineSession, code)
val sparkMetrics: Option[Either[StageMetrics, TaskMetrics]] = sparkMeasure.flatMap {
measure =>
val metrics = measure.getSparkMetrics
metrics.foreach(measure.begin)
metrics
}

val df = sparkEngineSession.sqlContext.sql(code)

Utils.tryQuietly(
extensions.foreach(
_.afterExecutingSQL(
Expand All @@ -139,6 +133,13 @@ class SparkSqlExecutor(
engineExecutionContext
)
}

// Stop capturing Spark metrics and output the records to the specified file.
sparkMeasure.foreach { measure =>
sparkMetrics.foreach(measure.end)
sparkMetrics.foreach(measure.outputMetrics)
}

SuccessExecuteResponse()
} catch {
case e: InvocationTargetException =>
Expand All @@ -154,5 +155,32 @@ class SparkSqlExecutor(
}
}

/**
* 创建 SparkSqlMeasure 实例的辅助方法
*/
private def createSparkMeasure(
engineExecutionContext: EngineExecutionContext,
sparkEngineSession: SparkEngineSession,
code: String
): Option[SparkSqlMeasure] = {
val sparkMeasureType = engineExecutionContext.getProperties
.getOrDefault(SparkConfiguration.SPARKMEASURE_AGGREGATE_TYPE, "")
.toString

if (sparkMeasureType.nonEmpty) {
val outputPrefix = SparkConfiguration.SPARKMEASURE_OUTPUT_PREFIX.getValue(options)
val outputPath = FsPath.getFsPath(
outputPrefix,
LabelUtil.getUserCreator(engineExecutionContext.getLabels.toList.asJava)._1,
sparkMeasureType,
JobUtils.getJobIdFromMap(engineExecutionContext.getProperties),
new Date().getTime.toString
)
Some(new SparkSqlMeasure(sparkEngineSession.sparkSession, code, sparkMeasureType, outputPath))
} else {
None
}
}

override protected def getExecutorIdPreFix: String = "SparkSqlExecutor_"
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,72 +37,101 @@ class SparkSqlMeasure(
outputPath: FsPath
) extends Logging {

private val sqlType = getDQLSqlType

def generatorMetrics(): DataFrame = {
var df: DataFrame = null
val metricsMap: java.util.Map[String, Long] = metricType match {
case "stage" =>
val metrics = StageMetrics(sparkSession)
sqlType match {
case "SELECT" =>
df = sparkSession.sql(sql)
metrics.runAndMeasure(df.show(0))
metrics.aggregateStageMetricsJavaMap()
case "INSERT" =>
df = metrics.runAndMeasure(sparkSession.sql(sql))
metrics.aggregateStageMetricsJavaMap()
case _ =>
df = sparkSession.sql(sql)
new java.util.HashMap[String, Long]()
}
case "task" =>
val metrics = TaskMetrics(sparkSession)
sqlType match {
case "SELECT" =>
df = sparkSession.sql(sql)
metrics.runAndMeasure(df.show(0))
metrics.aggregateTaskMetricsJavaMap()
case "INSERT" =>
df = metrics.runAndMeasure(sparkSession.sql(sql))
metrics.aggregateTaskMetricsJavaMap()
case _ =>
df = sparkSession.sql(sql)
new java.util.HashMap[String, Long]()
}
case _ =>
df = sparkSession.sql(sql)
new java.util.HashMap[String, Long]()
private val sqlType: String = determineSqlType

def begin(metrics: Either[StageMetrics, TaskMetrics]): Unit = {
metrics match {
case Left(stageMetrics) =>
stageMetrics.begin()
case Right(taskMetrics) =>
taskMetrics.begin()
}
}

def end(metrics: Either[StageMetrics, TaskMetrics]): Unit = {
metrics match {
case Left(stageMetrics) =>
stageMetrics.end()
case Right(taskMetrics) =>
taskMetrics.end()
}
if (MapUtils.isNotEmpty(metricsMap)) {
val retMap = new util.HashMap[String, Object]()
retMap.put("execution_code", sql)
retMap.put("metrics", metricsMap)
val mapper = new ObjectMapper()
val bytes = mapper.writeValueAsBytes(retMap)
val fs = FSFactory.getFs(outputPath)
if (!fs.exists(outputPath.getParent)) fs.mkdirs(outputPath.getParent)
val out = fs.write(outputPath, true)
out.write(bytes)
IOUtils.close(out)
fs.close()
}

private def enableSparkMeasure: Boolean = {
sqlType match {
case "SELECT" | "INSERT" => true
case _ => false
}
}

def getSparkMetrics: Option[Either[StageMetrics, TaskMetrics]] = {
if (enableSparkMeasure) {
metricType match {
case "stage" => Some(Left(StageMetrics(sparkSession)))
case "task" => Some(Right(TaskMetrics(sparkSession)))
case _ => None
}
} else {
None
}
df
}

private def getDQLSqlType: String = {
def outputMetrics(metrics: Either[StageMetrics, TaskMetrics]): Unit = {
if (enableSparkMeasure) {
val metricsMap = collectMetrics(metrics)

if (MapUtils.isNotEmpty(metricsMap)) {
val retMap = new util.HashMap[String, Object]()
retMap.put("execution_code", sql)
retMap.put("metrics", metricsMap)

val mapper = new ObjectMapper()
val bytes = mapper.writeValueAsBytes(retMap)

val fs = FSFactory.getFs(outputPath)
try {
if (!fs.exists(outputPath.getParent)) fs.mkdirs(outputPath.getParent)
val out = fs.write(outputPath, true)
try {
out.write(bytes)
} finally {
IOUtils.closeQuietly(out)
}
} finally {
fs.close()
}
}
}
}

private def determineSqlType: String = {
val parser = sparkSession.sessionState.sqlParser
val logicalPlan = parser.parsePlan(sql)

val planName = logicalPlan.getClass.getSimpleName
planName match {
logicalPlan.getClass.getSimpleName match {
case "UnresolvedWith" | "Project" | "GlobalLimit" => "SELECT"
case "InsertIntoStatement" | "CreateTableAsSelectStatement" | "CreateTableAsSelect" =>
"INSERT"
case _ =>
logger.info("当前SQL解析类型为{}, 跳过sparkmeasure", planName)
case planName =>
logger.info(s"Unsupported sql type")
planName
}
}

private def collectMetrics(
metrics: Either[StageMetrics, TaskMetrics]
): java.util.Map[String, Long] = {
metrics match {
case Left(stageMetrics) =>
logger.info("StageMetrics begin executed.")
stageMetrics.aggregateStageMetricsJavaMap()

case Right(taskMetrics) =>
logger.info("TaskMetrics begin executed.")
taskMetrics.aggregateTaskMetricsJavaMap()

case _ => new util.HashMap[String, Long]()
}
}

}

0 comments on commit 66e67a3

Please sign in to comment.