From 511f8085ed48de54daa7d7f74e68a0bf1908c625 Mon Sep 17 00:00:00 2001 From: cashmand Date: Fri, 13 Dec 2024 10:11:15 -0800 Subject: [PATCH] [SPARK-48898][SQL] Set nullability correctly in the Variant schema ### What changes were proposed in this pull request? The variantShreddingSchema method converts a human-readable schema for Variant to one that's a valid shredding schema. According to the shredding schema in https://github.com/apache/parquet-format/pull/461, each shredded field in an object should be a required group - i.e. a non-nullable struct. This PR fixes the variantShreddingSchema to mark that struct as non-nullable. ### Why are the changes needed? If we use variantShreddingSchema to construct a schema for Parquet, the schema would be technically non-conformant with the spec by setting the group as optional. I don't think this should really matter to readers, but it would waste a bit of space in the Parquet file by adding an extra definition level. ### Does this PR introduce _any_ user-facing change? No, this code is not used yet. ### How was this patch tested? Added a test to do some minimal validation of the variantShreddingSchema function. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49151 from cashmand/SPARK-48898-nullability-again. Authored-by: cashmand Signed-off-by: Wenchen Fan --- .../parquet/SparkShreddingUtils.scala | 11 +++++-- .../spark/sql/VariantShreddingSuite.scala | 11 ++++++- .../sql/VariantWriteShreddingSuite.scala | 30 +++++++++++++++++++ 3 files changed, 48 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala index 41244e20c369f..f38e188ed042c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala @@ -73,16 +73,21 @@ case object SparkShreddingUtils { */ def variantShreddingSchema(dataType: DataType, isTopLevel: Boolean = true): StructType = { val fields = dataType match { - case ArrayType(elementType, containsNull) => + case ArrayType(elementType, _) => + // Always set containsNull to false. One of value or typed_value must always be set for + // array elements. val arrayShreddingSchema = - ArrayType(variantShreddingSchema(elementType, false), containsNull) + ArrayType(variantShreddingSchema(elementType, false), containsNull = false) Seq( StructField(VariantValueFieldName, BinaryType, nullable = true), StructField(TypedValueFieldName, arrayShreddingSchema, nullable = true) ) case StructType(fields) => + // The field name level is always non-nullable: Variant null values are represented in the + // "value" columna as "00", and missing values are represented by setting both "value" and + // "typed_value" to null. val objectShreddingSchema = StructType(fields.map(f => - f.copy(dataType = variantShreddingSchema(f.dataType, false)))) + f.copy(dataType = variantShreddingSchema(f.dataType, false), nullable = false))) Seq( StructField(VariantValueFieldName, BinaryType, nullable = true), StructField(TypedValueFieldName, objectShreddingSchema, nullable = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala index 4ff346b957aa0..5d5c441052558 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala @@ -155,7 +155,16 @@ class VariantShreddingSuite extends QueryTest with SharedSparkSession with Parqu Row(metadata(Nil), null, Array(Row(null, null)))) checkException(path, "v", "MALFORMED_VARIANT") // Shredded field must not be null. - writeRows(path, writeSchema(StructType.fromDDL("a int")), + // Construct the schema manually, because SparkShreddingUtils.variantShreddingSchema will make + // `a` non-nullable, which would prevent us from writing the file. + val schema = StructType(Seq(StructField("v", StructType(Seq( + StructField("metadata", BinaryType), + StructField("value", BinaryType), + StructField("typed_value", StructType(Seq( + StructField("a", StructType(Seq( + StructField("value", BinaryType), + StructField("typed_value", BinaryType)))))))))))) + writeRows(path, schema, Row(metadata(Seq("a")), null, Row(null))) checkException(path, "v", "MALFORMED_VARIANT") // `value` must not contain any shredded field. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantWriteShreddingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantWriteShreddingSuite.scala index a62c6e4462464..d31bf109af6c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantWriteShreddingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantWriteShreddingSuite.scala @@ -67,6 +67,36 @@ class VariantWriteShreddingSuite extends SparkFunSuite with ExpressionEvalHelper private val emptyMetadata: Array[Byte] = parseJson("null").getMetadata + test("variantShreddingSchema") { + // Validate the schema produced by SparkShreddingUtils.variantShreddingSchema for a few simple + // cases. + // metadata is always non-nullable. + assert(SparkShreddingUtils.variantShreddingSchema(IntegerType) == + StructType(Seq( + StructField("metadata", BinaryType, nullable = false), + StructField("value", BinaryType, nullable = true), + StructField("typed_value", IntegerType, nullable = true)))) + + val fieldA = StructType(Seq( + StructField("value", BinaryType, nullable = true), + StructField("typed_value", TimestampNTZType, nullable = true))) + val arrayType = ArrayType(StructType(Seq( + StructField("value", BinaryType, nullable = true), + StructField("typed_value", StringType, nullable = true))), containsNull = false) + val fieldB = StructType(Seq( + StructField("value", BinaryType, nullable = true), + StructField("typed_value", arrayType, nullable = true))) + val objectType = StructType(Seq( + StructField("a", fieldA, nullable = false), + StructField("b", fieldB, nullable = false))) + val structSchema = DataType.fromDDL("a timestamp_ntz, b array") + assert(SparkShreddingUtils.variantShreddingSchema(structSchema) == + StructType(Seq( + StructField("metadata", BinaryType, nullable = false), + StructField("value", BinaryType, nullable = true), + StructField("typed_value", objectType, nullable = true)))) + } + test("shredding as fixed numeric types") { /* Cast integer to any wider numeric type. */ testWithSchema("1", IntegerType, Row(emptyMetadata, null, 1))