Skip to content

Commit

Permalink
[SEDONA-688] Verify KNN parameter K must be equal or larger than 1
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangfengcdt committed Jan 2, 2025
1 parent 31319b0 commit 58c64b7
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ case class BroadcastObjectSideKNNJoinExec(
sedonaConf: SedonaConf): Unit = {
require(numPartitions > 0, "The number of partitions must be greater than 0.")
val kValue: Int = this.k.eval().asInstanceOf[Int]
require(kValue > 0, "The number of neighbors must be greater than 0.")
require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.")
objectsShapes.setNeighborSampleNumber(kValue)
broadcastJoin = true
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ case class BroadcastQuerySideKNNJoinExec(
sedonaConf: SedonaConf): Unit = {
require(numPartitions > 0, "The number of partitions must be greater than 0.")
val kValue: Int = this.k.eval().asInstanceOf[Int]
require(kValue > 0, "The number of neighbors must be greater than 0.")
require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.")
objectsShapes.setNeighborSampleNumber(kValue)

val joinPartitions: Integer = numPartitions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,10 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {
return Nil
}

// validate the k value
val kValue: Int = distance.eval().asInstanceOf[Int]
require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.")

val leftShape = children.head
val rightShape = children.tail.head

Expand Down Expand Up @@ -711,6 +715,10 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy {

if (spatialPredicate == SpatialPredicate.KNN) {
{
// validate the k value for KNN join
val kValue: Int = distance.get.eval().asInstanceOf[Int]
require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.")

val leftShape = children.head
val rightShape = children.tail.head

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ case class KNNJoinExec(
sedonaConf: SedonaConf): Unit = {
require(numPartitions > 0, "The number of partitions must be greater than 0.")
val kValue: Int = this.k.eval().asInstanceOf[Int]
require(kValue > 0, "The number of neighbors must be greater than 0.")
require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.")
objectsShapes.setNeighborSampleNumber(kValue)

exactSpatialPartitioning(objectsShapes, queryShapes, numPartitions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,22 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks {
"[1,3][1,6][1,13][1,16][2,1][2,5][2,11][2,15][3,3][3,9][3,13][3,19]")
}

it("KNN Join should verify the correct parameter k is passed to the join function") {
val df = sparkSession
.range(0, 1)
.toDF("id")
.withColumn("geom", expr("ST_Point(id, id)"))
.repartition(1)
df.createOrReplaceTempView("df1")
val exception = intercept[IllegalArgumentException] {
sparkSession
.sql(s"SELECT A.ID, B.ID FROM df1 A JOIN df1 B ON ST_KNN(A.GEOM, B.GEOM, 0, false)")
.collect()
}
exception.getMessage should include(
"The number of neighbors (k) must be equal or greater than 1.")
}

it("KNN Join with exact algorithms with additional join conditions on id") {
val df = sparkSession.sql(
s"SELECT QUERIES.ID, OBJECTS.ID FROM QUERIES JOIN OBJECTS ON ST_KNN(QUERIES.GEOM, OBJECTS.GEOM, 4, false) AND QUERIES.ID > 1")
Expand Down

0 comments on commit 58c64b7

Please sign in to comment.