From 8a56bdc5aba88302e6d3da8d5e9a1a932e14d7d6 Mon Sep 17 00:00:00 2001 From: Jun <85203301+junlee-db@users.noreply.github.com> Date: Wed, 8 May 2024 17:44:45 -0700 Subject: [PATCH] [Spark] Extend Fuzz test For Managed Commit (#3049) This PR extends Fuzz test to test managed commit features. Specifically, it adds a new event phase inside commit operation, so that we can capture the backfill as a separate operation. By doing so, it is possible that multiple commits can go through before backfill and managed commit is expected to deal with various situation to return the correct output. ## How was this patch tested? Existing fuzz tests should naturally use the extended backfill phases. --- .../sql/delta/OptimisticTransaction.scala | 10 ++- .../delta/TransactionExecutionObserver.scala | 5 ++ .../fuzzer/OptimisticTransactionPhases.scala | 7 ++- ...eLockingTransactionExecutionObserver.scala | 30 ++++++--- ...actBatchBackfillingCommitOwnerClient.scala | 4 +- .../TransactionExecutionObserverSuite.scala | 15 ++++- .../TransactionExecutionTestMixin.scala | 11 +++- .../managedcommit/ManagedCommitSuite.scala | 62 +++++++++---------- .../ManagedCommitTestUtils.scala | 43 +++++++------ 9 files changed, 121 insertions(+), 66 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala b/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala index 6ccb2a7d1b1..8d9f37686ba 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala @@ -1427,8 +1427,9 @@ trait OptimisticTransactionImpl extends TransactionalWrite } val updatedActions = UpdatedActions( commitInfo, metadata, protocol, snapshot.metadata, snapshot.protocol) - val commitResponse = + val commitResponse = TransactionExecutionObserver.withObserver(executionObserver) { effectiveTableCommitOwnerClient.commit(attemptVersion, jsonActions, updatedActions) + } // TODO(managed-commits): Use the right timestamp method on top of CommitInfo once ICT is // merged. // If the metadata didn't change, `newMetadata` is empty, and we can re-use the old id. @@ -2094,9 +2095,12 @@ trait OptimisticTransactionImpl extends TransactionalWrite commitVersion: Long, actions: Iterator[String], updatedActions: UpdatedActions): CommitResponse = { + // Get thread local observer for Fuzz testing purpose. + val executionObserver = TransactionExecutionObserver.threadObserver.get() val commitFile = util.FileNames.unsafeDeltaFile(logPath, commitVersion) val commitFileStatus = doCommit(logStore, hadoopConf, logPath, commitFile, commitVersion, actions) + executionObserver.beginBackfill() // TODO(managed-commits): Integrate with ICT and pass the correct commitTimestamp CommitResponse(Commit( commitVersion, @@ -2174,7 +2178,9 @@ trait OptimisticTransactionImpl extends TransactionalWrite ): Commit = { val updatedActions = currentTransactionInfo.getUpdatedActions(snapshot.metadata, snapshot.protocol) - val commitResponse = tableCommitOwnerClient.commit(attemptVersion, jsonActions, updatedActions) + val commitResponse = TransactionExecutionObserver.withObserver(executionObserver) { + tableCommitOwnerClient.commit(attemptVersion, jsonActions, updatedActions) + } // TODO(managed-commits): Use the right timestamp method on top of CommitInfo once ICT is // merged. val commitTimestamp = commitResponse.getCommit.getFileStatus.getModificationTime diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/TransactionExecutionObserver.scala b/spark/src/main/scala/org/apache/spark/sql/delta/TransactionExecutionObserver.scala index c3bc74382d1..11d5bbc3d35 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/TransactionExecutionObserver.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/TransactionExecutionObserver.scala @@ -66,6 +66,9 @@ trait TransactionExecutionObserver /** Called before the first `doCommit` attempt. */ def beginDoCommit(): Unit + /** Called after publishing the commit file but before the `backfill` attempt. */ + def beginBackfill(): Unit + /** Called once a commit succeeded. */ def transactionCommitted(): Unit @@ -120,6 +123,8 @@ object NoOpTransactionExecutionObserver extends TransactionExecutionObserver { override def beginDoCommit(): Unit = () + override def beginBackfill(): Unit = () + override def transactionCommitted(): Unit = () override def transactionAborted(): Unit = () diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/fuzzer/OptimisticTransactionPhases.scala b/spark/src/main/scala/org/apache/spark/sql/delta/fuzzer/OptimisticTransactionPhases.scala index 2e0cb1d51bb..bb59574728e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/fuzzer/OptimisticTransactionPhases.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/fuzzer/OptimisticTransactionPhases.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.delta.fuzzer case class OptimisticTransactionPhases( initialPhase: ExecutionPhaseLock, preparePhase: ExecutionPhaseLock, - commitPhase: ExecutionPhaseLock) + commitPhase: ExecutionPhaseLock, + backfillPhase: ExecutionPhaseLock) object OptimisticTransactionPhases { @@ -28,6 +29,7 @@ object OptimisticTransactionPhases { final val INITIAL_PHASE_LABEL = PREFIX + "INIT" final val PREPARE_PHASE_LABEL = PREFIX + "PREPARE" final val COMMIT_PHASE_LABEL = PREFIX + "COMMIT" + final val BACKFILL_PHASE_LABEL = PREFIX + "BACKFILL" def forName(txnName: String): OptimisticTransactionPhases = { @@ -37,6 +39,7 @@ object OptimisticTransactionPhases { OptimisticTransactionPhases( initialPhase = ExecutionPhaseLock(toTxnPhaseLabel(INITIAL_PHASE_LABEL)), preparePhase = ExecutionPhaseLock(toTxnPhaseLabel(PREPARE_PHASE_LABEL)), - commitPhase = ExecutionPhaseLock(toTxnPhaseLabel(COMMIT_PHASE_LABEL))) + commitPhase = ExecutionPhaseLock(toTxnPhaseLabel(COMMIT_PHASE_LABEL)), + backfillPhase = ExecutionPhaseLock(toTxnPhaseLabel(BACKFILL_PHASE_LABEL))) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/fuzzer/PhaseLockingTransactionExecutionObserver.scala b/spark/src/main/scala/org/apache/spark/sql/delta/fuzzer/PhaseLockingTransactionExecutionObserver.scala index f93b843b92b..b6fcdfcfb5b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/fuzzer/PhaseLockingTransactionExecutionObserver.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/fuzzer/PhaseLockingTransactionExecutionObserver.scala @@ -26,7 +26,8 @@ private[delta] class PhaseLockingTransactionExecutionObserver( override val phaseLocks: Seq[ExecutionPhaseLock] = Seq( phases.initialPhase, phases.preparePhase, - phases.commitPhase) + phases.commitPhase, + phases.backfillPhase) /** * When set to true this observer will automatically update the thread's current observer to @@ -42,36 +43,49 @@ private[delta] class PhaseLockingTransactionExecutionObserver( override def preparingCommit[T](f: => T): T = phases.preparePhase.execute(f) - override def beginDoCommit(): Unit = phases.commitPhase.waitToEnter() + override def beginDoCommit(): Unit = { + phases.commitPhase.waitToEnter() + } + + override def beginBackfill(): Unit = { + phases.commitPhase.leave() + phases.backfillPhase.waitToEnter() + } override def transactionCommitted(): Unit = { if (nextObserver.nonEmpty && autoAdvanceNextObserver) { waitForCommitPhaseAndAdvanceToNextObserver() } else { - phases.commitPhase.leave() + phases.backfillPhase.leave() } } override def transactionAborted(): Unit = { - if (!phases.commitPhase.hasEntered) { - phases.commitPhase.waitToEnter() + if (!phases.commitPhase.hasLeft) { + if (!phases.commitPhase.hasEntered) { + phases.commitPhase.waitToEnter() + } + phases.commitPhase.leave() + } + if (!phases.backfillPhase.hasEntered) { + phases.backfillPhase.waitToEnter() } if (nextObserver.nonEmpty && autoAdvanceNextObserver) { waitForCommitPhaseAndAdvanceToNextObserver() } else { - phases.commitPhase.leave() + phases.backfillPhase.leave() } } /* - * Wait for the commit phase to pass but do not unblock it so that callers can write tests + * Wait for the backfill phase to pass but do not unblock it so that callers can write tests * that capture errors caused by code between the end of the last txn and the start of the * new txn. After the commit phase is passed, update the thread observer of the thread to * the next observer. */ def waitForCommitPhaseAndAdvanceToNextObserver(): Unit = { require(nextObserver.nonEmpty) - phases.commitPhase.waitToLeave() + phases.backfillPhase.waitToLeave() advanceToNextThreadObserver() } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/AbstractBatchBackfillingCommitOwnerClient.scala b/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/AbstractBatchBackfillingCommitOwnerClient.scala index b3e92357825..9187f6b539f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/AbstractBatchBackfillingCommitOwnerClient.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/AbstractBatchBackfillingCommitOwnerClient.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.delta.managedcommit import java.nio.file.FileAlreadyExistsException import java.util.UUID -import org.apache.spark.sql.delta.DeltaLog +import org.apache.spark.sql.delta.TransactionExecutionObserver import org.apache.spark.sql.delta.actions.CommitInfo import org.apache.spark.sql.delta.actions.Metadata import org.apache.spark.sql.delta.storage.LogStore @@ -62,6 +62,7 @@ trait AbstractBatchBackfillingCommitOwnerClient extends CommitOwnerClient with L commitVersion: Long, actions: Iterator[String], updatedActions: UpdatedActions): CommitResponse = { + val executionObserver = TransactionExecutionObserver.threadObserver.get() val tablePath = ManagedCommitUtils.getTablePath(logPath) if (commitVersion == 0) { throw CommitFailedException( @@ -99,6 +100,7 @@ trait AbstractBatchBackfillingCommitOwnerClient extends CommitOwnerClient with L val mcToFsConversion = isManagedCommitToFSConversion(commitVersion, updatedActions) // Backfill if needed + executionObserver.beginBackfill() if (batchSize <= 1) { // Always backfill when batch size is configured as 1 backfill(logStore, hadoopConf, logPath, commitVersion, fileStatus) diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/concurrency/TransactionExecutionObserverSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/concurrency/TransactionExecutionObserverSuite.scala index 423a26e7707..ce7c80a48e5 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/concurrency/TransactionExecutionObserverSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/concurrency/TransactionExecutionObserverSuite.scala @@ -60,10 +60,12 @@ class TransactionExecutionObserverSuite extends QueryTest with SharedSparkSessio assert(observer.phases.initialPhase.hasLeft) assert(!observer.phases.preparePhase.hasEntered) assert(!observer.phases.commitPhase.hasEntered) + assert(!observer.phases.backfillPhase.hasEntered) // allow things to progress observer.phases.preparePhase.entryBarrier.unblock() observer.phases.commitPhase.entryBarrier.unblock() + observer.phases.backfillPhase.entryBarrier.unblock() val removedFiles = txn.snapshot.allFiles.collect().map(_.remove).toSeq txn.commit(removedFiles, DeltaOperations.ManualUpdate) @@ -71,6 +73,8 @@ class TransactionExecutionObserverSuite extends QueryTest with SharedSparkSessio assert(observer.phases.preparePhase.hasLeft) assert(observer.phases.commitPhase.hasEntered) assert(observer.phases.commitPhase.hasLeft) + assert(observer.phases.backfillPhase.hasEntered) + assert(observer.phases.backfillPhase.hasLeft) } } val res = spark.read.format("delta").load(tempPath).collect() @@ -118,6 +122,10 @@ class TransactionExecutionObserverSuite extends QueryTest with SharedSparkSessio observer.phases.commitPhase.entryBarrier.unblock() busyWaitFor(observer.phases.commitPhase.hasEntered, timeout) busyWaitFor(observer.phases.commitPhase.hasLeft, timeout) + + observer.phases.backfillPhase.entryBarrier.unblock() + busyWaitFor(observer.phases.backfillPhase.hasEntered, timeout) + busyWaitFor(observer.phases.backfillPhase.hasLeft, timeout) testThread.join(timeout.toMillis) assert(!testThread.isAlive) // should have passed the barrier and completed @@ -146,6 +154,7 @@ class TransactionExecutionObserverSuite extends QueryTest with SharedSparkSessio // allow things to progress observer.phases.preparePhase.entryBarrier.unblock() observer.phases.commitPhase.entryBarrier.unblock() + observer.phases.backfillPhase.entryBarrier.unblock() val removedFiles = txn.snapshot.allFiles.collect().map(_.remove).toSeq txn.commit(removedFiles, DeltaOperations.ManualUpdate) } @@ -155,6 +164,7 @@ class TransactionExecutionObserverSuite extends QueryTest with SharedSparkSessio // allow things to progress observer.phases.preparePhase.entryBarrier.unblock() observer.phases.commitPhase.entryBarrier.unblock() + observer.phases.backfillPhase.entryBarrier.unblock() val removedFiles = txn.snapshot.allFiles.collect().map(_.remove).toSeq txn.commit(removedFiles, DeltaOperations.ManualUpdate) } @@ -210,11 +220,14 @@ class TransactionExecutionObserverSuite extends QueryTest with SharedSparkSessio observer.phases.preparePhase.entryBarrier.unblock() busyWaitFor(observer.phases.preparePhase.hasLeft, timeout) assert(!observer.phases.commitPhase.hasEntered) + assert(!observer.phases.backfillPhase.hasEntered) assertOperationNotVisible() observer.phases.commitPhase.entryBarrier.unblock() busyWaitFor(observer.phases.commitPhase.hasLeft, timeout) + observer.phases.backfillPhase.entryBarrier.unblock() + busyWaitFor(observer.phases.backfillPhase.hasLeft, timeout) testThread.join(timeout.toMillis) assert(!testThread.isAlive) // should have passed the barrier and completed @@ -243,7 +256,7 @@ class TransactionExecutionObserverSuite extends QueryTest with SharedSparkSessio TransactionExecutionObserver.withObserver(observer) { deltaLog.withNewTransaction { txn => - observer.phases.commitPhase.exitBarrier.unblock() + observer.phases.backfillPhase.exitBarrier.unblock() val removedFiles = txn.snapshot.allFiles.collect().map(_.remove).toSeq txn.commit(removedFiles, DeltaOperations.ManualUpdate) } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/concurrency/TransactionExecutionTestMixin.scala b/spark/src/test/scala/org/apache/spark/sql/delta/concurrency/TransactionExecutionTestMixin.scala index 8290ec6dcc3..d7098b88714 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/concurrency/TransactionExecutionTestMixin.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/concurrency/TransactionExecutionTestMixin.scala @@ -124,6 +124,7 @@ trait TransactionExecutionTestMixin { observer.phases.initialPhase.entryBarrier.unblock() observer.phases.preparePhase.entryBarrier.unblock() observer.phases.commitPhase.entryBarrier.unblock() + observer.phases.backfillPhase.entryBarrier.unblock() } /** @@ -145,11 +146,13 @@ trait TransactionExecutionTestMixin { // B starts and commits unblockAllPhases(observerB) - busyWaitFor(observerB.phases.commitPhase.hasLeft, timeout) + busyWaitFor(observerB.phases.backfillPhase.hasLeft, timeout) // A commits observerA.phases.commitPhase.entryBarrier.unblock() busyWaitFor(observerA.phases.commitPhase.hasLeft, timeout) + observerA.phases.backfillPhase.entryBarrier.unblock() + busyWaitFor(observerA.phases.backfillPhase.hasLeft, timeout) } (usageRecords, futureA, futureB) } @@ -179,15 +182,17 @@ trait TransactionExecutionTestMixin { // B starts and commits unblockAllPhases(observerB) - busyWaitFor(observerB.phases.commitPhase.hasLeft, timeout) + busyWaitFor(observerB.phases.backfillPhase.hasLeft, timeout) // C starts and commits unblockAllPhases(observerC) - busyWaitFor(observerC.phases.commitPhase.hasLeft, timeout) + busyWaitFor(observerC.phases.backfillPhase.hasLeft, timeout) // A commits observerA.phases.commitPhase.entryBarrier.unblock() busyWaitFor(observerA.phases.commitPhase.hasLeft, timeout) + observerA.phases.backfillPhase.entryBarrier.unblock() + busyWaitFor(observerA.phases.backfillPhase.hasLeft, timeout) } (usageRecords, futureA, futureB, futureC) } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitSuite.scala index 46d2aac0085..9cdd62a1295 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitSuite.scala @@ -136,11 +136,11 @@ class ManagedCommitSuite Seq(2).toDF.write.format("delta").mode("overwrite").save(tablePath) // version 1 Seq(3).toDF.write.format("delta").mode("append").save(tablePath) // version 2 DeltaLog.clearCache() - commitOwnerClient.numGetCommitsCalled = 0 + commitOwnerClient.numGetCommitsCalled.set(0) import testImplicits._ val result1 = sql(s"SELECT * FROM delta.`$tablePath`").collect() assert(result1.length === 2 && result1.toSet === Set(Row(2), Row(3))) - assert(commitOwnerClient.numGetCommitsCalled === 2) + assert(commitOwnerClient.numGetCommitsCalled.get === 2) } } @@ -293,8 +293,8 @@ class ManagedCommitSuite resetMetrics() Seq(2).toDF.write.format("delta").mode("append").save(tablePath) // version 4 Seq(3).toDF.write.format("delta").mode("append").save(tablePath) // version 5 - assert((cs1.numCommitsCalled, cs2.numCommitsCalled) === (0, 2)) - assert((cs1.numGetCommitsCalled, cs2.numGetCommitsCalled) === (0, 2)) + assert((cs1.numCommitsCalled.get, cs2.numCommitsCalled.get) === (0, 2)) + assert((cs1.numGetCommitsCalled.get, cs2.numGetCommitsCalled.get) === (0, 2)) // Step-5: Read the table again and assert that the right APIs are used resetMetrics() @@ -305,7 +305,7 @@ class ManagedCommitSuite assert((builder1.numBuildCalled, builder2.numBuildCalled) === (0, 0)) // Since this is dataframe read, so we invoke deltaLog.update() twice and so GetCommits API // is called twice. - assert((cs1.numGetCommitsCalled, cs2.numGetCommitsCalled) === (0, 2)) + assert((cs1.numGetCommitsCalled.get, cs2.numGetCommitsCalled.get) === (0, 2)) // Step-6: Clear cache and simulate cold read again. // We will firstly create snapshot from listing: 0.json, 1.json, 2.json. @@ -316,7 +316,7 @@ class ManagedCommitSuite resetMetrics() assert( sql(s"SELECT * FROM delta.`$tablePath`").collect().toSet === (0 to 3).map(Row(_)).toSet) - assert((cs1.numGetCommitsCalled, cs2.numGetCommitsCalled) === (0, 2)) + assert((cs1.numGetCommitsCalled.get, cs2.numGetCommitsCalled.get) === (0, 2)) assert((builder1.numBuildCalled, builder2.numBuildCalled) === (0, 2)) } } @@ -351,8 +351,8 @@ class ManagedCommitSuite managedCommitTableConf: Map[String, String], startVersion: Option[Long], endVersion: Option[Long]): GetCommitsResponse = { - if (failAttempts.contains(numGetCommitsCalled + 1)) { - numGetCommitsCalled += 1 + if (failAttempts.contains(numGetCommitsCalled.get + 1)) { + numGetCommitsCalled.incrementAndGet() throw new IllegalStateException("Injected failure") } super.getCommits(logPath, managedCommitTableConf, startVersion, endVersion) @@ -429,7 +429,7 @@ class ManagedCommitSuite resetMetrics() cs2.failAttempts = Set(1, 2) // fail 0th and 1st attempt, 2nd attempt will succeed. val ex1 = intercept[CommitOwnerGetCommitsFailedException] { oldDeltaLog.update() } - assert((cs1.numGetCommitsCalled, cs2.numGetCommitsCalled) === (1, 1)) + assert((cs1.numGetCommitsCalled.get, cs2.numGetCommitsCalled.get) === (1, 1)) assert(ex1.getMessage.contains("Injected failure")) assert(oldDeltaLog.unsafeVolatileSnapshot.version == 1) assert(oldDeltaLog.getCapturedSnapshot().updateTimestamp != clock.getTimeMillis()) @@ -437,7 +437,7 @@ class ManagedCommitSuite // Attempt-2 // 2nd update also fails val ex2 = intercept[CommitOwnerGetCommitsFailedException] { oldDeltaLog.update() } - assert((cs1.numGetCommitsCalled, cs2.numGetCommitsCalled) === (2, 2)) + assert((cs1.numGetCommitsCalled.get, cs2.numGetCommitsCalled.get) === (2, 2)) assert(ex2.getMessage.contains("Injected failure")) assert(oldDeltaLog.unsafeVolatileSnapshot.version == 1) assert(oldDeltaLog.getCapturedSnapshot().updateTimestamp != clock.getTimeMillis()) @@ -445,7 +445,7 @@ class ManagedCommitSuite // Attempt-3: 3rd update succeeds clock.advance(500) assert(oldDeltaLog.update().version === 5) - assert((cs1.numGetCommitsCalled, cs2.numGetCommitsCalled) === (3, 3)) + assert((cs1.numGetCommitsCalled.get, cs2.numGetCommitsCalled.get) === (3, 3)) assert(oldDeltaLog.getCapturedSnapshot().updateTimestamp == clock.getTimeMillis()) } } @@ -702,9 +702,9 @@ class ManagedCommitSuite log.checkpoint() log.startTransaction().commitManually(createTestAddFile("f2")) - assert(trackingCommitOwnerClient.numCommitsCalled > 0) - assert(trackingCommitOwnerClient.numGetCommitsCalled > 0) - assert(trackingCommitOwnerClient.numBackfillToVersionCalled > 0) + assert(trackingCommitOwnerClient.numCommitsCalled.get > 0) + assert(trackingCommitOwnerClient.numGetCommitsCalled.get > 0) + assert(trackingCommitOwnerClient.numBackfillToVersionCalled.get > 0) } } @@ -751,8 +751,8 @@ class ManagedCommitSuite assert(log.unsafeVolatileSnapshot.metadata.managedCommitTableConf === Map.empty) // upgrade commit always filesystem based assert(fs.exists(FileNames.unsafeDeltaFile(log.logPath, upgradeStartVersion))) - assert(Seq(cs1, cs2).map(_.numCommitsCalled) == Seq(0, 0)) - assert(Seq(cs1, cs2).map(_.numRegisterTableCalled) == Seq(1, 0)) + assert(Seq(cs1, cs2).map(_.numCommitsCalled.get) == Seq(0, 0)) + assert(Seq(cs1, cs2).map(_.numRegisterTableCalled.get) == Seq(1, 0)) // Do couple of commits on the managed-commit table // [upgradeExistingTable = false] Commit-1/2 @@ -765,7 +765,7 @@ class ManagedCommitSuite assert(log.unsafeVolatileSnapshot.metadata.managedCommitOwnerName.nonEmpty) assert(log.unsafeVolatileSnapshot.metadata.managedCommitOwnerConf === Map.empty) assert(log.unsafeVolatileSnapshot.metadata.managedCommitTableConf === Map.empty) - assert(cs1.numCommitsCalled === versionOffset) + assert(cs1.numCommitsCalled.get === versionOffset) val backfillExpected = if (version % backfillInterval == 0) true else false assert(fs.exists(FileNames.unsafeDeltaFile(log.logPath, version)) == backfillExpected) } @@ -782,8 +782,8 @@ class ManagedCommitSuite assert(log.unsafeVolatileSnapshot.metadata.managedCommitTableConf === Map.empty) assert(log.unsafeVolatileSnapshot.metadata === newMetadata2) // This must have increased by 1 as downgrade commit happens via CommitOwnerClient. - assert(Seq(cs1, cs2).map(_.numCommitsCalled) == Seq(3, 0)) - assert(Seq(cs1, cs2).map(_.numRegisterTableCalled) == Seq(1, 0)) + assert(Seq(cs1, cs2).map(_.numCommitsCalled.get) == Seq(3, 0)) + assert(Seq(cs1, cs2).map(_.numRegisterTableCalled.get) == Seq(1, 0)) (0 to 3).foreach { version => assert(fs.exists(FileNames.unsafeDeltaFile(log.logPath, version))) } @@ -804,8 +804,8 @@ class ManagedCommitSuite expectedFileNames.map(name => createTestAddFile(name, dataChange = false))) // commit-owner should not be invoked for commit API. // Register table API should not be called until the end - assert(Seq(cs1, cs2).map(_.numCommitsCalled) == Seq(3, 0)) - assert(Seq(cs1, cs2).map(_.numRegisterTableCalled) == Seq(1, 0)) + assert(Seq(cs1, cs2).map(_.numCommitsCalled.get) == Seq(3, 0)) + assert(Seq(cs1, cs2).map(_.numRegisterTableCalled.get) == Seq(1, 0)) // 4th file is directly written to FS in backfilled way. assert(fs.exists(FileNames.unsafeDeltaFile(log.logPath, upgradeStartVersion + 4))) @@ -826,16 +826,16 @@ class ManagedCommitSuite expectedFileNames = Set("1", "2", "post-upgrade-file", "upgrade-2-file") assert(log.unsafeVolatileSnapshot.allFiles.collect().toSet === expectedFileNames.map(name => createTestAddFile(name, dataChange = false))) - assert(Seq(cs1, cs2).map(_.numCommitsCalled) == Seq(3, 0)) - assert(Seq(cs1, cs2).map(_.numRegisterTableCalled) == Seq(1, 1)) + assert(Seq(cs1, cs2).map(_.numCommitsCalled.get) == Seq(3, 0)) + assert(Seq(cs1, cs2).map(_.numRegisterTableCalled.get) == Seq(1, 1)) // Make 1 more commit, this should go to new owner log.startTransaction().commitManually(newMetadata3, createTestAddFile("4")) expectedFileNames = Set("1", "2", "post-upgrade-file", "upgrade-2-file", "4") assert(log.unsafeVolatileSnapshot.allFiles.collect().toSet === expectedFileNames.map(name => createTestAddFile(name, dataChange = false))) - assert(Seq(cs1, cs2).map(_.numCommitsCalled) == Seq(3, 1)) - assert(Seq(cs1, cs2).map(_.numRegisterTableCalled) == Seq(1, 1)) + assert(Seq(cs1, cs2).map(_.numCommitsCalled.get) == Seq(3, 1)) + assert(Seq(cs1, cs2).map(_.numRegisterTableCalled.get) == Seq(1, 1)) assert(log.unsafeVolatileSnapshot.version === upgradeStartVersion + 6) } } @@ -913,8 +913,8 @@ class ManagedCommitSuite Some(oldProtocol.readerFeatures.getOrElse(Set.empty) + V2CheckpointTableFeature.name), writerFeatures = Some(oldProtocol.writerFeatures.getOrElse(Set.empty) + ManagedCommitTableFeature.name)) - assert(cs.numRegisterTableCalled === 0) - assert(cs.numCommitsCalled === 0) + assert(cs.numRegisterTableCalled.get === 0) + assert(cs.numCommitsCalled.get === 0) val txn = log.startTransaction() txn.updateMetadataForNewTable(newMetadata) @@ -926,8 +926,8 @@ class ManagedCommitSuite Map.empty, Map.empty) log = DeltaLog.forTable(spark, tablePath) - assert(cs.numRegisterTableCalled === 1) - assert(cs.numCommitsCalled === 0) + assert(cs.numRegisterTableCalled.get === 1) + assert(cs.numCommitsCalled.get === 0) assert(log.unsafeVolatileSnapshot.version === 2L) Seq(V2CheckpointTableFeature, ManagedCommitTableFeature).foreach { feature => @@ -940,8 +940,8 @@ class ManagedCommitSuite assert(log.unsafeVolatileSnapshot.metadata.managedCommitTableConf === Map.empty) Seq(3).toDF.write.mode("append").format("delta").save(tablePath) - assert(cs.numRegisterTableCalled === 1) - assert(cs.numCommitsCalled === 1) + assert(cs.numRegisterTableCalled.get === 1) + assert(cs.numCommitsCalled.get === 1) assert(log.unsafeVolatileSnapshot.version === 3L) assert(log.unsafeVolatileSnapshot.tableCommitOwnerClientOpt.nonEmpty) diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitTestUtils.scala b/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitTestUtils.scala index 477b6ae987f..496803dd454 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitTestUtils.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitTestUtils.scala @@ -16,6 +16,8 @@ package org.apache.spark.sql.delta.managedcommit +import java.util.concurrent.atomic.AtomicInteger + import org.apache.spark.sql.delta.{DeltaConfigs, DeltaLog, DeltaTestUtilsBase} import org.apache.spark.sql.delta.DeltaConfigs.MANAGED_COMMIT_OWNER_NAME import org.apache.spark.sql.delta.actions.{Action, CommitInfo, Metadata, Protocol} @@ -46,7 +48,7 @@ trait ManagedCommitTestUtils * Runs the function `f` with managed commits default properties unset. * Any table created in function `f`` won't have managed commits enabled by default. */ - def withoutManagedCommitsDefaultTableProperties(f: => Unit): Unit = { + def withoutManagedCommitsDefaultTableProperties[T](f: => T): T = { val commitOwnerKey = MANAGED_COMMIT_OWNER_NAME.defaultTablePropertyKey val oldCommitOwnerValue = spark.conf.getOption(commitOwnerKey) spark.conf.unset(commitOwnerKey) @@ -128,31 +130,36 @@ class PredictableUuidInMemoryCommitOwnerClient(batchSize: Long) } } +object TrackingCommitOwnerClient { + private val insideOperation = new ThreadLocal[Boolean] { + override def initialValue(): Boolean = false + } +} + class TrackingCommitOwnerClient(delegatingCommitOwnerClient: InMemoryCommitOwner) extends CommitOwnerClient { - var numCommitsCalled: Int = 0 - var numGetCommitsCalled: Int = 0 - var numBackfillToVersionCalled: Int = 0 - var numRegisterTableCalled: Int = 0 - var insideOperation: Boolean = false + val numCommitsCalled = new AtomicInteger(0) + val numGetCommitsCalled = new AtomicInteger(0) + val numBackfillToVersionCalled = new AtomicInteger(0) + val numRegisterTableCalled = new AtomicInteger(0) - def recordOperation[T](op: String)(f: => T): T = synchronized { - val oldInsideOperation = insideOperation + def recordOperation[T](op: String)(f: => T): T = { + val oldInsideOperation = TrackingCommitOwnerClient.insideOperation.get() try { - if (!insideOperation) { + if (!TrackingCommitOwnerClient.insideOperation.get()) { op match { - case "commit" => numCommitsCalled += 1 - case "getCommits" => numGetCommitsCalled += 1 - case "backfillToVersion" => numBackfillToVersionCalled += 1 - case "registerTable" => numRegisterTableCalled += 1 + case "commit" => numCommitsCalled.incrementAndGet() + case "getCommits" => numGetCommitsCalled.incrementAndGet() + case "backfillToVersion" => numBackfillToVersionCalled.incrementAndGet() + case "registerTable" => numRegisterTableCalled.incrementAndGet() case _ => () } } - insideOperation = true + TrackingCommitOwnerClient.insideOperation.set(true) f } finally { - insideOperation = oldInsideOperation + TrackingCommitOwnerClient.insideOperation.set(oldInsideOperation) } } @@ -191,9 +198,9 @@ class TrackingCommitOwnerClient(delegatingCommitOwnerClient: InMemoryCommitOwner override def semanticEquals(other: CommitOwnerClient): Boolean = this == other def reset(): Unit = { - numCommitsCalled = 0 - numGetCommitsCalled = 0 - numBackfillToVersionCalled = 0 + numCommitsCalled.set(0) + numGetCommitsCalled.set(0) + numBackfillToVersionCalled.set(0) } override def registerTable(