Skip to content

Commit

Permalink
[SEDONA-690] Optimize query side broadcast knn join (#1741)
Browse files Browse the repository at this point in the history
* [SEDONA-688] Verify KNN parameter K must be equal or larger than 1

* [SEDONA-690] Optimize query side broadcast knn join

* fix isGeography parameter
  • Loading branch information
zhangfengcdt authored Jan 5, 2025
1 parent 9497a07 commit af74a17
Show file tree
Hide file tree
Showing 8 changed files with 535 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.sedona.core.knnJudgement.EuclideanItemDistance;
import org.apache.sedona.core.knnJudgement.HaversineItemDistance;
import org.apache.sedona.core.knnJudgement.SpheroidDistance;
import org.apache.sedona.core.wrapper.UniqueGeometry;
import org.apache.spark.api.java.function.FlatMapFunction2;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.util.LongAccumulator;
Expand All @@ -46,35 +47,43 @@ public class KnnJoinIndexJudgement<T extends Geometry, U extends Geometry>
extends JudgementBase<T, U>
implements FlatMapFunction2<Iterator<T>, Iterator<SpatialIndex>, Pair<U, T>>, Serializable {
private final int k;
private final Double searchRadius;
private final DistanceMetric distanceMetric;
private final boolean includeTies;
private final Broadcast<STRtree> broadcastedTreeIndex;
private final Broadcast<List> broadcastQueryObjects;
private final Broadcast<STRtree> broadcastObjectsTreeIndex;

/**
* Constructor for the KnnJoinIndexJudgement class.
*
* @param k the number of nearest neighbors to find
* @param searchRadius
* @param distanceMetric the distance metric to use
* @param broadcastQueryObjects the broadcast geometries on queries
* @param broadcastObjectsTreeIndex the broadcast spatial index on objects
* @param buildCount accumulator for the number of geometries processed from the build side
* @param streamCount accumulator for the number of geometries processed from the stream side
* @param resultCount accumulator for the number of join results
* @param candidateCount accumulator for the number of candidate matches
* @param broadcastedTreeIndex the broadcasted spatial index
*/
public KnnJoinIndexJudgement(
int k,
Double searchRadius,
DistanceMetric distanceMetric,
boolean includeTies,
Broadcast<STRtree> broadcastedTreeIndex,
Broadcast<List> broadcastQueryObjects,
Broadcast<STRtree> broadcastObjectsTreeIndex,
LongAccumulator buildCount,
LongAccumulator streamCount,
LongAccumulator resultCount,
LongAccumulator candidateCount) {
super(null, buildCount, streamCount, resultCount, candidateCount);
this.k = k;
this.searchRadius = searchRadius;
this.distanceMetric = distanceMetric;
this.includeTies = includeTies;
this.broadcastedTreeIndex = broadcastedTreeIndex;
this.broadcastQueryObjects = broadcastQueryObjects;
this.broadcastObjectsTreeIndex = broadcastObjectsTreeIndex;
}

/**
Expand All @@ -90,7 +99,7 @@ public KnnJoinIndexJudgement(
@Override
public Iterator<Pair<U, T>> call(Iterator<T> streamShapes, Iterator<SpatialIndex> treeIndexes)
throws Exception {
if (!treeIndexes.hasNext() || !streamShapes.hasNext()) {
if (!treeIndexes.hasNext() || (streamShapes != null && !streamShapes.hasNext())) {
buildCount.add(0);
streamCount.add(0);
resultCount.add(0);
Expand All @@ -99,10 +108,9 @@ public Iterator<Pair<U, T>> call(Iterator<T> streamShapes, Iterator<SpatialIndex
}

STRtree strTree;
if (broadcastedTreeIndex != null) {
// get the broadcasted spatial index if available
// this is to support the broadcast join
strTree = broadcastedTreeIndex.getValue();
if (broadcastObjectsTreeIndex != null) {
// get the broadcast spatial index on objects side if available
strTree = broadcastObjectsTreeIndex.getValue();
} else {
// get the spatial index from the iterator
SpatialIndex treeIndex = treeIndexes.next();
Expand All @@ -113,44 +121,133 @@ public Iterator<Pair<U, T>> call(Iterator<T> streamShapes, Iterator<SpatialIndex
strTree = (STRtree) treeIndex;
}

// TODO: For future improvement, instead of using a list to store the results,
// we can use lazy evaluation to avoid storing all the results in memory.
List<Pair<U, T>> result = new ArrayList<>();
ItemDistance itemDistance;

while (streamShapes.hasNext()) {
T streamShape = streamShapes.next();
streamCount.add(1);

Object[] localK;
switch (distanceMetric) {
case EUCLIDEAN:
itemDistance = new EuclideanItemDistance();
break;
case HAVERSINE:
itemDistance = new HaversineItemDistance();
break;
case SPHEROID:
itemDistance = new SpheroidDistance();
break;
default:
itemDistance = new GeometryItemDistance();
break;
}
List queryItems;
if (broadcastQueryObjects != null) {
// get the broadcast spatial index on queries side if available
queryItems = broadcastQueryObjects.getValue();
for (Object item : queryItems) {
T queryGeom;
if (item instanceof UniqueGeometry) {
queryGeom = (T) ((UniqueGeometry) item).getOriginalGeometry();
} else {
queryGeom = (T) item;
}
streamCount.add(1);

localK =
strTree.nearestNeighbour(streamShape.getEnvelopeInternal(), streamShape, itemDistance, k);
if (includeTies) {
localK = getUpdatedLocalKWithTies(streamShape, localK, strTree);
Object[] localK =
strTree.nearestNeighbour(
queryGeom.getEnvelopeInternal(), queryGeom, getItemDistance(), k);
if (includeTies) {
localK = getUpdatedLocalKWithTies(queryGeom, localK, strTree);
}
if (searchRadius != null) {
localK = getInSearchRadius(localK, queryGeom);
}

for (Object obj : localK) {
T candidate = (T) obj;
Pair<U, T> pair = Pair.of((U) item, candidate);
result.add(pair);
resultCount.add(1);
}
}
return result.iterator();
} else {
while (streamShapes.hasNext()) {
T streamShape = streamShapes.next();
streamCount.add(1);

Object[] localK =
strTree.nearestNeighbour(
streamShape.getEnvelopeInternal(), streamShape, getItemDistance(), k);
if (includeTies) {
localK = getUpdatedLocalKWithTies(streamShape, localK, strTree);
}
if (searchRadius != null) {
localK = getInSearchRadius(localK, streamShape);
}

for (Object obj : localK) {
T candidate = (T) obj;
Pair<U, T> pair = Pair.of((U) streamShape, candidate);
result.add(pair);
resultCount.add(1);
for (Object obj : localK) {
T candidate = (T) obj;
Pair<U, T> pair = Pair.of((U) streamShape, candidate);
result.add(pair);
resultCount.add(1);
}
}
return result.iterator();
}
}

return result.iterator();
private Object[] getInSearchRadius(Object[] localK, T queryGeom) {
localK =
Arrays.stream(localK)
.filter(
candidate -> {
Geometry candidateGeom = (Geometry) candidate;
return distanceByMetric(queryGeom, candidateGeom, distanceMetric) <= searchRadius;
})
.toArray();
return localK;
}

/**
* This method calculates the distance between two geometries using the specified distance metric.
*
* @param queryGeom the query geometry
* @param candidateGeom the candidate geometry
* @param distanceMetric the distance metric to use
* @return the distance between the two geometries
*/
public static double distanceByMetric(
Geometry queryGeom, Geometry candidateGeom, DistanceMetric distanceMetric) {
switch (distanceMetric) {
case EUCLIDEAN:
EuclideanItemDistance euclideanItemDistance = new EuclideanItemDistance();
return euclideanItemDistance.distance(queryGeom, candidateGeom);
case HAVERSINE:
HaversineItemDistance haversineItemDistance = new HaversineItemDistance();
return haversineItemDistance.distance(queryGeom, candidateGeom);
case SPHEROID:
SpheroidDistance spheroidDistance = new SpheroidDistance();
return spheroidDistance.distance(queryGeom, candidateGeom);
default:
return queryGeom.distance(candidateGeom);
}
}

private ItemDistance getItemDistance() {
ItemDistance itemDistance;
itemDistance = getItemDistanceByMetric(distanceMetric);
return itemDistance;
}

/**
* This method returns the ItemDistance object based on the specified distance metric.
*
* @param distanceMetric the distance metric to use
* @return the ItemDistance object
*/
public static ItemDistance getItemDistanceByMetric(DistanceMetric distanceMetric) {
ItemDistance itemDistance;
switch (distanceMetric) {
case EUCLIDEAN:
itemDistance = new EuclideanItemDistance();
break;
case HAVERSINE:
itemDistance = new HaversineItemDistance();
break;
case SPHEROID:
itemDistance = new SpheroidDistance();
break;
default:
itemDistance = new GeometryItemDistance();
break;
}
return itemDistance;
}

private Object[] getUpdatedLocalKWithTies(T streamShape, Object[] localK, STRtree strTree) {
Expand Down Expand Up @@ -184,4 +281,18 @@ private Object[] getUpdatedLocalKWithTies(T streamShape, Object[] localK, STRtre
}
return localK;
}

public static <U extends Geometry, T extends Geometry> double distance(
U key, T value, DistanceMetric distanceMetric) {
switch (distanceMetric) {
case EUCLIDEAN:
return new EuclideanItemDistance().distance(key, value);
case HAVERSINE:
return new HaversineItemDistance().distance(key, value);
case SPHEROID:
return new SpheroidDistance().distance(key, value);
default:
return new EuclideanItemDistance().distance(key, value);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,12 @@ public double distance(ItemBoundable item1, ItemBoundable item2) {
return g1.distance(g2);
}
}

public double distance(Geometry geometry1, Geometry geometry2) {
if (geometry1 == geometry2) {
return Double.MAX_VALUE;
} else {
return geometry1.distance(geometry2);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,12 @@ public double distance(ItemBoundable item1, ItemBoundable item2) {
return Haversine.distance(g1, g2);
}
}

public double distance(Geometry geometry1, Geometry geometry2) {
if (geometry1 == geometry2) {
return Double.MAX_VALUE;
} else {
return Haversine.distance(geometry1, geometry2);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,12 @@ public double distance(ItemBoundable item1, ItemBoundable item2) {
return Spheroid.distance(g1, g2);
}
}

public double distance(Geometry geometry1, Geometry geometry2) {
if (geometry1 == geometry2) {
return Double.MAX_VALUE;
} else {
return Spheroid.distance(geometry1, geometry2);
}
}
}
Loading

0 comments on commit af74a17

Please sign in to comment.