diff --git a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java index 1c7fe7a0ae..f5375009ed 100644 --- a/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java +++ b/spark/common/src/main/java/org/apache/sedona/core/joinJudgement/KnnJoinIndexJudgement.java @@ -25,6 +25,7 @@ import org.apache.sedona.core.knnJudgement.EuclideanItemDistance; import org.apache.sedona.core.knnJudgement.HaversineItemDistance; import org.apache.sedona.core.knnJudgement.SpheroidDistance; +import org.apache.sedona.core.wrapper.UniqueGeometry; import org.apache.spark.api.java.function.FlatMapFunction2; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.util.LongAccumulator; @@ -46,35 +47,43 @@ public class KnnJoinIndexJudgement extends JudgementBase implements FlatMapFunction2, Iterator, Pair>, Serializable { private final int k; + private final Double searchRadius; private final DistanceMetric distanceMetric; private final boolean includeTies; - private final Broadcast broadcastedTreeIndex; + private final Broadcast broadcastQueryObjects; + private final Broadcast broadcastObjectsTreeIndex; /** * Constructor for the KnnJoinIndexJudgement class. * * @param k the number of nearest neighbors to find + * @param searchRadius * @param distanceMetric the distance metric to use + * @param broadcastQueryObjects the broadcast geometries on queries + * @param broadcastObjectsTreeIndex the broadcast spatial index on objects * @param buildCount accumulator for the number of geometries processed from the build side * @param streamCount accumulator for the number of geometries processed from the stream side * @param resultCount accumulator for the number of join results * @param candidateCount accumulator for the number of candidate matches - * @param broadcastedTreeIndex the broadcasted spatial index */ public KnnJoinIndexJudgement( int k, + Double searchRadius, DistanceMetric distanceMetric, boolean includeTies, - Broadcast broadcastedTreeIndex, + Broadcast broadcastQueryObjects, + Broadcast broadcastObjectsTreeIndex, LongAccumulator buildCount, LongAccumulator streamCount, LongAccumulator resultCount, LongAccumulator candidateCount) { super(null, buildCount, streamCount, resultCount, candidateCount); this.k = k; + this.searchRadius = searchRadius; this.distanceMetric = distanceMetric; this.includeTies = includeTies; - this.broadcastedTreeIndex = broadcastedTreeIndex; + this.broadcastQueryObjects = broadcastQueryObjects; + this.broadcastObjectsTreeIndex = broadcastObjectsTreeIndex; } /** @@ -90,7 +99,7 @@ public KnnJoinIndexJudgement( @Override public Iterator> call(Iterator streamShapes, Iterator treeIndexes) throws Exception { - if (!treeIndexes.hasNext() || !streamShapes.hasNext()) { + if (!treeIndexes.hasNext() || (streamShapes != null && !streamShapes.hasNext())) { buildCount.add(0); streamCount.add(0); resultCount.add(0); @@ -99,10 +108,9 @@ public Iterator> call(Iterator streamShapes, Iterator> call(Iterator streamShapes, Iterator> result = new ArrayList<>(); - ItemDistance itemDistance; - while (streamShapes.hasNext()) { - T streamShape = streamShapes.next(); - streamCount.add(1); - - Object[] localK; - switch (distanceMetric) { - case EUCLIDEAN: - itemDistance = new EuclideanItemDistance(); - break; - case HAVERSINE: - itemDistance = new HaversineItemDistance(); - break; - case SPHEROID: - itemDistance = new SpheroidDistance(); - break; - default: - itemDistance = new GeometryItemDistance(); - break; - } + List queryItems; + if (broadcastQueryObjects != null) { + // get the broadcast spatial index on queries side if available + queryItems = broadcastQueryObjects.getValue(); + for (Object item : queryItems) { + T queryGeom; + if (item instanceof UniqueGeometry) { + queryGeom = (T) ((UniqueGeometry) item).getOriginalGeometry(); + } else { + queryGeom = (T) item; + } + streamCount.add(1); - localK = - strTree.nearestNeighbour(streamShape.getEnvelopeInternal(), streamShape, itemDistance, k); - if (includeTies) { - localK = getUpdatedLocalKWithTies(streamShape, localK, strTree); + Object[] localK = + strTree.nearestNeighbour( + queryGeom.getEnvelopeInternal(), queryGeom, getItemDistance(), k); + if (includeTies) { + localK = getUpdatedLocalKWithTies(queryGeom, localK, strTree); + } + if (searchRadius != null) { + localK = getInSearchRadius(localK, queryGeom); + } + + for (Object obj : localK) { + T candidate = (T) obj; + Pair pair = Pair.of((U) item, candidate); + result.add(pair); + resultCount.add(1); + } } + return result.iterator(); + } else { + while (streamShapes.hasNext()) { + T streamShape = streamShapes.next(); + streamCount.add(1); + + Object[] localK = + strTree.nearestNeighbour( + streamShape.getEnvelopeInternal(), streamShape, getItemDistance(), k); + if (includeTies) { + localK = getUpdatedLocalKWithTies(streamShape, localK, strTree); + } + if (searchRadius != null) { + localK = getInSearchRadius(localK, streamShape); + } - for (Object obj : localK) { - T candidate = (T) obj; - Pair pair = Pair.of((U) streamShape, candidate); - result.add(pair); - resultCount.add(1); + for (Object obj : localK) { + T candidate = (T) obj; + Pair pair = Pair.of((U) streamShape, candidate); + result.add(pair); + resultCount.add(1); + } } + return result.iterator(); } + } - return result.iterator(); + private Object[] getInSearchRadius(Object[] localK, T queryGeom) { + localK = + Arrays.stream(localK) + .filter( + candidate -> { + Geometry candidateGeom = (Geometry) candidate; + return distanceByMetric(queryGeom, candidateGeom, distanceMetric) <= searchRadius; + }) + .toArray(); + return localK; + } + + /** + * This method calculates the distance between two geometries using the specified distance metric. + * + * @param queryGeom the query geometry + * @param candidateGeom the candidate geometry + * @param distanceMetric the distance metric to use + * @return the distance between the two geometries + */ + public static double distanceByMetric( + Geometry queryGeom, Geometry candidateGeom, DistanceMetric distanceMetric) { + switch (distanceMetric) { + case EUCLIDEAN: + EuclideanItemDistance euclideanItemDistance = new EuclideanItemDistance(); + return euclideanItemDistance.distance(queryGeom, candidateGeom); + case HAVERSINE: + HaversineItemDistance haversineItemDistance = new HaversineItemDistance(); + return haversineItemDistance.distance(queryGeom, candidateGeom); + case SPHEROID: + SpheroidDistance spheroidDistance = new SpheroidDistance(); + return spheroidDistance.distance(queryGeom, candidateGeom); + default: + return queryGeom.distance(candidateGeom); + } + } + + private ItemDistance getItemDistance() { + ItemDistance itemDistance; + itemDistance = getItemDistanceByMetric(distanceMetric); + return itemDistance; + } + + /** + * This method returns the ItemDistance object based on the specified distance metric. + * + * @param distanceMetric the distance metric to use + * @return the ItemDistance object + */ + public static ItemDistance getItemDistanceByMetric(DistanceMetric distanceMetric) { + ItemDistance itemDistance; + switch (distanceMetric) { + case EUCLIDEAN: + itemDistance = new EuclideanItemDistance(); + break; + case HAVERSINE: + itemDistance = new HaversineItemDistance(); + break; + case SPHEROID: + itemDistance = new SpheroidDistance(); + break; + default: + itemDistance = new GeometryItemDistance(); + break; + } + return itemDistance; } private Object[] getUpdatedLocalKWithTies(T streamShape, Object[] localK, STRtree strTree) { @@ -184,4 +281,18 @@ private Object[] getUpdatedLocalKWithTies(T streamShape, Object[] localK, STRtre } return localK; } + + public static double distance( + U key, T value, DistanceMetric distanceMetric) { + switch (distanceMetric) { + case EUCLIDEAN: + return new EuclideanItemDistance().distance(key, value); + case HAVERSINE: + return new HaversineItemDistance().distance(key, value); + case SPHEROID: + return new SpheroidDistance().distance(key, value); + default: + return new EuclideanItemDistance().distance(key, value); + } + } } diff --git a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/EuclideanItemDistance.java b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/EuclideanItemDistance.java index a27bf543b1..1aba8f87f7 100644 --- a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/EuclideanItemDistance.java +++ b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/EuclideanItemDistance.java @@ -36,4 +36,12 @@ public double distance(ItemBoundable item1, ItemBoundable item2) { return g1.distance(g2); } } + + public double distance(Geometry geometry1, Geometry geometry2) { + if (geometry1 == geometry2) { + return Double.MAX_VALUE; + } else { + return geometry1.distance(geometry2); + } + } } diff --git a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/HaversineItemDistance.java b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/HaversineItemDistance.java index 9ad1bfbee4..b04627074e 100644 --- a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/HaversineItemDistance.java +++ b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/HaversineItemDistance.java @@ -37,4 +37,12 @@ public double distance(ItemBoundable item1, ItemBoundable item2) { return Haversine.distance(g1, g2); } } + + public double distance(Geometry geometry1, Geometry geometry2) { + if (geometry1 == geometry2) { + return Double.MAX_VALUE; + } else { + return Haversine.distance(geometry1, geometry2); + } + } } diff --git a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/SpheroidDistance.java b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/SpheroidDistance.java index df22d3565e..4ecdbf84c6 100644 --- a/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/SpheroidDistance.java +++ b/spark/common/src/main/java/org/apache/sedona/core/knnJudgement/SpheroidDistance.java @@ -37,4 +37,12 @@ public double distance(ItemBoundable item1, ItemBoundable item2) { return Spheroid.distance(g1, g2); } } + + public double distance(Geometry geometry1, Geometry geometry2) { + if (geometry1 == geometry2) { + return Double.MAX_VALUE; + } else { + return Spheroid.distance(geometry1, geometry2); + } + } } diff --git a/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java b/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java index d20563d279..a5665726e0 100644 --- a/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java +++ b/spark/common/src/main/java/org/apache/sedona/core/spatialOperator/JoinQuery.java @@ -18,10 +18,7 @@ */ package org.apache.sedona.core.spatialOperator; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Objects; +import java.util.*; import org.apache.commons.lang3.tuple.Pair; import org.apache.log4j.LogManager; import org.apache.log4j.Logger; @@ -35,15 +32,18 @@ import org.apache.sedona.core.spatialPartitioning.SpatialPartitioner; import org.apache.sedona.core.spatialRDD.CircleRDD; import org.apache.sedona.core.spatialRDD.SpatialRDD; +import org.apache.sedona.core.wrapper.UniqueGeometry; import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.util.LongAccumulator; import org.locationtech.jts.geom.Geometry; +import org.locationtech.jts.index.SpatialIndex; import org.locationtech.jts.index.strtree.STRtree; import scala.Tuple2; @@ -784,47 +784,82 @@ public static JavaPairRDD knnJoin LongAccumulator resultCount = Metrics.createMetric(sparkContext, "resultCount"); LongAccumulator candidateCount = Metrics.createMetric(sparkContext, "candidateCount"); - final Broadcast broadcastedTreeIndex; - if (broadcastJoin) { - // adjust auto broadcast threshold to avoid building index on large RDDs + final Broadcast broadcastObjectsTreeIndex; + final Broadcast broadcastQueryObjects; + if (broadcastJoin && objectRDD.indexedRawRDD != null && objectRDD.indexedRDD == null) { + // If broadcastJoin is true and rawIndex is created on object side + // we will broadcast queryRDD to objectRDD + List> uniqueQueryObjects = new ArrayList<>(); + for (U queryObject : queryRDD.rawSpatialRDD.collect()) { + // Wrap the query objects in a UniqueGeometry object to count for duplicate queries in the + // join + uniqueQueryObjects.add(new UniqueGeometry<>(queryObject)); + } + broadcastQueryObjects = + JavaSparkContext.fromSparkContext(sparkContext).broadcast(uniqueQueryObjects); + broadcastObjectsTreeIndex = null; + } else if (broadcastJoin && objectRDD.indexedRawRDD == null && objectRDD.indexedRDD == null) { + // If broadcastJoin is true and index and rawIndex are NOT created on object side + // we will broadcast objectRDD to queryRDD STRtree strTree = objectRDD.coalesceAndBuildRawIndex(IndexType.RTREE); - broadcastedTreeIndex = JavaSparkContext.fromSparkContext(sparkContext).broadcast(strTree); + broadcastObjectsTreeIndex = + JavaSparkContext.fromSparkContext(sparkContext).broadcast(strTree); + broadcastQueryObjects = null; } else { - broadcastedTreeIndex = null; + // Regular join does not need to set broadcast inderx + broadcastQueryObjects = null; + broadcastObjectsTreeIndex = null; } // The reason for using objectRDD as the right side is that the partitions are built on the // right side. final JavaRDD> joinResult; - if (objectRDD.indexedRDD != null) { + if (broadcastObjectsTreeIndex == null && broadcastQueryObjects == null) { + // no broadcast join final KnnJoinIndexJudgement judgement = new KnnJoinIndexJudgement( joinParams.k, + joinParams.searchRadius, joinParams.distanceMetric, includeTies, - broadcastedTreeIndex, + null, + null, buildCount, streamCount, resultCount, candidateCount); joinResult = queryRDD.spatialPartitionedRDD.zipPartitions(objectRDD.indexedRDD, judgement); - } else if (broadcastedTreeIndex != null) { + } else if (broadcastObjectsTreeIndex != null) { + // broadcast join with objectRDD as broadcast side final KnnJoinIndexJudgement judgement = new KnnJoinIndexJudgement( joinParams.k, + joinParams.searchRadius, joinParams.distanceMetric, includeTies, - broadcastedTreeIndex, + null, + broadcastObjectsTreeIndex, buildCount, streamCount, resultCount, candidateCount); - int numPartitionsObjects = objectRDD.rawSpatialRDD.getNumPartitions(); - joinResult = - queryRDD - .rawSpatialRDD - .repartition(numPartitionsObjects) - .zipPartitions(objectRDD.rawSpatialRDD, judgement); + // won't need inputs from the shapes in the objectRDD + joinResult = queryRDD.rawSpatialRDD.zipPartitions(queryRDD.rawSpatialRDD, judgement); + } else if (broadcastQueryObjects != null) { + // broadcast join with queryRDD as broadcast side + final KnnJoinIndexJudgement judgement = + new KnnJoinIndexJudgement( + joinParams.k, + joinParams.searchRadius, + joinParams.distanceMetric, + includeTies, + broadcastQueryObjects, + null, + buildCount, + streamCount, + resultCount, + candidateCount); + joinResult = querySideBroadcastKNNJoin(objectRDD, joinParams, judgement, includeTies); } else { throw new IllegalArgumentException("No index found on the input RDDs."); } @@ -833,6 +868,123 @@ public static JavaPairRDD knnJoin (PairFunction, U, T>) pair -> new Tuple2<>(pair.getKey(), pair.getValue())); } + /** + * Performs a KNN join where the query side is broadcasted. + * + *

This function performs a K-Nearest Neighbors (KNN) join operation where the query geometries + * are broadcasted to all partitions of the object geometries. + * + *

The function first maps partitions of the indexed raw RDD to perform the KNN join, then + * groups the results by the query geometry and keeps the top K pair for each query geometry based + * on the distance. + * + * @param objectRDD The set of geometries (neighbors) to be queried. + * @param joinParams The parameters for the join, including index type, number of neighbors (k), + * and distance metric. + * @param judgement The judgement function used to perform the KNN join. + * @param The type of the geometries in the queryRDD set. + * @param The type of the geometries in the objectRDD set. + * @return A JavaRDD of pairs where each pair contains a geometry from the queryRDD and a matching + * geometry from the objectRDD. + */ + private static + JavaRDD> querySideBroadcastKNNJoin( + SpatialRDD objectRDD, + JoinParams joinParams, + KnnJoinIndexJudgement judgement, + boolean includeTies) { + final JavaRDD> joinResult; + JavaRDD> joinResultMapped = + objectRDD.indexedRawRDD.mapPartitions( + iterator -> { + List> results = new ArrayList<>(); + if (iterator.hasNext()) { + SpatialIndex spatialIndex = iterator.next(); + // the broadcast join won't need inputs from the query's shape stream + Iterator> callResult = + judgement.call(null, Collections.singletonList(spatialIndex).iterator()); + callResult.forEachRemaining(results::add); + } + return results.iterator(); + }); + // this is to avoid serializable issues with the broadcast variable + int k = joinParams.k; + DistanceMetric distanceMetric = joinParams.distanceMetric; + + // Transform joinResultMapped to keep the top k pairs for each geometry + // (based on a grouping key and distance) + joinResult = + joinResultMapped + .groupBy(pair -> pair.getKey()) // Group by the first geometry + .flatMap( + (FlatMapFunction>>, Pair>) + pair -> { + Iterable> values = pair._2; + + // Extract and sort values by distance + List> sortedPairs = new ArrayList<>(); + for (Pair p : values) { + Pair newPair = + Pair.of( + (U) ((UniqueGeometry) p.getKey()).getOriginalGeometry(), + p.getValue()); + sortedPairs.add(newPair); + } + + // Sort pairs based on the distance function between the two geometries + sortedPairs.sort( + (p1, p2) -> { + double distance1 = + KnnJoinIndexJudgement.distance( + p1.getKey(), p1.getValue(), distanceMetric); + double distance2 = + KnnJoinIndexJudgement.distance( + p2.getKey(), p2.getValue(), distanceMetric); + return Double.compare( + distance1, distance2); // Sort ascending by distance + }); + + if (includeTies) { + // Keep the top k pairs, including ties + List> topPairs = new ArrayList<>(); + double kthDistance = -1; + for (int i = 0; i < sortedPairs.size(); i++) { + if (i < k) { + topPairs.add(sortedPairs.get(i)); + if (i == k - 1) { + kthDistance = + KnnJoinIndexJudgement.distance( + sortedPairs.get(i).getKey(), + sortedPairs.get(i).getValue(), + distanceMetric); + } + } else { + double currentDistance = + KnnJoinIndexJudgement.distance( + sortedPairs.get(i).getKey(), + sortedPairs.get(i).getValue(), + distanceMetric); + if (currentDistance == kthDistance) { + topPairs.add(sortedPairs.get(i)); + } else { + break; + } + } + } + return topPairs.iterator(); + } else { + // Keep the top k pairs without ties + List> topPairs = new ArrayList<>(); + for (int i = 0; i < Math.min(k, sortedPairs.size()); i++) { + topPairs.add(sortedPairs.get(i)); + } + return topPairs.iterator(); + } + }); + + return joinResult; + } + public static final class JoinParams { public final boolean useIndex; public final SpatialPredicate spatialPredicate; diff --git a/spark/common/src/main/java/org/apache/sedona/core/wrapper/UniqueGeometry.java b/spark/common/src/main/java/org/apache/sedona/core/wrapper/UniqueGeometry.java new file mode 100644 index 0000000000..01f20f2fa6 --- /dev/null +++ b/spark/common/src/main/java/org/apache/sedona/core/wrapper/UniqueGeometry.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sedona.core.wrapper; + +import java.util.UUID; +import org.apache.commons.lang3.NotImplementedException; +import org.locationtech.jts.geom.*; + +public class UniqueGeometry extends Geometry { + private final T originalGeometry; + private final String uniqueId; + + public UniqueGeometry(T originalGeometry) { + super(new GeometryFactory()); + this.originalGeometry = originalGeometry; + this.uniqueId = UUID.randomUUID().toString(); + } + + public T getOriginalGeometry() { + return originalGeometry; + } + + public String getUniqueId() { + return uniqueId; + } + + @Override + public int hashCode() { + return uniqueId.hashCode(); // Uniqueness ensured by uniqueId + } + + @Override + public String getGeometryType() { + throw new NotImplementedException("getGeometryType is not implemented."); + } + + @Override + public Coordinate getCoordinate() { + throw new NotImplementedException("getCoordinate is not implemented."); + } + + @Override + public Coordinate[] getCoordinates() { + throw new NotImplementedException("getCoordinates is not implemented."); + } + + @Override + public int getNumPoints() { + throw new NotImplementedException("getNumPoints is not implemented."); + } + + @Override + public boolean isEmpty() { + throw new NotImplementedException("isEmpty is not implemented."); + } + + @Override + public int getDimension() { + throw new NotImplementedException("getDimension is not implemented."); + } + + @Override + public Geometry getBoundary() { + throw new NotImplementedException("getBoundary is not implemented."); + } + + @Override + public int getBoundaryDimension() { + throw new NotImplementedException("getBoundaryDimension is not implemented."); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + UniqueGeometry that = (UniqueGeometry) obj; + return uniqueId.equals(that.uniqueId); + } + + @Override + public String toString() { + return "UniqueGeometry{" + + "originalGeometry=" + + originalGeometry + + ", uniqueId='" + + uniqueId + + '\'' + + '}'; + } + + @Override + protected Geometry reverseInternal() { + throw new NotImplementedException("reverseInternal is not implemented."); + } + + @Override + public boolean equalsExact(Geometry geometry, double v) { + throw new NotImplementedException("equalsExact is not implemented."); + } + + @Override + public void apply(CoordinateFilter coordinateFilter) { + throw new NotImplementedException("apply(CoordinateFilter) is not implemented."); + } + + @Override + public void apply(CoordinateSequenceFilter coordinateSequenceFilter) { + throw new NotImplementedException("apply(CoordinateSequenceFilter) is not implemented."); + } + + @Override + public void apply(GeometryFilter geometryFilter) { + throw new NotImplementedException("apply(GeometryFilter) is not implemented."); + } + + @Override + public void apply(GeometryComponentFilter geometryComponentFilter) { + throw new NotImplementedException("apply(GeometryComponentFilter) is not implemented."); + } + + @Override + protected Geometry copyInternal() { + throw new NotImplementedException("copyInternal is not implemented."); + } + + @Override + public void normalize() { + throw new NotImplementedException("normalize is not implemented."); + } + + @Override + protected Envelope computeEnvelopeInternal() { + throw new NotImplementedException("computeEnvelopeInternal is not implemented."); + } + + @Override + protected int compareToSameClass(Object o) { + throw new NotImplementedException("compareToSameClass(Object) is not implemented."); + } + + @Override + protected int compareToSameClass( + Object o, CoordinateSequenceComparator coordinateSequenceComparator) { + throw new NotImplementedException( + "compareToSameClass(Object, CoordinateSequenceComparator) is not implemented."); + } + + @Override + protected int getTypeCode() { + throw new NotImplementedException("getTypeCode is not implemented."); + } +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala index 1b21c79e7c..c5777be3c1 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastObjectSideKNNJoinExec.scala @@ -120,7 +120,7 @@ case class BroadcastObjectSideKNNJoinExec( sedonaConf: SedonaConf): Unit = { require(numPartitions > 0, "The number of partitions must be greater than 0.") val kValue: Int = this.k.eval().asInstanceOf[Int] - require(kValue > 0, "The number of neighbors must be greater than 0.") + require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.") objectsShapes.setNeighborSampleNumber(kValue) broadcastJoin = true } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala index 812bc6e6d6..9ce40c6d42 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/BroadcastQuerySideKNNJoinExec.scala @@ -127,22 +127,13 @@ case class BroadcastQuerySideKNNJoinExec( sedonaConf: SedonaConf): Unit = { require(numPartitions > 0, "The number of partitions must be greater than 0.") val kValue: Int = this.k.eval().asInstanceOf[Int] - require(kValue > 0, "The number of neighbors must be greater than 0.") + require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.") objectsShapes.setNeighborSampleNumber(kValue) - val joinPartitions: Integer = numPartitions - broadcastJoin = false - - // expand the boundary for partition to include both RDDs - objectsShapes.analyze() - queryShapes.analyze() - objectsShapes.boundaryEnvelope.expandToInclude(queryShapes.boundaryEnvelope) - - objectsShapes.spatialPartitioning(GridType.QUADTREE_RTREE, joinPartitions) - queryShapes.spatialPartitioning( - objectsShapes.getPartitioner.asInstanceOf[QuadTreeRTPartitioner].nonOverlappedPartitioner()) - - objectsShapes.buildIndex(IndexType.RTREE, true) + // index the objects on regular partitions (not spatial partitions) + // this avoids the cost of spatial partitioning + objectsShapes.buildIndex(IndexType.RTREE, false) + broadcastJoin = true } /** diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala index 825855b88c..b89b1adeda 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala @@ -582,10 +582,21 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { return Nil } + // validate the k value + val kValue: Int = distance.eval().asInstanceOf[Int] + require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.") + val leftShape = children.head val rightShape = children.tail.head - val querySide = getKNNQuerySide(left, leftShape) + val querySide = matchExpressionsToPlans(leftShape, rightShape, left, right) match { + case Some((_, _, false)) => + LeftSide + case Some((_, _, true)) => + RightSide + case None => + Nil + } val objectSidePlan = if (querySide == LeftSide) right else left checkObjectPlanFilterPushdown(objectSidePlan) @@ -711,10 +722,21 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { if (spatialPredicate == SpatialPredicate.KNN) { { + // validate the k value for KNN join + val kValue: Int = distance.get.eval().asInstanceOf[Int] + require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.") + val leftShape = children.head val rightShape = children.tail.head - val querySide = getKNNQuerySide(left, leftShape) + val querySide = matchExpressionsToPlans(leftShape, rightShape, left, right) match { + case Some((_, _, false)) => + LeftSide + case Some((_, _, true)) => + RightSide + case None => + Nil + } val objectSidePlan = if (querySide == LeftSide) right else left checkObjectPlanFilterPushdown(objectSidePlan) @@ -731,7 +753,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { k = distance.get, useApproximate = false, spatialPredicate, - isGeography = false, + isGeography, condition = null, extraCondition = None) :: Nil } else { @@ -746,7 +768,7 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { k = distance.get, useApproximate = false, spatialPredicate, - isGeography = false, + isGeography, condition = null, extraCondition = None) :: Nil } @@ -857,27 +879,6 @@ class JoinQueryDetector(sparkSession: SparkSession) extends Strategy { } } - /** - * Gets the query and object plans based on the left shape. - * - * This method checks if the left shape is part of the left or right plan and returns the query - * and object plans accordingly. - * - * @param leftShape - * The left shape expression. - * @return - * The join side where the left shape is located. - */ - private def getKNNQuerySide(left: LogicalPlan, leftShape: Expression) = { - val isLeftQuerySide = - left.toString().toLowerCase().contains(leftShape.toString().toLowerCase()) - if (isLeftQuerySide) { - LeftSide - } else { - RightSide - } - } - /** * Check if the given condition is an equi-join between the given plans. This method basically * replicates the logic of diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala index 2b9bbfb50b..fdc53d13ce 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/KNNJoinExec.scala @@ -162,7 +162,7 @@ case class KNNJoinExec( sedonaConf: SedonaConf): Unit = { require(numPartitions > 0, "The number of partitions must be greater than 0.") val kValue: Int = this.k.eval().asInstanceOf[Int] - require(kValue > 0, "The number of neighbors must be greater than 0.") + require(kValue >= 1, "The number of neighbors (k) must be equal or greater than 1.") objectsShapes.setNeighborSampleNumber(kValue) exactSpatialPartitioning(objectsShapes, queryShapes, numPartitions) diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala index 1d6119d02d..f3b07c2501 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/KnnJoinSuite.scala @@ -209,6 +209,22 @@ class KnnJoinSuite extends TestBaseScala with TableDrivenPropertyChecks { "[1,3][1,6][1,13][1,16][2,1][2,5][2,11][2,15][3,3][3,9][3,13][3,19]") } + it("KNN Join should verify the correct parameter k is passed to the join function") { + val df = sparkSession + .range(0, 1) + .toDF("id") + .withColumn("geom", expr("ST_Point(id, id)")) + .repartition(1) + df.createOrReplaceTempView("df1") + val exception = intercept[IllegalArgumentException] { + sparkSession + .sql(s"SELECT A.ID, B.ID FROM df1 A JOIN df1 B ON ST_KNN(A.GEOM, B.GEOM, 0, false)") + .collect() + } + exception.getMessage should include( + "The number of neighbors (k) must be equal or greater than 1.") + } + it("KNN Join with exact algorithms with additional join conditions on id") { val df = sparkSession.sql( s"SELECT QUERIES.ID, OBJECTS.ID FROM QUERIES JOIN OBJECTS ON ST_KNN(QUERIES.GEOM, OBJECTS.GEOM, 4, false) AND QUERIES.ID > 1")