From e3169d74b1a4bfd80174ecc4a9a061de9fa9ae3b Mon Sep 17 00:00:00 2001 From: Sumeet Varma Date: Fri, 29 Mar 2024 16:09:03 -0700 Subject: [PATCH] Fix bug --- .../spark/sql/delta/DeltaHistoryManager.scala | 16 +++++++++-- .../sql/delta/DeltaHistoryManagerSuite.scala | 28 +++++++++++++------ 2 files changed, 32 insertions(+), 12 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaHistoryManager.scala b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaHistoryManager.scala index 6106a246026..c6910fbdf12 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/DeltaHistoryManager.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/DeltaHistoryManager.scala @@ -83,10 +83,20 @@ class DeltaHistoryManager( import org.apache.spark.sql.delta.implicits._ val conf = getSerializableHadoopConf val logPath = deltaLog.logPath.toString - val snapshot = endOpt.map(end => deltaLog.getSnapshotAt(end)).getOrElse(deltaLog.update()) - val commitFileProvider = DeltaCommitFileProvider(snapshot) + val currentSnapshot = deltaLog.unsafeVolatileSnapshot + val (snapshotForCommitFileProvider, end) = endOpt match { + case Some(end) if currentSnapshot.version >= end => + // Use the cache snapshot if it's fresh enough for the [start, end] query. + (currentSnapshot, end) + case _ => + // Either end doesn't exist or the currently cached snapshot isn't new enough to satisfy it. + val newSnapshot = deltaLog.update() + val endVersion = endOpt.getOrElse(newSnapshot.version).min(newSnapshot.version) + (newSnapshot, endVersion) + } + val commitFileProvider = DeltaCommitFileProvider(snapshotForCommitFileProvider) // We assume that commits are contiguous, therefore we try to load all of them in order - val info = spark.range(start, snapshot.version + 1) + val info = spark.range(start, end + 1) .mapPartitions { versions => val logStore = LogStore(SparkEnv.get.conf, conf.value) val basePath = new Path(logPath) diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaHistoryManagerSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaHistoryManagerSuite.scala index a6a52af6d2b..2da3ff29533 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaHistoryManagerSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaHistoryManagerSuite.scala @@ -25,7 +25,9 @@ import java.util.{Date, Locale} import scala.concurrent.duration._ import scala.language.implicitConversions +import com.databricks.spark.util.Log4jUsageLogger import org.apache.spark.sql.delta.DeltaTestUtils.createTestAddFile +import org.apache.spark.sql.delta.DeltaTestUtils.filterUsageRecords import org.apache.spark.sql.delta.managedcommit.ManagedCommitBaseSuite import org.apache.spark.sql.delta.sources.DeltaSQLConf import org.apache.spark.sql.delta.stats.StatsUtils @@ -613,17 +615,25 @@ abstract class DeltaHistoryManagerBase extends DeltaTimeTravelTests val start = 1540415658000L generateCommits(tblName, start, start + 20.minutes, start + 40.minutes, start + 60.minutes) val deltaLog = DeltaLog.forTable(spark, getTableLocation(tblName)) - val history_02 = deltaLog.history.getHistory(start = 0, endOpt = Some(2)) - assert(history_02.size == 3) - assert(history_02.map(_.getVersion) == Seq(2, 1, 0)) - val history_13 = deltaLog.history.getHistory(start = 1, endOpt = Some(1)) - assert(history_13.size == 1) - assert(history_13.map(_.getVersion) == Seq(1)) + def testGetHistory( + start: Long, + endOpt: Option[Long], + versions: Seq[Long], + expectedLogUpdates: Int): Unit = { + val usageRecords = Log4jUsageLogger.track { + val history = deltaLog.history.getHistory(start, endOpt) + assert(history.map(_.getVersion) == versions) + } + assert(filterUsageRecords(usageRecords, "deltaLog.update").size === expectedLogUpdates) + } - val history_2 = deltaLog.history.getHistory(start = 2, endOpt = None) - assert(history_2.size == 2) - assert(history_2.map(_.getVersion) == Seq(3, 2)) + testGetHistory(start = 0, endOpt = Some(2), versions = Seq(2, 1, 0), expectedLogUpdates = 0) + testGetHistory(start = 1, endOpt = Some(1), versions = Seq(1), expectedLogUpdates = 0) + testGetHistory(start = 2, endOpt = None, versions = Seq(3, 2), expectedLogUpdates = 1) + testGetHistory(start = 1, endOpt = Some(5), versions = Seq(3, 2, 1), expectedLogUpdates = 1) + testGetHistory(start = 4, endOpt = None, versions = Seq.empty, expectedLogUpdates = 1) + testGetHistory(start = 2, endOpt = Some(1), versions = Seq.empty, expectedLogUpdates = 0) } } }