diff --git a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala index 4c8fcab692..a61a4780d4 100644 --- a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala +++ b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala @@ -65,9 +65,6 @@ object SedonaContext { RasterRegistrator.registerAll(sparkSession) UdtRegistrator.registerAll() UdfRegistrator.registerAll(sparkSession) - if (sparkSession.conf.get("spark.sedona.enableParserExtension", "true").toBoolean) { - ParserRegistrator.register(sparkSession) - } sparkSession } diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/ParserRegistrator.scala b/spark/common/src/main/scala/org/apache/sedona/sql/ParserRegistrator.scala index db3c623a09..0414742aa2 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/ParserRegistrator.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/ParserRegistrator.scala @@ -38,21 +38,6 @@ object ParserRegistrator { val field = sparkSession.sessionState.getClass.getDeclaredField("sqlParser") field.setAccessible(true) field.set(sparkSession.sessionState, parser) - return // return if the new constructor is available - } catch { - case _: Exception => - } - - // try to register the parser with the legacy constructor for spark 3.0 - try { - val parserClassName = "org.apache.sedona.sql.parser.SedonaSqlParser" - val delegate: ParserInterface = sparkSession.sessionState.sqlParser - - val parser = - ParserFactory.getParser(parserClassName, sparkSession.sessionState.conf, delegate) - val field = sparkSession.sessionState.getClass.getDeclaredField("sqlParser") - field.setAccessible(true) - field.set(sparkSession.sessionState, parser) } catch { case _: Exception => } diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/SedonaSqlExtensions.scala b/spark/common/src/main/scala/org/apache/sedona/sql/SedonaSqlExtensions.scala index be0774ac90..fbc3567192 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/SedonaSqlExtensions.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/SedonaSqlExtensions.scala @@ -19,13 +19,24 @@ package org.apache.sedona.sql import org.apache.sedona.spark.SedonaContext +import org.apache.spark.SparkContext import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.parser.ParserFactory class SedonaSqlExtensions extends (SparkSessionExtensions => Unit) { + private lazy val enableParser = + SparkContext.getOrCreate().getConf.get("spark.sedona.enableParserExtension", "true").toBoolean + def apply(e: SparkSessionExtensions): Unit = { e.injectCheckRule(spark => { SedonaContext.create(spark) _ => () }) + + if (enableParser) { + e.injectParser { case (_, parser) => + ParserFactory.getParser("org.apache.sedona.sql.parser.SedonaSqlParser", parser) + } + } } } diff --git a/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala index 6c70419122..56c27ba76b 100644 --- a/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala +++ b/spark/spark-3.3/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala @@ -35,12 +35,8 @@ class SedonaSqlParser(delegate: ParserInterface) extends SparkSqlParser { override def parsePlan(sqlText: String): LogicalPlan = try { parse(sqlText) { parser => - parserBuilder.visit(parser.singleStatement()) match { - case plan: LogicalPlan => plan - case _ => - delegate.parsePlan(sqlText) - } - } + parserBuilder.visit(parser.singleStatement()) + }.asInstanceOf[LogicalPlan] } catch { case _: Exception => delegate.parsePlan(sqlText) diff --git a/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala b/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala index 72680aacd4..6f873d0a08 100644 --- a/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala +++ b/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala @@ -44,14 +44,29 @@ class SQLSyntaxTestScala extends TestBaseScala with TableDrivenPropertyChecks { it( "should be able to create a regular table with geometry column should work without a workaround") { - sparkSession.sql("CREATE TABLE T_TEST_EXPLICIT_GEOMETRY (GEO_COL GEOMETRY)") - sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY") should be(true) + try { + sparkSession.sql("CREATE TABLE T_TEST_EXPLICIT_GEOMETRY (GEO_COL GEOMETRY)") + sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") + } catch { + case ex: Exception => + ex.getClass.getName.endsWith("ParseException") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") + } } it( "should be able to create a regular table with regular and geometry column should work without a workaround") { - sparkSession.sql("CREATE TABLE T_TEST_EXPLICIT_GEOMETRY_2 (INT_COL INT, GEO_COL GEOMETRY)") - sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY_2") should be(true) + try { + sparkSession.sql( + "CREATE TABLE T_TEST_EXPLICIT_GEOMETRY_2 (INT_COL INT, GEO_COL GEOMETRY)") + sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY_2") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") + } catch { + case ex: Exception => + ex.getClass.getName.endsWith("ParseException") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") + } } } } diff --git a/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index f629648b29..8d13f6138d 100644 --- a/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/spark/spark-3.3/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -23,6 +23,8 @@ import org.apache.sedona.spark.SedonaContext import org.apache.spark.sql.DataFrame import org.scalatest.{BeforeAndAfterAll, FunSpec} +import java.util.concurrent.ThreadLocalRandom + trait TestBaseScala extends FunSpec with BeforeAndAfterAll { Logger.getRootLogger().setLevel(Level.WARN) Logger.getLogger("org.apache").setLevel(Level.WARN) @@ -30,6 +32,7 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { Logger.getLogger("akka").setLevel(Level.WARN) Logger.getLogger("org.apache.sedona.core").setLevel(Level.WARN) + val keyParserExtension = "spark.sedona.enableParserExtension" val warehouseLocation = System.getProperty("user.dir") + "/target/" val sparkSession = SedonaContext .builder() @@ -38,6 +41,8 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { .config("spark.sql.warehouse.dir", warehouseLocation) .config("sedona.join.autoBroadcastJoinThreshold", "-1") .config("spark.sql.session.timeZone", "UTC") + .config("spark.sql.extensions", "org.apache.sedona.sql.SedonaSqlExtensions") + .config(keyParserExtension, ThreadLocalRandom.current().nextBoolean()) .getOrCreate() val sparkSessionMinio = SedonaContext diff --git a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala index 6c70419122..56c27ba76b 100644 --- a/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala +++ b/spark/spark-3.4/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala @@ -35,12 +35,8 @@ class SedonaSqlParser(delegate: ParserInterface) extends SparkSqlParser { override def parsePlan(sqlText: String): LogicalPlan = try { parse(sqlText) { parser => - parserBuilder.visit(parser.singleStatement()) match { - case plan: LogicalPlan => plan - case _ => - delegate.parsePlan(sqlText) - } - } + parserBuilder.visit(parser.singleStatement()) + }.asInstanceOf[LogicalPlan] } catch { case _: Exception => delegate.parsePlan(sqlText) diff --git a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala index 72680aacd4..6f873d0a08 100644 --- a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala +++ b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala @@ -44,14 +44,29 @@ class SQLSyntaxTestScala extends TestBaseScala with TableDrivenPropertyChecks { it( "should be able to create a regular table with geometry column should work without a workaround") { - sparkSession.sql("CREATE TABLE T_TEST_EXPLICIT_GEOMETRY (GEO_COL GEOMETRY)") - sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY") should be(true) + try { + sparkSession.sql("CREATE TABLE T_TEST_EXPLICIT_GEOMETRY (GEO_COL GEOMETRY)") + sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") + } catch { + case ex: Exception => + ex.getClass.getName.endsWith("ParseException") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") + } } it( "should be able to create a regular table with regular and geometry column should work without a workaround") { - sparkSession.sql("CREATE TABLE T_TEST_EXPLICIT_GEOMETRY_2 (INT_COL INT, GEO_COL GEOMETRY)") - sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY_2") should be(true) + try { + sparkSession.sql( + "CREATE TABLE T_TEST_EXPLICIT_GEOMETRY_2 (INT_COL INT, GEO_COL GEOMETRY)") + sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY_2") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") + } catch { + case ex: Exception => + ex.getClass.getName.endsWith("ParseException") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") + } } } } diff --git a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index 34746d0b28..ae1ed5d091 100644 --- a/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/spark/spark-3.4/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -23,6 +23,8 @@ import org.apache.sedona.spark.SedonaContext import org.apache.spark.sql.DataFrame import org.scalatest.{BeforeAndAfterAll, FunSpec} +import java.util.concurrent.ThreadLocalRandom + trait TestBaseScala extends FunSpec with BeforeAndAfterAll { Logger.getRootLogger().setLevel(Level.WARN) Logger.getLogger("org.apache").setLevel(Level.WARN) @@ -30,6 +32,7 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { Logger.getLogger("akka").setLevel(Level.WARN) Logger.getLogger("org.apache.sedona.core").setLevel(Level.WARN) + val keyParserExtension = "spark.sedona.enableParserExtension" val warehouseLocation = System.getProperty("user.dir") + "/target/" val sparkSession = SedonaContext .builder() @@ -38,6 +41,8 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { .config("spark.sql.warehouse.dir", warehouseLocation) // We need to be explicit about broadcasting in tests. .config("sedona.join.autoBroadcastJoinThreshold", "-1") + .config("spark.sql.extensions", "org.apache.sedona.sql.SedonaSqlExtensions") + .config(keyParserExtension, ThreadLocalRandom.current().nextBoolean()) .getOrCreate() val sparkSessionMinio = SedonaContext diff --git a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala index 6c70419122..56c27ba76b 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/sedona/sql/parser/SedonaSqlParser.scala @@ -35,12 +35,8 @@ class SedonaSqlParser(delegate: ParserInterface) extends SparkSqlParser { override def parsePlan(sqlText: String): LogicalPlan = try { parse(sqlText) { parser => - parserBuilder.visit(parser.singleStatement()) match { - case plan: LogicalPlan => plan - case _ => - delegate.parsePlan(sqlText) - } - } + parserBuilder.visit(parser.singleStatement()) + }.asInstanceOf[LogicalPlan] } catch { case _: Exception => delegate.parsePlan(sqlText) diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala index 72680aacd4..6f873d0a08 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/SQLSyntaxTestScala.scala @@ -44,14 +44,29 @@ class SQLSyntaxTestScala extends TestBaseScala with TableDrivenPropertyChecks { it( "should be able to create a regular table with geometry column should work without a workaround") { - sparkSession.sql("CREATE TABLE T_TEST_EXPLICIT_GEOMETRY (GEO_COL GEOMETRY)") - sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY") should be(true) + try { + sparkSession.sql("CREATE TABLE T_TEST_EXPLICIT_GEOMETRY (GEO_COL GEOMETRY)") + sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") + } catch { + case ex: Exception => + ex.getClass.getName.endsWith("ParseException") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") + } } it( "should be able to create a regular table with regular and geometry column should work without a workaround") { - sparkSession.sql("CREATE TABLE T_TEST_EXPLICIT_GEOMETRY_2 (INT_COL INT, GEO_COL GEOMETRY)") - sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY_2") should be(true) + try { + sparkSession.sql( + "CREATE TABLE T_TEST_EXPLICIT_GEOMETRY_2 (INT_COL INT, GEO_COL GEOMETRY)") + sparkSession.catalog.tableExists("T_TEST_EXPLICIT_GEOMETRY_2") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("true") + } catch { + case ex: Exception => + ex.getClass.getName.endsWith("ParseException") should be(true) + sparkSession.sparkContext.getConf.get(keyParserExtension) should be("false") + } } } } diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index 34746d0b28..ae1ed5d091 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -23,6 +23,8 @@ import org.apache.sedona.spark.SedonaContext import org.apache.spark.sql.DataFrame import org.scalatest.{BeforeAndAfterAll, FunSpec} +import java.util.concurrent.ThreadLocalRandom + trait TestBaseScala extends FunSpec with BeforeAndAfterAll { Logger.getRootLogger().setLevel(Level.WARN) Logger.getLogger("org.apache").setLevel(Level.WARN) @@ -30,6 +32,7 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { Logger.getLogger("akka").setLevel(Level.WARN) Logger.getLogger("org.apache.sedona.core").setLevel(Level.WARN) + val keyParserExtension = "spark.sedona.enableParserExtension" val warehouseLocation = System.getProperty("user.dir") + "/target/" val sparkSession = SedonaContext .builder() @@ -38,6 +41,8 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { .config("spark.sql.warehouse.dir", warehouseLocation) // We need to be explicit about broadcasting in tests. .config("sedona.join.autoBroadcastJoinThreshold", "-1") + .config("spark.sql.extensions", "org.apache.sedona.sql.SedonaSqlExtensions") + .config(keyParserExtension, ThreadLocalRandom.current().nextBoolean()) .getOrCreate() val sparkSessionMinio = SedonaContext