Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SEDONA-690] Optimize query side broadcast knn join #1741

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.sedona.core.knnJudgement.EuclideanItemDistance;
import org.apache.sedona.core.knnJudgement.HaversineItemDistance;
import org.apache.sedona.core.knnJudgement.SpheroidDistance;
import org.apache.sedona.core.wrapper.UniqueGeometry;
import org.apache.spark.api.java.function.FlatMapFunction2;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.util.LongAccumulator;
Expand All @@ -46,35 +47,43 @@ public class KnnJoinIndexJudgement<T extends Geometry, U extends Geometry>
extends JudgementBase<T, U>
implements FlatMapFunction2<Iterator<T>, Iterator<SpatialIndex>, Pair<U, T>>, Serializable {
private final int k;
private final Double searchRadius;
private final DistanceMetric distanceMetric;
private final boolean includeTies;
private final Broadcast<STRtree> broadcastedTreeIndex;
private final Broadcast<List> broadcastQueryObjects;
private final Broadcast<STRtree> broadcastObjectsTreeIndex;

/**
* Constructor for the KnnJoinIndexJudgement class.
*
* @param k the number of nearest neighbors to find
* @param searchRadius
* @param distanceMetric the distance metric to use
* @param broadcastQueryObjects the broadcast geometries on queries
* @param broadcastObjectsTreeIndex the broadcast spatial index on objects
* @param buildCount accumulator for the number of geometries processed from the build side
* @param streamCount accumulator for the number of geometries processed from the stream side
* @param resultCount accumulator for the number of join results
* @param candidateCount accumulator for the number of candidate matches
* @param broadcastedTreeIndex the broadcasted spatial index
*/
public KnnJoinIndexJudgement(
int k,
Double searchRadius,
DistanceMetric distanceMetric,
boolean includeTies,
Broadcast<STRtree> broadcastedTreeIndex,
Broadcast<List> broadcastQueryObjects,
Broadcast<STRtree> broadcastObjectsTreeIndex,
LongAccumulator buildCount,
LongAccumulator streamCount,
LongAccumulator resultCount,
LongAccumulator candidateCount) {
super(null, buildCount, streamCount, resultCount, candidateCount);
this.k = k;
this.searchRadius = searchRadius;
this.distanceMetric = distanceMetric;
this.includeTies = includeTies;
this.broadcastedTreeIndex = broadcastedTreeIndex;
this.broadcastQueryObjects = broadcastQueryObjects;
this.broadcastObjectsTreeIndex = broadcastObjectsTreeIndex;
}

/**
Expand All @@ -90,7 +99,7 @@ public KnnJoinIndexJudgement(
@Override
public Iterator<Pair<U, T>> call(Iterator<T> streamShapes, Iterator<SpatialIndex> treeIndexes)
throws Exception {
if (!treeIndexes.hasNext() || !streamShapes.hasNext()) {
if (!treeIndexes.hasNext() || (streamShapes != null && !streamShapes.hasNext())) {
buildCount.add(0);
streamCount.add(0);
resultCount.add(0);
Expand All @@ -99,10 +108,9 @@ public Iterator<Pair<U, T>> call(Iterator<T> streamShapes, Iterator<SpatialIndex
}

STRtree strTree;
if (broadcastedTreeIndex != null) {
// get the broadcasted spatial index if available
// this is to support the broadcast join
strTree = broadcastedTreeIndex.getValue();
if (broadcastObjectsTreeIndex != null) {
// get the broadcast spatial index on objects side if available
strTree = broadcastObjectsTreeIndex.getValue();
} else {
// get the spatial index from the iterator
SpatialIndex treeIndex = treeIndexes.next();
Expand All @@ -113,44 +121,133 @@ public Iterator<Pair<U, T>> call(Iterator<T> streamShapes, Iterator<SpatialIndex
strTree = (STRtree) treeIndex;
}

// TODO: For future improvement, instead of using a list to store the results,
// we can use lazy evaluation to avoid storing all the results in memory.
List<Pair<U, T>> result = new ArrayList<>();
ItemDistance itemDistance;

while (streamShapes.hasNext()) {
T streamShape = streamShapes.next();
streamCount.add(1);

Object[] localK;
switch (distanceMetric) {
case EUCLIDEAN:
itemDistance = new EuclideanItemDistance();
break;
case HAVERSINE:
itemDistance = new HaversineItemDistance();
break;
case SPHEROID:
itemDistance = new SpheroidDistance();
break;
default:
itemDistance = new GeometryItemDistance();
break;
}
List queryItems;
if (broadcastQueryObjects != null) {
// get the broadcast spatial index on queries side if available
queryItems = broadcastQueryObjects.getValue();
for (Object item : queryItems) {
T queryGeom;
if (item instanceof UniqueGeometry) {
queryGeom = (T) ((UniqueGeometry) item).getOriginalGeometry();
} else {
queryGeom = (T) item;
}
streamCount.add(1);

localK =
strTree.nearestNeighbour(streamShape.getEnvelopeInternal(), streamShape, itemDistance, k);
if (includeTies) {
localK = getUpdatedLocalKWithTies(streamShape, localK, strTree);
Object[] localK =
strTree.nearestNeighbour(
queryGeom.getEnvelopeInternal(), queryGeom, getItemDistance(), k);
if (includeTies) {
localK = getUpdatedLocalKWithTies(queryGeom, localK, strTree);
}
if (searchRadius != null) {
localK = getInSearchRadius(localK, queryGeom);
}

for (Object obj : localK) {
T candidate = (T) obj;
Pair<U, T> pair = Pair.of((U) item, candidate);
result.add(pair);
resultCount.add(1);
}
}
return result.iterator();
} else {
while (streamShapes.hasNext()) {
T streamShape = streamShapes.next();
streamCount.add(1);

Object[] localK =
strTree.nearestNeighbour(
streamShape.getEnvelopeInternal(), streamShape, getItemDistance(), k);
if (includeTies) {
localK = getUpdatedLocalKWithTies(streamShape, localK, strTree);
}
if (searchRadius != null) {
localK = getInSearchRadius(localK, streamShape);
}

for (Object obj : localK) {
T candidate = (T) obj;
Pair<U, T> pair = Pair.of((U) streamShape, candidate);
result.add(pair);
resultCount.add(1);
for (Object obj : localK) {
T candidate = (T) obj;
Pair<U, T> pair = Pair.of((U) streamShape, candidate);
result.add(pair);
resultCount.add(1);
}
}
return result.iterator();
}
}

return result.iterator();
private Object[] getInSearchRadius(Object[] localK, T queryGeom) {
localK =
Arrays.stream(localK)
.filter(
candidate -> {
Geometry candidateGeom = (Geometry) candidate;
return distanceByMetric(queryGeom, candidateGeom, distanceMetric) <= searchRadius;
})
.toArray();
return localK;
}

/**
* This method calculates the distance between two geometries using the specified distance metric.
*
* @param queryGeom the query geometry
* @param candidateGeom the candidate geometry
* @param distanceMetric the distance metric to use
* @return the distance between the two geometries
*/
public static double distanceByMetric(
Geometry queryGeom, Geometry candidateGeom, DistanceMetric distanceMetric) {
switch (distanceMetric) {
case EUCLIDEAN:
EuclideanItemDistance euclideanItemDistance = new EuclideanItemDistance();
return euclideanItemDistance.distance(queryGeom, candidateGeom);
case HAVERSINE:
HaversineItemDistance haversineItemDistance = new HaversineItemDistance();
return haversineItemDistance.distance(queryGeom, candidateGeom);
case SPHEROID:
SpheroidDistance spheroidDistance = new SpheroidDistance();
return spheroidDistance.distance(queryGeom, candidateGeom);
default:
return queryGeom.distance(candidateGeom);
}
}

private ItemDistance getItemDistance() {
ItemDistance itemDistance;
itemDistance = getItemDistanceByMetric(distanceMetric);
return itemDistance;
}

/**
* This method returns the ItemDistance object based on the specified distance metric.
*
* @param distanceMetric the distance metric to use
* @return the ItemDistance object
*/
public static ItemDistance getItemDistanceByMetric(DistanceMetric distanceMetric) {
ItemDistance itemDistance;
switch (distanceMetric) {
case EUCLIDEAN:
itemDistance = new EuclideanItemDistance();
break;
case HAVERSINE:
itemDistance = new HaversineItemDistance();
break;
case SPHEROID:
itemDistance = new SpheroidDistance();
break;
default:
itemDistance = new GeometryItemDistance();
break;
}
return itemDistance;
}

private Object[] getUpdatedLocalKWithTies(T streamShape, Object[] localK, STRtree strTree) {
Expand Down Expand Up @@ -184,4 +281,18 @@ private Object[] getUpdatedLocalKWithTies(T streamShape, Object[] localK, STRtre
}
return localK;
}

public static <U extends Geometry, T extends Geometry> double distance(
U key, T value, DistanceMetric distanceMetric) {
switch (distanceMetric) {
case EUCLIDEAN:
return new EuclideanItemDistance().distance(key, value);
case HAVERSINE:
return new HaversineItemDistance().distance(key, value);
case SPHEROID:
return new SpheroidDistance().distance(key, value);
default:
return new EuclideanItemDistance().distance(key, value);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,12 @@ public double distance(ItemBoundable item1, ItemBoundable item2) {
return g1.distance(g2);
}
}

public double distance(Geometry geometry1, Geometry geometry2) {
if (geometry1 == geometry2) {
return Double.MAX_VALUE;
} else {
return geometry1.distance(geometry2);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,12 @@ public double distance(ItemBoundable item1, ItemBoundable item2) {
return Haversine.distance(g1, g2);
}
}

public double distance(Geometry geometry1, Geometry geometry2) {
if (geometry1 == geometry2) {
return Double.MAX_VALUE;
} else {
return Haversine.distance(geometry1, geometry2);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,12 @@ public double distance(ItemBoundable item1, ItemBoundable item2) {
return Spheroid.distance(g1, g2);
}
}

public double distance(Geometry geometry1, Geometry geometry2) {
if (geometry1 == geometry2) {
return Double.MAX_VALUE;
} else {
return Spheroid.distance(geometry1, geometry2);
}
}
}
Loading
Loading