diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 7a604be919026..e11e9c502c723 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -253,7 +253,8 @@ class IncrementalExecution( oldMetadata) val metadata = ssw.operatorStateMetadata(stateSchemaList) oldMetadata match { - case Some(oldMetadata) => ssw.validateNewMetadata(oldMetadata, metadata) + case Some(oldMetadata) => + ssw.validateNewMetadata(oldMetadata, metadata) case None => } val metadataWriter = OperatorStateMetadataWriter.createWriter( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala index b79d7a53becf1..c1e702fd7d421 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala @@ -60,6 +60,37 @@ trait StateSchemaProvider extends Serializable { def getCurrentStateSchemaId(colFamilyName: String, isKey: Boolean): Short } +// Test implementation that can be dynamically updated +class TestStateSchemaProvider extends StateSchemaProvider { + private var schemas = Map.empty[StateSchemaMetadataKey, StateSchemaMetadataValue] + + def addSchema( + colFamilyName: String, + keySchema: StructType, + valueSchema: StructType, + keySchemaId: Short = 0, + valueSchemaId: Short = 0): Unit = { + schemas ++= Map( + StateSchemaMetadataKey(colFamilyName, keySchemaId, isKey = true) -> + StateSchemaMetadataValue(keySchema, SchemaConverters.toAvroTypeWithDefaults(keySchema)), + StateSchemaMetadataKey(colFamilyName, valueSchemaId, isKey = false) -> + StateSchemaMetadataValue(valueSchema, SchemaConverters.toAvroTypeWithDefaults(valueSchema)) + ) + } + + override def getSchemaMetadataValue(key: StateSchemaMetadataKey): StateSchemaMetadataValue = { + schemas(key) + } + + override def getCurrentStateSchemaId(colFamilyName: String, isKey: Boolean): Short = { + schemas.keys + .filter(key => + key.colFamilyName == colFamilyName && + key.isKey == isKey) + .map(_.schemaId).max + } +} + class InMemoryStateSchemaProvider(metadata: StateSchemaMetadata) extends StateSchemaProvider { override def getSchemaMetadataValue(key: StateSchemaMetadataKey): StateSchemaMetadataValue = { metadata.activeSchemas(key) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala index 60805fefb66db..6feebc0d078c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateSchemaCompatibilityChecker.scala @@ -210,7 +210,7 @@ class StateSchemaCompatibilityChecker( case s: SchemaValidationException => throw StateStoreErrors.stateStoreInvalidValueSchemaEvolution( valueSchema.toString, s.getMessage) - case e: _ => throw e + case e: Throwable => throw e } // Schema evolved - increment value schema ID diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala index 61ed35c60a2ee..3094d3610335a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreErrors.scala @@ -336,13 +336,13 @@ class StateStoreValueSchemaNotCompatible( "newValueSchema" -> newValueSchema)) class StateStoreInvalidValueSchemaEvolution( - storedValueSchema: String, - newValueSchema: String) + newValueSchema: String, + avroErrorMessage: String) extends SparkUnsupportedOperationException( errorClass = "STATE_STORE_INVALID_VALUE_SCHEMA_EVOLUTION", messageParameters = Map( - "storedValueSchema" -> storedValueSchema, - "newValueSchema" -> newValueSchema)) + "newValueSchema" -> newValueSchema, + "avroErrorMessage" -> avroErrorMessage)) class StateStoreSchemaFileThresholdExceeded( numSchemaFiles: Int, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala index 21415721718b4..0c16291437396 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreSuite.scala @@ -30,7 +30,6 @@ import org.apache.spark.{SparkConf, SparkUnsupportedOperationException} import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.LocalSparkSession.withSparkSession import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.avro.SchemaConverters import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly @@ -43,37 +42,6 @@ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -// Test implementation that can be dynamically updated -class TestStateSchemaProvider extends StateSchemaProvider { - private var schemas = Map.empty[StateSchemaMetadataKey, StateSchemaMetadataValue] - - def addSchema( - colFamilyName: String, - keySchema: StructType, - valueSchema: StructType, - keySchemaId: Short = 0, - valueSchemaId: Short = 0): Unit = { - schemas ++= Map( - StateSchemaMetadataKey(colFamilyName, keySchemaId, isKey = true) -> - StateSchemaMetadataValue(keySchema, SchemaConverters.toAvroTypeWithDefaults(keySchema)), - StateSchemaMetadataKey(colFamilyName, valueSchemaId, isKey = false) -> - StateSchemaMetadataValue(valueSchema, SchemaConverters.toAvroTypeWithDefaults(valueSchema)) - ) - } - - override def getSchemaMetadataValue(key: StateSchemaMetadataKey): StateSchemaMetadataValue = { - schemas(key) - } - - override def getCurrentStateSchemaId(colFamilyName: String, isKey: Boolean): Short = { - schemas.keys - .filter(key => - key.colFamilyName == colFamilyName && - key.isKey == isKey) - .map(_.schemaId).max - } -} - @ExtendedSQLTest class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvider] with AlsoTestWithRocksDBFeatures diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 0825a0ada2893..8d0ad2e0d624c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -1854,8 +1854,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest AddData(inputData, "a"), ExpectFailure[StateStoreInvalidValueSchemaEvolution] { (t: Throwable) => { - assert(t.getMessage.contains( - "Unable to read schema:")) + assert(t.getMessage.contains("Unable to read schema:")) } } ) @@ -2143,6 +2142,53 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + test("test exceeding schema file threshold throws error") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString, + SQLConf.STREAMING_STATE_MAX_STATE_SCHEMA_FILES.key -> 1.toString) { + withTempDir { checkpointDir => + val clock = new StreamManualClock + + val inputData1 = MemoryStream[String] + val result1 = inputData1.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessor(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result1, OutputMode.Update())( + StartStream( + checkpointLocation = checkpointDir.getCanonicalPath, + trigger = Trigger.ProcessingTime("1 second"), + triggerClock = clock), + AddData(inputData1, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + StopStream + ) + + val result2 = inputData1.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result2, OutputMode.Update())( + StartStream( + checkpointLocation = checkpointDir.getCanonicalPath, + trigger = Trigger.ProcessingTime("1 second"), + triggerClock = clock), + AddData(inputData1, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "2")), + StopStream + ) + } + } + } + test("test query restart succeeds") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName,