Skip to content

Commit

Permalink
[Spark] Handle case when Checkpoints.findLastCompleteCheckpoint is pa…
Browse files Browse the repository at this point in the history
…ssed MAX_VALUE (delta-io#3105)

#### Which Delta project/connector is this regarding?
<!--
Please add the component selected below to the beginning of the pull
request title
For example: [Spark] Title of my pull request
-->

- [X] Spark
- [ ] Standalone
- [ ] Flink
- [ ] Kernel
- [ ] Other (fill in here)

## Description

Fixes an issue where `Checkpoints.findLastCompleteCheckpoint` goes into
an almost infinite loop if it is passed a Checkpoint.MAX_VALUE.

## How was this patch tested?

UT

## Does this PR introduce _any_ user-facing changes?

No
  • Loading branch information
prakharjain09 authored May 20, 2024
1 parent 3af4335 commit 57df2c0
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 26 deletions.
49 changes: 44 additions & 5 deletions spark/src/main/scala/org/apache/spark/sql/delta/Checkpoints.scala
Original file line number Diff line number Diff line change
Expand Up @@ -429,10 +429,46 @@ trait Checkpoints extends DeltaLogging {
*/
private[delta] def findLastCompleteCheckpointBefore(
checkpointInstance: Option[CheckpointInstance] = None): Option[CheckpointInstance] = {
val upperBoundCv = checkpointInstance.filterNot(_.version < 0).getOrElse {
logInfo(s"Try to find Delta last complete checkpoint")
return findLastCompleteCheckpoint()
}
val eventData = mutable.Map[String, String]()
val startTimeMs = System.currentTimeMillis()
def sendUsageLog(): Unit = {
eventData("totalTimeTakenMs") = (System.currentTimeMillis() - startTimeMs).toString
recordDeltaEvent(
self, opType = "delta.findLastCompleteCheckpointBefore", data = eventData.toMap)
}
try {
val resultOpt = findLastCompleteCheckpointBeforeInternal(eventData, checkpointInstance)
eventData("resultantCheckpointVersion") = resultOpt.map(_.version).getOrElse(-1L).toString
sendUsageLog()
resultOpt
} catch {
case e@(NonFatal(_) | _: InterruptedException | _: java.io.InterruptedIOException |
_: java.nio.channels.ClosedByInterruptException) =>
eventData("exception") = Utils.exceptionString(e)
sendUsageLog()
throw e
}
}

private def findLastCompleteCheckpointBeforeInternal(
eventData: mutable.Map[String, String],
checkpointInstance: Option[CheckpointInstance]): Option[CheckpointInstance] = {
val upperBoundCv =
checkpointInstance
// If someone passes the upperBound as 0 or sentinel value, we should not do backward
// listing. Instead we should list the entire directory from 0 and return the latest
// available checkpoint.
.filterNot(cv => cv.version < 0 || cv.version == CheckpointInstance.MaxValue.version)
.getOrElse {
logInfo(s"Try to find Delta last complete checkpoint")
eventData("listingFromZero") = true.toString
return findLastCompleteCheckpoint()
}
eventData("efficientBackwardListingEnabled") = true.toString
eventData("upperBoundVersion") = upperBoundCv.version.toString
eventData("upperBoundCheckpointType") = upperBoundCv.format.name
var iterations: Long = 0L
var numFilesScanned: Long = 0L
logInfo(s"Try to find Delta last complete checkpoint before version ${upperBoundCv.version}")
var listingEndVersion = upperBoundCv.version

Expand All @@ -446,9 +482,12 @@ trait Checkpoints extends DeltaLogging {
// |
// latest checkpoint
while (listingEndVersion >= 0) {
iterations += 1
eventData("iterations") = iterations.toString
val listingStartVersion = math.max(0, listingEndVersion - 1000)
val checkpoints = store
.listFrom(listingPrefix(logPath, listingStartVersion), newDeltaHadoopConf())
.map { file => numFilesScanned += 1 ; file }
.collect {
// Also collect delta files from the listing result so that the next takeWhile helps us
// terminate iterator early if no checkpoint exists upto the `listingEndVersion`
Expand All @@ -471,6 +510,7 @@ trait Checkpoints extends DeltaLogging {
.toArray
val lastCheckpoint =
getLatestCompleteCheckpointFromList(checkpoints, Some(upperBoundCv.version))
eventData("numFilesScanned") = numFilesScanned.toString
if (lastCheckpoint.isDefined) {
logInfo(s"Delta checkpoint is found at version ${lastCheckpoint.get.version}")
return lastCheckpoint
Expand All @@ -494,7 +534,6 @@ trait Checkpoints extends DeltaLogging {
getLatestCompleteCheckpointFromList(files.map(f => CheckpointInstance(f.getPath)).toArray)
}.foldLeft(Option.empty[CheckpointInstance])((_, right) => Some(right))
// ^The foldLeft here emulates the non-existing Iterator.tailOption method.

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

package org.apache.spark.sql.delta

import com.databricks.spark.util.Log4jUsageLogger
import org.apache.spark.sql.delta.CheckpointInstance.Format
import org.apache.spark.sql.delta.DeltaTestUtils.BOOLEAN_DOMAIN
import org.apache.spark.sql.delta.managedcommit.ManagedCommitBaseSuite
import org.apache.spark.sql.delta.storage.LocalLogStore
import org.apache.spark.sql.delta.test.DeltaSQLCommandTest
import org.apache.spark.sql.delta.util.FileNames
import org.apache.spark.sql.delta.util.{FileNames, JsonUtils}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}

Expand Down Expand Up @@ -68,6 +69,18 @@ class FindLastCompleteCheckpointSuite
versions.map { version => pathToFileStatus(FileNames.checksumFile(logPath, version)) }
}

def getLastCompleteCheckpointUsageLog(f: => Unit): Map[String, String] = {
val usageRecords = Log4jUsageLogger.track {
f
}
val opType = "delta.findLastCompleteCheckpointBefore"
val records = usageRecords.filter { r =>
r.tags.get("opType").contains(opType) || r.opType.map(_.typeName).contains(opType)
}
assert(records.size === 1)
JsonUtils.fromJson[Map[String, String]](records.head.blob)
}

test("findLastCompleteCheckpoint without any argument") {
withTempDir { dir =>
val log = DeltaLog.forTable(spark, dir.getAbsolutePath)
Expand All @@ -79,14 +92,20 @@ class FindLastCompleteCheckpointSuite
commitFiles(logPath, 0L to 3000) ++
singleCheckpointFiles(logPath, Seq(100, 200, 1000, 2000))
)
assert(log.findLastCompleteCheckpointBefore().contains(CheckpointInstance(version = 2000)))
val eventData1 = getLastCompleteCheckpointUsageLog {
assert(log.findLastCompleteCheckpointBefore().contains(CheckpointInstance(version = 2000)))
}
assert(!eventData1.contains("iterations"))
assert(logStore.listFromCount == 1)
assert(logStore.elementsConsumedFromListFromIter == 3005)
logStore.reset()

// Case-2: No checkpoint exists in table dir
logStore.customListingResult = Some(commitFiles(logPath, 0L to 3000))
assert(log.findLastCompleteCheckpointBefore().isEmpty)
val eventData2 = getLastCompleteCheckpointUsageLog {
assert(log.findLastCompleteCheckpointBefore().isEmpty)
}
assert(!eventData2.contains("iterations"))
assert(logStore.listFromCount == 1)
assert(logStore.elementsConsumedFromListFromIter == 3001)
logStore.reset()
Expand All @@ -97,8 +116,11 @@ class FindLastCompleteCheckpointSuite
singleCheckpointFiles(logPath, Seq(100, 200, 1000, 2000)) ++
multipartCheckpointFiles(logPath, Seq(300, 2000), numParts = 4)
)
assert(log.findLastCompleteCheckpointBefore().contains(
CheckpointInstance(version = 2000, Format.WITH_PARTS, numParts = Some(4))))
val eventData3 = getLastCompleteCheckpointUsageLog {
assert(log.findLastCompleteCheckpointBefore().contains(
CheckpointInstance(version = 2000, Format.WITH_PARTS, numParts = Some(4))))
}
assert(!eventData2.contains("iterations"))
assert(logStore.listFromCount == 1)
assert(logStore.elementsConsumedFromListFromIter == 3013)
logStore.reset()
Expand All @@ -117,11 +139,15 @@ class FindLastCompleteCheckpointSuite
commitFiles(logPath, 0L to 3000) ++
singleCheckpointFiles(logPath, Seq(100, 200, 1000, 2000))
)
assert(
log.findLastCompleteCheckpointBefore(Some(CheckpointInstance(version = 2000)))
.contains(CheckpointInstance(version = 1000)))
val eventData1 = getLastCompleteCheckpointUsageLog {
assert(
log.findLastCompleteCheckpointBefore(Some(CheckpointInstance(version = 2000)))
.contains(CheckpointInstance(version = 1000)))
}
assert(logStore.listFromCount == 1)
assert(logStore.elementsConsumedFromListFromIter == 1002 + 2) // commits + checkpoint
assert(eventData1("iterations") == "1")
assert(eventData1("numFilesScanned") == "1004")
logStore.reset()

// Case-2: The exact upperBound (a multi-part checkpoint) doesn't exist but another single
Expand All @@ -132,10 +158,14 @@ class FindLastCompleteCheckpointSuite
)
var sentinelCheckpoint =
CheckpointInstance(version = 2000, Format.WITH_PARTS, numParts = Some(4))
assert(log.findLastCompleteCheckpointBefore(Some(sentinelCheckpoint))
.contains(CheckpointInstance(version = 2000)))
val eventData2 = getLastCompleteCheckpointUsageLog {
assert(log.findLastCompleteCheckpointBefore(Some(sentinelCheckpoint))
.contains(CheckpointInstance(version = 2000)))
}
assert(logStore.listFromCount == 1)
assert(logStore.elementsConsumedFromListFromIter == 1002 + 2) // commits + checkpoint
assert(eventData2("iterations") == "1")
assert(eventData2("numFilesScanned") == "1004")
logStore.reset()

// Case-3: The last complete checkpoint doesn't exist in last 1000 elements and needs
Expand All @@ -144,14 +174,17 @@ class FindLastCompleteCheckpointSuite
commitFiles(logPath, 0L to 2500) ++
singleCheckpointFiles(logPath, Seq(100, 150))
)
assert(
log.findLastCompleteCheckpointBefore(2200)
.contains(CheckpointInstance(version = 150)))
val eventData3 = getLastCompleteCheckpointUsageLog {
assert(
log.findLastCompleteCheckpointBefore(2200).contains(CheckpointInstance(version = 150)))
}
assert(logStore.listFromCount == 3)
// the first listing will consume 1000 elements from 1200 to 2201 => 1002 commits
// the second listing will consume 1000 elements from 200 to 1201 => 1002 commits
// the third listing will consume 501 elements from 0 to 201 => 202 commits + 2 checkpoints
assert(logStore.elementsConsumedFromListFromIter == 2208) // commits + checkpoint
assert(eventData3("iterations") == "3")
assert(eventData3("numFilesScanned") == "2208")
logStore.reset()
}
}
Expand Down Expand Up @@ -186,12 +219,22 @@ class FindLastCompleteCheckpointSuite
commitFiles(logPath, 0L to lastCommitVersion) ++
singleCheckpointFiles(logPath, Seq(100), length = 20) ++
singleCheckpointFiles(logPath, Seq(200), length = 0))
assert(
log.findLastCompleteCheckpointBefore(sentinelInstance)
.contains(CheckpointInstance(version = 100)))
val eventData1 = getLastCompleteCheckpointUsageLog {
assert(
log.findLastCompleteCheckpointBefore(sentinelInstance)
.contains(CheckpointInstance(version = 100)))
}
assert(logStore.listFromCount == expectedListCount)
assert(logStore.elementsConsumedFromListFromIter ===
getExpectedFileCount(filesPerCheckpoint = 1))
if (passSentinelInstance) {
assert(eventData1("iterations") == expectedListCount.toString)
assert(eventData1("numFilesScanned") ==
getExpectedFileCount(filesPerCheckpoint = 1).toString)
} else {
assert(Seq("iterations", "numFilesScanned").forall(!eventData1.contains(_)))
}

