From 0864bbeac399825619375187b6c0ddca3f408f90 Mon Sep 17 00:00:00 2001 From: Daniel Tenedorio Date: Mon, 17 Jun 2024 11:12:41 -0700 Subject: [PATCH] [SPARK-48566][PYTHON] Fix bug where partition indices are incorrect when UDTF analyze() uses both select and partitionColumns ### What changes were proposed in this pull request? This PR fixes a bug that resulted in an internal error with some combination of the Python UDTF "select" and "partitionBy" options of the "analyze" method. Specifically, this logic in `Analyzer.scala` was wrong because it did not update the usage of `partitioningExpressionIndexes` to take the "select" expressions into account when they were introduced in https://github.com/apache/spark/pull/45007: ``` val tvfWithTableColumnIndexes = tvf match { case g Generate(pyudtf: PythonUDTF, _, _, _, _, _) if tableArgs.head._1.partitioningExpressionIndexes.nonEmpty => ////////////////////////////////////////////////////////////////////////////// // The bug is here: the 'partitioningExpressionIndexes' are not valid // if the UDTF "select" expressions are non-empty, since that prompts // us to add a new projection (of a possibly different number of // expressions) to evaluate them. ////////////////////////////////////////////////////////////////////////////// val partitionColumnIndexes = PythonUDTFPartitionColumnIndexes(tableArgs.head._1.partitioningExpressionIndexes) g.copy(generator = pyudtf.copy( pythonUDTFPartitionColumnIndexes = Some(partitionColumnIndexes))) case _ => tvf } ``` To reproduce: ``` from pyspark.sql.functions import ( AnalyzeArgument, AnalyzeResult, PartitioningColumn, SelectedColumn, udtf ) from pyspark.sql.types import ( DoubleType, StringType, StructType, ) udtf class TestTvf: staticmethod def analyze(observed: AnalyzeArgument) -> AnalyzeResult: out_schema = StructType() out_schema.add("partition_col", StringType()) out_schema.add("double_col", DoubleType()) return AnalyzeResult( schema=out_schema, partitionBy=[PartitioningColumn("partition_col")], select=[ SelectedColumn("partition_col"), SelectedColumn("double_col"), ], ) def eval(self, *args, **kwargs): pass def terminate(self): for _ in range(10): yield { "partition_col": None, "double_col": 1.0, } spark.udtf.register("serialize_test", TestTvf) # Fails ( spark .sql( """ SELECT * FROM serialize_test( TABLE( SELECT 5 AS unused_col, 'hi' AS partition_col, 1.0 AS double_col UNION ALL SELECT 4 AS unused_col, 'hi' AS partition_col, 1.0 AS double_col ) ) """ ) .toPandas() ) ``` ### Why are the changes needed? The above query returned internal errors before, but works now. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Additional golden file coverage ### Was this patch authored or co-authored using generative AI tooling? Some light GitHub copilot usage Closes #46918 from dtenedor/fix-udtf-bug. Authored-by: Daniel Tenedorio Signed-off-by: Takuya Ueshin --- .../sql/catalyst/analysis/Analyzer.scala | 14 ++++-- ...ctionTableSubqueryArgumentExpression.scala | 7 ++- .../analyzer-results/udtf/udtf.sql.out | 20 +++++++++ .../resources/sql-tests/inputs/udtf/udtf.sql | 16 +++++++ .../sql-tests/results/udtf/udtf.sql.out | 26 +++++++++++ .../spark/sql/IntegratedUDFTestUtils.scala | 44 +++++++++++++++++++ 6 files changed, 122 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a233161713c3c..cd7aeb7cd4ac9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2206,11 +2206,19 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}") // Propagate the column indexes for TABLE arguments to the PythonUDTF instance. + val f: FunctionTableSubqueryArgumentExpression = tableArgs.head._1 val tvfWithTableColumnIndexes = tvf match { case g @ Generate(pyudtf: PythonUDTF, _, _, _, _, _) - if tableArgs.head._1.partitioningExpressionIndexes.nonEmpty => - val partitionColumnIndexes = - PythonUDTFPartitionColumnIndexes(tableArgs.head._1.partitioningExpressionIndexes) + if f.extraProjectedPartitioningExpressions.nonEmpty => + val partitionColumnIndexes = if (f.selectedInputExpressions.isEmpty) { + PythonUDTFPartitionColumnIndexes(f.partitioningExpressionIndexes) + } else { + // If the UDTF specified 'select' expression(s), we added a projection to compute + // them plus the 'partitionBy' expression(s) afterwards. + PythonUDTFPartitionColumnIndexes( + (0 until f.extraProjectedPartitioningExpressions.length) + .map(_ + f.selectedInputExpressions.length)) + } g.copy(generator = pyudtf.copy( pythonUDTFPartitionColumnIndexes = Some(partitionColumnIndexes))) case _ => tvf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala index 94465ccff796e..bfd3bc8051dff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala @@ -172,9 +172,12 @@ case class FunctionTableSubqueryArgumentExpression( } } - private lazy val extraProjectedPartitioningExpressions: Seq[Alias] = { + lazy val extraProjectedPartitioningExpressions: Seq[Alias] = { partitionByExpressions.filter { e => - !subqueryOutputs.contains(e) + !subqueryOutputs.contains(e) || + // Skip deduplicating the 'partitionBy' expression(s) against the attributes of the input + // table if the UDTF also specified 'select' expression(s). + selectedInputExpressions.nonEmpty }.zipWithIndex.map { case (expr, index) => Alias(expr, s"partition_by_$index")() } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out index 74ea9261462d6..4b53f1c6f19c4 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out @@ -904,6 +904,26 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +SELECT * FROM UDTFPartitionByIndexingBug( + TABLE( + SELECT + 5 AS unused_col, + 'hi' AS partition_col, + 1.0 AS double_col + + UNION ALL + + SELECT + 4 AS unused_col, + 'hi' AS partition_col, + 1.0 AS double_col + ) +) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + -- !query DROP VIEW t1 -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql index c83481f10dca6..a437b1f93b604 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql @@ -143,6 +143,22 @@ SELECT * FROM UDTFWithSinglePartition(1, invalid_arg_name => 2); SELECT * FROM UDTFWithSinglePartition(1, initial_count => 2); SELECT * FROM UDTFWithSinglePartition(initial_count => 1, initial_count => 2); SELECT * FROM UDTFInvalidPartitionByOrderByParseError(TABLE(t2)); +-- Exercise the UDTF partitioning bug. +SELECT * FROM UDTFPartitionByIndexingBug( + TABLE( + SELECT + 5 AS unused_col, + 'hi' AS partition_col, + 1.0 AS double_col + + UNION ALL + + SELECT + 4 AS unused_col, + 'hi' AS partition_col, + 1.0 AS double_col + ) +); -- cleanup DROP VIEW t1; diff --git a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out index 78ad8b7c02cd5..f99c6c30c07e2 100644 --- a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out @@ -1069,6 +1069,32 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +SELECT * FROM UDTFPartitionByIndexingBug( + TABLE( + SELECT + 5 AS unused_col, + 'hi' AS partition_col, + 1.0 AS double_col + + UNION ALL + + SELECT + 4 AS unused_col, + 'hi' AS partition_col, + 1.0 AS double_col + ) +) +-- !query schema +struct +-- !query output +NULL 1.0 +NULL 1.0 +NULL 1.0 +NULL 1.0 +NULL 1.0 + + -- !query DROP VIEW t1 -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index c1ca48162d207..957be07607b66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -660,6 +660,49 @@ object IntegratedUDFTestUtils extends SQLHelper { orderBy = "OrderingColumn(\"input\")", select = "SelectedColumn(\"partition_col\")") + object UDTFPartitionByIndexingBug extends TestUDTF { + val pythonScript: String = + s""" + |from pyspark.sql.functions import ( + | AnalyzeArgument, + | AnalyzeResult, + | PartitioningColumn, + | SelectedColumn, + | udtf + |) + |from pyspark.sql.types import ( + | DoubleType, + | StringType, + | StructType, + |) + |class $name: + | @staticmethod + | def analyze(observed: AnalyzeArgument) -> AnalyzeResult: + | out_schema = StructType() + | out_schema.add("partition_col", StringType()) + | out_schema.add("double_col", DoubleType()) + | + | return AnalyzeResult( + | schema=out_schema, + | partitionBy=[PartitioningColumn("partition_col")], + | select=[ + | SelectedColumn("partition_col"), + | SelectedColumn("double_col"), + | ], + | ) + | + | def eval(self, *args, **kwargs): + | pass + | + | def terminate(self): + | for _ in range(5): + | yield { + | "partition_col": None, + | "double_col": 1.0, + | } + |""".stripMargin + } + object UDTFInvalidPartitionByOrderByParseError extends TestPythonUDTFPartitionByOrderByBase( partitionBy = "PartitioningColumn(\"unparsable\")", @@ -1216,6 +1259,7 @@ object IntegratedUDFTestUtils extends SQLHelper { UDTFPartitionByOrderBySelectExpr, UDTFPartitionByOrderBySelectComplexExpr, UDTFPartitionByOrderBySelectExprOnlyPartitionColumn, + UDTFPartitionByIndexingBug, InvalidAnalyzeMethodReturnsNonStructTypeSchema, InvalidAnalyzeMethodWithSinglePartitionNoInputTable, InvalidAnalyzeMethodWithPartitionByNoInputTable,