Skip to content

Commit

Permalink
fixing compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Jan 6, 2025
1 parent dd82c88 commit 45240f4
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,8 @@ class IncrementalExecution(
oldMetadata)
val metadata = ssw.operatorStateMetadata(stateSchemaList)
oldMetadata match {
case Some(oldMetadata) => ssw.validateNewMetadata(oldMetadata, metadata)
case Some(oldMetadata) =>
ssw.validateNewMetadata(oldMetadata, metadata)
case None =>
}
val metadataWriter = OperatorStateMetadataWriter.createWriter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,37 @@ trait StateSchemaProvider extends Serializable {
def getCurrentStateSchemaId(colFamilyName: String, isKey: Boolean): Short
}

// Test implementation that can be dynamically updated
class TestStateSchemaProvider extends StateSchemaProvider {
private var schemas = Map.empty[StateSchemaMetadataKey, StateSchemaMetadataValue]

def addSchema(
colFamilyName: String,
keySchema: StructType,
valueSchema: StructType,
keySchemaId: Short = 0,
valueSchemaId: Short = 0): Unit = {
schemas ++= Map(
StateSchemaMetadataKey(colFamilyName, keySchemaId, isKey = true) ->
StateSchemaMetadataValue(keySchema, SchemaConverters.toAvroTypeWithDefaults(keySchema)),
StateSchemaMetadataKey(colFamilyName, valueSchemaId, isKey = false) ->
StateSchemaMetadataValue(valueSchema, SchemaConverters.toAvroTypeWithDefaults(valueSchema))
)
}

override def getSchemaMetadataValue(key: StateSchemaMetadataKey): StateSchemaMetadataValue = {
schemas(key)
}

override def getCurrentStateSchemaId(colFamilyName: String, isKey: Boolean): Short = {
schemas.keys
.filter(key =>
key.colFamilyName == colFamilyName &&
key.isKey == isKey)
.map(_.schemaId).max
}
}

class InMemoryStateSchemaProvider(metadata: StateSchemaMetadata) extends StateSchemaProvider {
override def getSchemaMetadataValue(key: StateSchemaMetadataKey): StateSchemaMetadataValue = {
metadata.activeSchemas(key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class StateSchemaCompatibilityChecker(
case s: SchemaValidationException =>
throw StateStoreErrors.stateStoreInvalidValueSchemaEvolution(
valueSchema.toString, s.getMessage)
case e: _ => throw e
case e: Throwable => throw e
}

// Schema evolved - increment value schema ID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,13 +336,13 @@ class StateStoreValueSchemaNotCompatible(
"newValueSchema" -> newValueSchema))

class StateStoreInvalidValueSchemaEvolution(
storedValueSchema: String,
newValueSchema: String)
newValueSchema: String,
avroErrorMessage: String)
extends SparkUnsupportedOperationException(
errorClass = "STATE_STORE_INVALID_VALUE_SCHEMA_EVOLUTION",
messageParameters = Map(
"storedValueSchema" -> storedValueSchema,
"newValueSchema" -> newValueSchema))
"newValueSchema" -> newValueSchema,
"avroErrorMessage" -> avroErrorMessage))

class StateStoreSchemaFileThresholdExceeded(
numSchemaFiles: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.apache.spark.{SparkConf, SparkUnsupportedOperationException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.LocalSparkSession.withSparkSession
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.avro.SchemaConverters
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.util.quietly
Expand All @@ -43,37 +42,6 @@ import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

// Test implementation that can be dynamically updated
class TestStateSchemaProvider extends StateSchemaProvider {
private var schemas = Map.empty[StateSchemaMetadataKey, StateSchemaMetadataValue]

def addSchema(
colFamilyName: String,
keySchema: StructType,
valueSchema: StructType,
keySchemaId: Short = 0,
valueSchemaId: Short = 0): Unit = {
schemas ++= Map(
StateSchemaMetadataKey(colFamilyName, keySchemaId, isKey = true) ->
StateSchemaMetadataValue(keySchema, SchemaConverters.toAvroTypeWithDefaults(keySchema)),
StateSchemaMetadataKey(colFamilyName, valueSchemaId, isKey = false) ->
StateSchemaMetadataValue(valueSchema, SchemaConverters.toAvroTypeWithDefaults(valueSchema))
)
}

override def getSchemaMetadataValue(key: StateSchemaMetadataKey): StateSchemaMetadataValue = {
schemas(key)
}

override def getCurrentStateSchemaId(colFamilyName: String, isKey: Boolean): Short = {
schemas.keys
.filter(key =>
key.colFamilyName == colFamilyName &&
key.isKey == isKey)
.map(_.schemaId).max
}
}

@ExtendedSQLTest
class RocksDBStateStoreSuite extends StateStoreSuiteBase[RocksDBStateStoreProvider]
with AlsoTestWithRocksDBFeatures
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1854,8 +1854,7 @@ class TransformWithStateSuite extends StateStoreMetricsTest
AddData(inputData, "a"),
ExpectFailure[StateStoreInvalidValueSchemaEvolution] {
(t: Throwable) => {
assert(t.getMessage.contains(
"Unable to read schema:"))
assert(t.getMessage.contains("Unable to read schema:"))
}
}
)
Expand Down Expand Up @@ -2143,6 +2142,53 @@ class TransformWithStateSuite extends StateStoreMetricsTest
}
}

test("test exceeding schema file threshold throws error") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
SQLConf.SHUFFLE_PARTITIONS.key ->
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString,
SQLConf.STREAMING_STATE_MAX_STATE_SCHEMA_FILES.key -> 1.toString) {
withTempDir { checkpointDir =>
val clock = new StreamManualClock

val inputData1 = MemoryStream[String]
val result1 = inputData1.toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountStatefulProcessor(),
TimeMode.ProcessingTime(),
OutputMode.Update())

testStream(result1, OutputMode.Update())(
StartStream(
checkpointLocation = checkpointDir.getCanonicalPath,
trigger = Trigger.ProcessingTime("1 second"),
triggerClock = clock),
AddData(inputData1, "a"),
AdvanceManualClock(1 * 1000),
CheckNewAnswer(("a", "1")),
StopStream
)

val result2 = inputData1.toDS()
.groupByKey(x => x)
.transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(),
TimeMode.ProcessingTime(),
OutputMode.Update())

testStream(result2, OutputMode.Update())(
StartStream(
checkpointLocation = checkpointDir.getCanonicalPath,
trigger = Trigger.ProcessingTime("1 second"),
triggerClock = clock),
AddData(inputData1, "a"),
AdvanceManualClock(1 * 1000),
CheckNewAnswer(("a", "2")),
StopStream
)
}
}
}

test("test query restart succeeds") {
withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
classOf[RocksDBStateStoreProvider].getName,
Expand Down

0 comments on commit 45240f4

Please sign in to comment.