From 58c64b77c37cece5c539e4eaffa04e7d1d9071c8 Mon Sep 17 00:00:00 2001 From: zhangfengcdt Date: Thu, 2 Jan 2025 09:54:44 -0800 Subject: [PATCH] [SEDONA-688] Verify KNN parameter K must be equal or larger than 1 --- .../join/BroadcastObjectSideKNNJoinExec.scala | 2 +- .../join/BroadcastQuerySideKNNJoinExec.scala | 2 +- .../strategy/join/JoinQueryDetector.scala | 8 ++++++++ .../sedona_sql/strategy/join/KNNJoinExec.scala | 2 +- .../org/apache/sedona/sql/KnnJoinSuite.scala | 16 ++++++++++++++++ 5 files changed, 27 insertions(+), 3 deletions(-) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala index 1b21c79e7c..c5777be3c1 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala @@ -120,7 +120,7 @@ case class BroadcastObjectSideKNNJoinExec( sedonaConf: SedonaConf): Unit = { require(numPartitions > 0, "The number of partitions must be greater than 0.") val kValue: Int = this.k.eval().asInstanceOf[Int] - require(kValue > 0, "The number of neighbors must be greater than 0.") + require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.") objectsShapes.setNeighborSampleNumber(kValue) broadcastJoin = true } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala index 812bc6e6d6..001c0a1ca3 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala @@ -127,7 +127,7 @@ case class BroadcastQuerySideKNNJoinExec( sedonaConf: SedonaConf): Unit = { require(numPartitions > 0, "The number of partitions must be greater than 0.") val kValue: Int = this.k.eval().asInstanceOf[Int] - require(kValue > 0, "The number of neighbors must be greater than 0.") + require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.") objectsShapes.setNeighborSampleNumber(kValue) val joinPartitions: Integer = numPartitions diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala index 825855b88c..da9bd5359b 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala @@ -582,6 +582,10 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { return Nil } + // validate the k value + val kValue: Int = distance.eval().asInstanceOf[Int] + require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.") + val leftShape = children.head val rightShape = children.tail.head @@ -711,6 +715,10 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { if (spatialPredicate == SpatialPredicate.KNN) { { + // validate the k value for KNN join + val kValue: Int = distance.get.eval().asInstanceOf[Int] + require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.") + val leftShape = children.head val rightShape = children.tail.head diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala index 2b9bbfb50b..fdc53d13ce 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala @@ -162,7 +162,7 @@ case class KNNJoinExec( sedonaConf: SedonaConf): Unit = { require(numPartitions > 0, "The number of partitions must be greater than 0.") val kValue: Int = this.k.eval().asInstanceOf[Int] - require(kValue > 0, "The number of neighbors must be greater than 0.") + require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.") objectsShapes.setNeighborSampleNumber(kValue) exactSpatialPartitioning(objectsShapes, queryShapes, numPartitions) diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala index 1d6119d02d..f3b07c2501 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala @@ -209,6 +209,22 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks { "[1,3][1,6][1,13][1,16][2,1][2,5][2,11][2,15][3,3][3,9][3,13][3,19]") } + it("KNN Join should verify the correct parameter k is passed to the join function") { + val df = sparkSession + .range(0, 1) + .toDF("id") + .withColumn("geom", expr("ST_Point(id, id)")) + .repartition(1) + df.createOrReplaceTempView("df1") + val exception = intercept[IllegalArgumentException] { + sparkSession + .sql(s"SELECT A.ID, B.ID FROM df1 A JOIN df1 B ON ST_KNN(A.GEOM, B.GEOM, 0, false)") + .collect() + } + exception.getMessage should include( + "The number of neighbors (k) must be equal or greater than 1.") + } + it("KNN Join with exact algorithms with additional join conditions on id") { val df = sparkSession.sql( s"SELECT QUERIES.ID, OBJECTS.ID FROM QUERIES JOIN OBJECTS ON ST_KNN(QUERIES.GEOM, OBJECTS.GEOM, 4, false) AND QUERIES.ID > 1")