From 3af433517bf5a42b1774cb63a8bd1d262e7d933d Mon Sep 17 00:00:00 2001 From: Dhruv Arya Date: Fri, 17 May 2024 16:15:34 -0700 Subject: [PATCH] [Spark] Pass sparkSession to commitOwnerBuilder (#3112) #### Which Delta project/connector is this regarding? - [X] Spark - [ ] Standalone - [ ] Flink - [ ] Kernel - [ ] Other (fill in here) ## Description Updates CommitOwnerBuilder.build so that it can take in a sparkSession object. This allows it to read CommitOwner-related dynamic confs from the sparkSession while building it. ## Does this PR introduce _any_ user-facing changes? No --- .../sql/delta/OptimisticTransaction.scala | 2 +- .../org/apache/spark/sql/delta/Snapshot.scala | 2 +- .../managedcommit/CommitOwnerClient.scala | 12 ++++--- .../managedcommit/InMemoryCommitOwner.scala | 4 ++- .../managedcommit/ManagedCommitUtils.scala | 11 +++++-- .../spark/sql/delta/CloneTableSuiteBase.scala | 2 +- .../spark/sql/delta/DeltaLogSuite.scala | 4 +-- .../delta/OptimisticTransactionSuite.scala | 12 ++++--- .../sql/delta/SnapshotManagementSuite.scala | 3 +- .../CommitOwnerClientSuite.scala | 33 ++++++++++--------- .../InMemoryCommitOwnerSuite.scala | 10 +++--- .../managedcommit/ManagedCommitSuite.scala | 11 ++++--- .../ManagedCommitTestUtils.scala | 3 +- 13 files changed, 63 insertions(+), 46 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala b/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala index e1d91af67bb..1a2256248e6 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/OptimisticTransaction.scala @@ -1523,7 +1523,7 @@ trait OptimisticTransactionImpl extends TransactionalWrite var newManagedCommitTableConf: Option[Map[String, String]] = None if (finalMetadata.configuration != snapshot.metadata.configuration || snapshot.version == -1L) { val newCommitOwnerClientOpt = - ManagedCommitUtils.getCommitOwnerClient(finalMetadata, finalProtocol) + ManagedCommitUtils.getCommitOwnerClient(spark, finalMetadata, finalProtocol) (newCommitOwnerClientOpt, readSnapshotTableCommitOwnerClientOpt) match { case (Some(newCommitOwnerClient), None) => // FS -> MC conversion diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/Snapshot.scala b/spark/src/main/scala/org/apache/spark/sql/delta/Snapshot.scala index 08f296752eb..5307354ebd1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/Snapshot.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/Snapshot.scala @@ -234,7 +234,7 @@ class Snapshot( */ val tableCommitOwnerClientOpt: Option[TableCommitOwnerClient] = initializeTableCommitOwner() protected def initializeTableCommitOwner(): Option[TableCommitOwnerClient] = { - ManagedCommitUtils.getTableCommitOwner(this) + ManagedCommitUtils.getTableCommitOwner(spark, this) } /** Number of columns to collect stats on for data skipping */ diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/CommitOwnerClient.scala b/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/CommitOwnerClient.scala index 6fa497a0ab6..bb669a007d9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/CommitOwnerClient.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/CommitOwnerClient.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.delta.storage.LogStore import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.spark.sql.SparkSession + /** Representation of a commit file */ case class Commit( private val version: Long, @@ -199,7 +201,7 @@ trait CommitOwnerBuilder { def getName: String /** Returns a commit-owner client based on the given conf */ - def build(conf: Map[String, String]): CommitOwnerClient + def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient } /** Factory to get the correct [[CommitOwnerClient]] for a table */ @@ -218,10 +220,12 @@ object CommitOwnerProvider { } } - /** Returns a [[CommitOwnerClient]] for the given `name` and `conf` */ + /** Returns a [[CommitOwnerClient]] for the given `name`, `conf`, and `spark` */ def getCommitOwnerClient( - name: String, conf: Map[String, String]): CommitOwnerClient = synchronized { - nameToBuilderMapping.get(name).map(_.build(conf)).getOrElse { + name: String, + conf: Map[String, String], + spark: SparkSession): CommitOwnerClient = synchronized { + nameToBuilderMapping.get(name).map(_.build(spark, conf)).getOrElse { throw new IllegalArgumentException(s"Unknown commit-owner: $name") } } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitOwner.scala b/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitOwner.scala index 9f8fc87a551..84320fe6779 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitOwner.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitOwner.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.delta.storage.LogStore import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.spark.sql.SparkSession + class InMemoryCommitOwner(val batchSize: Long) extends AbstractBatchBackfillingCommitOwnerClient { @@ -206,7 +208,7 @@ case class InMemoryCommitOwnerBuilder(batchSize: Long) extends CommitOwnerBuilde def getName: String = "in-memory" /** Returns a commit-owner based on the given conf */ - def build(conf: Map[String, String]): CommitOwnerClient = { + def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = { inMemoryStore } } diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitUtils.scala index 6d399ac3f23..67d1164b90d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitUtils.scala @@ -25,6 +25,8 @@ import org.apache.spark.sql.delta.util.FileNames.{DeltaFile, UnbackfilledDeltaFi import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.spark.sql.SparkSession + object ManagedCommitUtils extends DeltaLogging { /** @@ -111,16 +113,19 @@ object ManagedCommitUtils extends DeltaLogging { */ def getTablePath(logPath: Path): Path = logPath.getParent - def getCommitOwnerClient(metadata: Metadata, protocol: Protocol): Option[CommitOwnerClient] = { + def getCommitOwnerClient( + spark: SparkSession, metadata: Metadata, protocol: Protocol): Option[CommitOwnerClient] = { metadata.managedCommitOwnerName.map { commitOwnerStr => assert(protocol.isFeatureSupported(ManagedCommitTableFeature)) - CommitOwnerProvider.getCommitOwnerClient(commitOwnerStr, metadata.managedCommitOwnerConf) + CommitOwnerProvider.getCommitOwnerClient( + commitOwnerStr, metadata.managedCommitOwnerConf, spark) } } def getTableCommitOwner( + spark: SparkSession, snapshotDescriptor: SnapshotDescriptor): Option[TableCommitOwnerClient] = { - getCommitOwnerClient(snapshotDescriptor.metadata, snapshotDescriptor.protocol).map { + getCommitOwnerClient(spark, snapshotDescriptor.metadata, snapshotDescriptor.protocol).map { commitOwner => TableCommitOwnerClient( commitOwner, diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/CloneTableSuiteBase.scala b/spark/src/test/scala/org/apache/spark/sql/delta/CloneTableSuiteBase.scala index e3f4cb15a07..a17ed80fae2 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/CloneTableSuiteBase.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/CloneTableSuiteBase.scala @@ -37,7 +37,7 @@ import org.scalatest.Tag import org.apache.spark.{DebugFilesystem, SparkException, TaskFailedReason} import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} -import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaLogSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaLogSuite.scala index 70758705094..39d6e6ebbad 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/DeltaLogSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/DeltaLogSuite.scala @@ -503,7 +503,7 @@ class DeltaLogSuite extends QueryTest // For Managed Commit table with a commit that is not backfilled, we can't use // 00000000002.json yet. Contact commit store to get uuid file path to malform json file. val oc = CommitOwnerProvider.getCommitOwnerClient( - "tracking-in-memory", Map.empty[String, String]) + "tracking-in-memory", Map.empty[String, String], spark) val commitResponse = oc.getCommits(deltaLog.logPath, Map.empty, Some(2)) if (!commitResponse.getCommits.isEmpty) { val path = commitResponse.getCommits.last.getFileStatus.getPath @@ -602,7 +602,7 @@ class DeltaLogSuite extends QueryTest // For Managed Commit table with a commit that is not backfilled, we can't use // 00000000001.json yet. Contact commit store to get uuid file path to malform json file. val oc = CommitOwnerProvider.getCommitOwnerClient( - "tracking-in-memory", Map.empty[String, String]) + "tracking-in-memory", Map.empty[String, String], spark) val commitResponse = oc.getCommits(log.logPath, Map.empty, Some(1)) if (!commitResponse.getCommits.isEmpty) { commitFilePath = commitResponse.getCommits.head.getFileStatus.getPath diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/OptimisticTransactionSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/OptimisticTransactionSuite.scala index 540e8bff82e..9e668bb7cf2 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/OptimisticTransactionSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/OptimisticTransactionSuite.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.delta.util.{FileNames, JsonUtils} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.sql.Row -import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.{Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{EqualTo, Literal} import org.apache.spark.sql.functions.lit @@ -520,7 +519,8 @@ class OptimisticTransactionSuite } } } - override def build(conf: Map[String, String]): CommitOwnerClient = commitOwnerClient + override def build( + spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = commitOwnerClient } CommitOwnerProvider.registerBuilder(RetryableNonConflictCommitOwnerBuilder$) @@ -569,7 +569,8 @@ class OptimisticTransactionSuite } } } - override def build(conf: Map[String, String]): CommitOwnerClient = commitOwnerClient + override def build( + spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = commitOwnerClient } CommitOwnerProvider.registerBuilder(FileAlreadyExistsCommitOwnerBuilder) @@ -878,7 +879,8 @@ class OptimisticTransactionSuite object RetryableConflictCommitOwnerBuilder$ extends CommitOwnerBuilder { lazy val commitOwnerClient = new RetryableConflictCommitOwnerClient() override def getName: String = commitOwnerName - override def build(conf: Map[String, String]): CommitOwnerClient = commitOwnerClient + override def build( + spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = commitOwnerClient } CommitOwnerProvider.registerBuilder(RetryableConflictCommitOwnerBuilder$) val conf = Map(DeltaConfigs.MANAGED_COMMIT_OWNER_NAME.key -> commitOwnerName) diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/SnapshotManagementSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/SnapshotManagementSuite.scala index ed641c1d2fc..23380dd5afb 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/SnapshotManagementSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/SnapshotManagementSuite.scala @@ -41,6 +41,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkConf import org.apache.spark.SparkException import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.storage.StorageLevel @@ -587,7 +588,7 @@ object ConcurrentBackfillCommitOwnerBuilder extends CommitOwnerBuilder { private lazy val concurrentBackfillCommitOwnerClient = ConcurrentBackfillCommitOwnerClient(synchronousBackfillThreshold = 2, batchSize) override def getName: String = "awaiting-commit-owner" - override def build(conf: Map[String, String]): CommitOwnerClient = { + override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = { concurrentBackfillCommitOwnerClient } } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/CommitOwnerClientSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/CommitOwnerClientSuite.scala index 325e0bca5e7..5faa9b02ecf 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/CommitOwnerClientSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/CommitOwnerClientSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.delta.test.DeltaSQLTestUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{QueryTest, SparkSession} import org.apache.spark.sql.test.SharedSparkSession class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with SharedSparkSession @@ -72,15 +72,15 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share test("registering multiple commit-owner builders with same name") { object Builder1 extends CommitOwnerBuilder { - override def build(conf: Map[String, String]): CommitOwnerClient = null + override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = null override def getName: String = "builder-1" } object BuilderWithSameName extends CommitOwnerBuilder { - override def build(conf: Map[String, String]): CommitOwnerClient = null + override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = null override def getName: String = "builder-1" } object Builder3 extends CommitOwnerBuilder { - override def build(conf: Map[String, String]): CommitOwnerClient = null + override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = null override def getName: String = "builder-3" } CommitOwnerProvider.registerBuilder(Builder1) @@ -94,7 +94,7 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share object Builder1 extends CommitOwnerBuilder { val cs1 = new TestCommitOwnerClient1() val cs2 = new TestCommitOwnerClient2() - override def build(conf: Map[String, String]): CommitOwnerClient = { + override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = { conf.getOrElse("url", "") match { case "url1" => cs1 case "url2" => cs2 @@ -104,21 +104,22 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share override def getName: String = "cs-x" } CommitOwnerProvider.registerBuilder(Builder1) - val cs1 = CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url1")) + val cs1 = CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url1"), spark) assert(cs1.isInstanceOf[TestCommitOwnerClient1]) - val cs1_again = CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url1")) + val cs1_again = CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url1"), spark) assert(cs1 eq cs1_again) - val cs2 = CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url2", "a" -> "b")) + val cs2 = + CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url2", "a" -> "b"), spark) assert(cs2.isInstanceOf[TestCommitOwnerClient2]) // If builder receives a config which doesn't have expected params, then it can throw exception. intercept[IllegalArgumentException] { - CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url3")) + CommitOwnerProvider.getCommitOwnerClient("cs-x", Map("url" -> "url3"), spark) } } test("getCommitOwnerClient - builder returns new object each time") { object Builder1 extends CommitOwnerBuilder { - override def build(conf: Map[String, String]): CommitOwnerClient = { + override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = { conf.getOrElse("url", "") match { case "url1" => new TestCommitOwnerClient1() case _ => throw new IllegalArgumentException("Invalid url") @@ -127,9 +128,9 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share override def getName: String = "cs-name" } CommitOwnerProvider.registerBuilder(Builder1) - val cs1 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("url" -> "url1")) + val cs1 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("url" -> "url1"), spark) assert(cs1.isInstanceOf[TestCommitOwnerClient1]) - val cs1_again = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("url" -> "url1")) + val cs1_again = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("url" -> "url1"), spark) assert(cs1 ne cs1_again) } @@ -202,7 +203,7 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share other.asInstanceOf[TestCommitOwnerClient].key == key } object Builder1 extends CommitOwnerBuilder { - override def build(conf: Map[String, String]): CommitOwnerClient = { + override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = { new TestCommitOwnerClient(conf("key")) } override def getName: String = "cs-name" @@ -210,13 +211,13 @@ class CommitOwnerClientSuite extends QueryTest with DeltaSQLTestUtils with Share CommitOwnerProvider.registerBuilder(Builder1) // Different CommitOwner with same keys should be semantically equal. - val obj1 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url1")) - val obj2 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url1")) + val obj1 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url1"), spark) + val obj2 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url1"), spark) assert(obj1 != obj2) assert(obj1.semanticEquals(obj2)) // Different CommitOwner with different keys should be semantically unequal. - val obj3 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url2")) + val obj3 = CommitOwnerProvider.getCommitOwnerClient("cs-name", Map("key" -> "url2"), spark) assert(obj1 != obj3) assert(!obj1.semanticEquals(obj3)) } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitOwnerSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitOwnerSuite.scala index 82e7698d423..7ba4265926d 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitOwnerSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/InMemoryCommitOwnerSuite.scala @@ -24,7 +24,7 @@ abstract class InMemoryCommitOwnerSuite(batchSize: Int) extends CommitOwnerClien override protected def createTableCommitOwnerClient( deltaLog: DeltaLog): TableCommitOwnerClient = { - val cs = InMemoryCommitOwnerBuilder(batchSize).build(Map.empty) + val cs = InMemoryCommitOwnerBuilder(batchSize).build(spark, Map.empty) TableCommitOwnerClient(cs, deltaLog, Map.empty[String, String]) } @@ -65,22 +65,22 @@ abstract class InMemoryCommitOwnerSuite(batchSize: Int) extends CommitOwnerClien test("InMemoryCommitOwnerBuilder works as expected") { val builder1 = InMemoryCommitOwnerBuilder(5) - val cs1 = builder1.build(Map.empty) + val cs1 = builder1.build(spark, Map.empty) assert(cs1.isInstanceOf[InMemoryCommitOwner]) assert(cs1.asInstanceOf[InMemoryCommitOwner].batchSize == 5) - val cs1_again = builder1.build(Map.empty) + val cs1_again = builder1.build(spark, Map.empty) assert(cs1_again.isInstanceOf[InMemoryCommitOwner]) assert(cs1 == cs1_again) val builder2 = InMemoryCommitOwnerBuilder(10) - val cs2 = builder2.build(Map.empty) + val cs2 = builder2.build(spark, Map.empty) assert(cs2.isInstanceOf[InMemoryCommitOwner]) assert(cs2.asInstanceOf[InMemoryCommitOwner].batchSize == 10) assert(cs2 ne cs1) val builder3 = InMemoryCommitOwnerBuilder(10) - val cs3 = builder3.build(Map.empty) + val cs3 = builder3.build(spark, Map.empty) assert(cs3.isInstanceOf[InMemoryCommitOwner]) assert(cs3.asInstanceOf[InMemoryCommitOwner].batchSize == 10) assert(cs3 ne cs2) diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitSuite.scala index 9cdd62a1295..64627b2071b 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitSuite.scala @@ -40,7 +40,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.SparkConf -import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.ManualClock @@ -71,7 +71,7 @@ class ManagedCommitSuite override def getName: String = commitOwnerName - override def build(conf: Map[String, String]): CommitOwnerClient = + override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = new InMemoryCommitOwner(batchSize = 5) { override def commit( logStore: LogStore, @@ -125,7 +125,7 @@ class ManagedCommitSuite test("cold snapshot initialization") { val builder = TrackingInMemoryCommitOwnerBuilder(batchSize = 10) - val commitOwnerClient = builder.build(Map.empty).asInstanceOf[TrackingCommitOwnerClient] + val commitOwnerClient = builder.build(spark, Map.empty).asInstanceOf[TrackingCommitOwnerClient] CommitOwnerProvider.registerBuilder(builder) withTempDir { tempDir => val tablePath = tempDir.getAbsolutePath @@ -221,7 +221,7 @@ class ManagedCommitSuite name: String, commitOwnerClient: CommitOwnerClient) extends CommitOwnerBuilder { var numBuildCalled = 0 - override def build(conf: Map[String, String]): CommitOwnerClient = { + override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = { numBuildCalled += 1 commitOwnerClient } @@ -361,7 +361,8 @@ class ManagedCommitSuite case class TrackingInMemoryCommitOwnerClientBuilder( name: String, commitOwnerClient: CommitOwnerClient) extends CommitOwnerBuilder { - override def build(conf: Map[String, String]): CommitOwnerClient = commitOwnerClient + override def build( + spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = commitOwnerClient override def getName: String = name } val builder1 = TrackingInMemoryCommitOwnerClientBuilder(name = "in-memory-1", cs1) diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitTestUtils.scala b/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitTestUtils.scala index 5698bcbead5..60497c1d063 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitTestUtils.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/managedcommit/ManagedCommitTestUtils.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.test.SharedSparkSession trait ManagedCommitTestUtils @@ -116,7 +117,7 @@ case class TrackingInMemoryCommitOwnerBuilder( } override def getName: String = "tracking-in-memory" - override def build(conf: Map[String, String]): CommitOwnerClient = { + override def build(spark: SparkSession, conf: Map[String, String]): CommitOwnerClient = { trackingInMemoryCommitOwnerClient } }