logStore.reset()

// Case-2: `findLastCompleteCheckpointBefore` invoked with upperBound, with a multi-part
Expand All @@ -206,8 +249,17 @@ class FindLastCompleteCheckpointSuite
multipartCheckpointFiles(logPath, Seq(100), numParts = 4) ++
badCheckpointV200
)
assert(log.findLastCompleteCheckpointBefore(sentinelInstance)
.contains(CheckpointInstance(version = 100, Format.WITH_PARTS, numParts = Some(4))))
val eventData2 = getLastCompleteCheckpointUsageLog {
assert(log.findLastCompleteCheckpointBefore(sentinelInstance)
.contains(CheckpointInstance(version = 100, Format.WITH_PARTS, numParts = Some(4))))
}
if (passSentinelInstance) {
assert(eventData2("iterations") == expectedListCount.toString)
assert(eventData2("numFilesScanned") ==
getExpectedFileCount(filesPerCheckpoint = 4).toString)
} else {
assert(Seq("iterations", "numFilesScanned").forall(!eventData2.contains(_)))
}
assert(logStore.listFromCount == expectedListCount)
assert(logStore.elementsConsumedFromListFromIter ===
getExpectedFileCount(filesPerCheckpoint = 4))
Expand Down Expand Up @@ -245,16 +297,50 @@ class FindLastCompleteCheckpointSuite
commitFiles(logPath, 0L to lastCommitVersion) ++
multipartCheckpointFiles(logPath, Seq(100), numParts = 4, length = 20) ++
multipartCheckpointFiles(logPath, Seq(200), numParts = 4, length = 20).take(3))
assert(
log.findLastCompleteCheckpointBefore(sentinelInstance)
.contains(CheckpointInstance(100, Format.WITH_PARTS, numParts = Some(4))))
val eventData1 = getLastCompleteCheckpointUsageLog {
assert(
log.findLastCompleteCheckpointBefore(sentinelInstance)
.contains(CheckpointInstance(100, Format.WITH_PARTS, numParts = Some(4))))
}
assert(logStore.listFromCount == expectedListCount)
assert(logStore.elementsConsumedFromListFromIter ===
getExpectedFileCount(fileInCheckpointV200 = 3, filesInCheckpointV100 = 4))
if (passSentinelInstance) {
assert(eventData1("iterations") == expectedListCount.toString)
assert(eventData1("numFilesScanned") ==
getExpectedFileCount(fileInCheckpointV200 = 3, filesInCheckpointV100 = 4).toString)
} else {
assert(Seq("iterations", "numFilesScanned").forall(!eventData1.contains(_)))
}

logStore.reset()
}
}

test("findLastCompleteCheckpoint with CheckpointInstance.MAX value") {
withTempDir { dir =>
val log = DeltaLog.forTable(spark, dir.getAbsolutePath)
val logPath = log.logPath
val logStore = log.store.asInstanceOf[CustomListingLogStore]
logStore.reset()

logStore.customListingResult = Some(
commitFiles(logPath, 0L to 3000) ++
singleCheckpointFiles(logPath, Seq(100, 200, 1000, 1200))
)
val eventData = getLastCompleteCheckpointUsageLog {
assert(
log.findLastCompleteCheckpointBefore(Some(CheckpointInstance.MaxValue))
.contains(CheckpointInstance(version = 1200)))
}
assert(!eventData.contains("iterations"))
assert(!eventData.contains("upperBoundVersion"))
assert(eventData("totalTimeTakenMs").toLong > 0)
assert(logStore.listFromCount == 1)
assert(logStore.elementsConsumedFromListFromIter == 3001 + 4) // commits + checkpoint
logStore.reset()
}
}
}

/**
Expand Down

0 comments on commit 57df2c0

Please sign in to comment.