From a75fe7bb73342e6340ff528a00e6c1d7c368a176 Mon Sep 17 00:00:00 2001 From: Furqaan Khan <46216254+furqaankhan@users.noreply.github.com> Date: Wed, 22 Nov 2023 12:07:52 -0500 Subject: [PATCH] [SEDONA-433] Improve RS_SummaryStats performance (#1128) --- .../common/raster/RasterBandAccessors.java | 55 +++++---- .../raster/RasterBandAccessorsTest.java | 113 ++++-------------- 2 files changed, 51 insertions(+), 117 deletions(-) diff --git a/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java b/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java index d1a0b2f424..dc5bd38324 100644 --- a/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java +++ b/common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java @@ -21,6 +21,7 @@ import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.math3.stat.StatUtils; import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics; +import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation; import org.apache.sedona.common.Functions; import org.apache.sedona.common.utils.RasterUtils; import org.geotools.coverage.GridSampleDimension; @@ -277,39 +278,37 @@ public static double[] getSummaryStats(GridCoverage2D rasterGeom, int band, bool Raster raster = RasterUtils.getRaster(rasterGeom.getRenderedImage()); int height = RasterAccessors.getHeight(rasterGeom), width = RasterAccessors.getWidth(rasterGeom); double[] pixels = raster.getSamples(0, 0, width, height, band - 1, (double[]) null); - double count = 0, sum = 0, mean = 0, stddev = 0, min = Double.MAX_VALUE, max = -Double.MAX_VALUE; - Double noDataValue = RasterBandAccessors.getBandNoDataValue(rasterGeom, band); - for (double pixel: pixels) { - if(excludeNoDataValue) { - // exclude no data values + + List pixelData = null; + + if (excludeNoDataValue) { + pixelData = new ArrayList<>(); + Double noDataValue = RasterBandAccessors.getBandNoDataValue(rasterGeom, band); + for (double pixel: pixels) { if (noDataValue == null || pixel != noDataValue) { - count++; - sum += pixel; - min = Math.min(min, pixel); - max = Math.max(max, pixel); + pixelData.add(pixel); } - } else { - // include no data values - count = pixels.length; - sum += pixel; - min = Math.min(min, pixel); - max = Math.max(max, pixel); } } - if (count == 0) { - return new double[] {0, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN}; - } - mean = sum / count; - for(double pixel: pixels){ - if (excludeNoDataValue){ - if (noDataValue == null || pixel != noDataValue) { - stddev += Math.pow(pixel - mean, 2); - } - } else { - stddev += Math.pow(pixel - mean, 2); - } + + DescriptiveStatistics stats = null; + + if (pixelData == null) { + stats = new DescriptiveStatistics(pixels); + } else { + pixels = pixelData.stream().mapToDouble(d -> d).toArray(); + stats = new DescriptiveStatistics(pixels); } - stddev = Math.sqrt(stddev/count); + + StandardDeviation sd = new StandardDeviation(false); + + double count = stats.getN(); + double sum = stats.getSum(); + double mean = stats.getMean(); + double stddev = sd.evaluate(pixels, mean); + double min = stats.getMin(); + double max = stats.getMax(); + return new double[]{count, sum, mean, stddev, min, max}; } diff --git a/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java b/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java index 1d2e14dcb7..1a39862300 100644 --- a/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java +++ b/common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java @@ -193,19 +193,9 @@ public void testSummaryStatsWithAllNoData() throws FactoryException { GridCoverage2D emptyRaster = RasterConstructors.makeEmptyRaster(1, 5, 5, 0, 0, 1, -1, 0, 0, 0); double[] values = new double[] {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0}; emptyRaster = MapAlgebra.addBandFromArray(emptyRaster, values, 1, 0d); - double count = 0.0; - double sum = Double.NaN; - double mean = Double.NaN; - double stddev = Double.NaN; - double min = Double.NaN; - double max = Double.NaN; - double[] result = RasterBandAccessors.getSummaryStats(emptyRaster); - assertEquals(count, result[0], 0.1d); - assertEquals(sum, result[1], 0.1d); - assertEquals(mean, result[2], 0.1d); - assertEquals(stddev, result[3], 0.1d); - assertEquals(min, result[4], 0.1d); - assertEquals(max, result[5], 0.1d); + double[] actual = RasterBandAccessors.getSummaryStats(emptyRaster); + double[] expected = {0.0, 0.0, Double.NaN, Double.NaN, Double.NaN, Double.NaN}; + assertArrayEquals(expected, actual, FP_TOLERANCE); } @Test @@ -215,87 +205,32 @@ public void testSummaryStatsWithEmptyRaster() throws FactoryException { double[] values2 = new double[] {0,0,28,29,0,0,0,33,34,35,36,37,38,0,0,0,0,43,44,45,46,47,48,49,50}; emptyRaster = MapAlgebra.addBandFromArray(emptyRaster, values1, 1, 0d); emptyRaster = MapAlgebra.addBandFromArray(emptyRaster, values2, 2, 0d); - double count = 25.0; - double sum = 204.0; - double mean = 8.16; - double stddev = 9.27655108324209; - double min = 0.0; - double max = 25.0; - double[] result = RasterBandAccessors.getSummaryStats(emptyRaster, 1, false); - assertEquals(count, result[0], 0.1d); - assertEquals(sum, result[1], 0.1d); - assertEquals(mean, result[2], 1e-2d); - assertEquals(stddev, result[3], 1e-6); - assertEquals(min, result[4], 0.1d); - assertEquals(max, result[5], 0.1d); - - count = 16.0; - sum = 642.0; - mean = 40.125; - stddev = 6.9988838395847095; - min = 28.0; - max = 50.0; - result = RasterBandAccessors.getSummaryStats(emptyRaster, 2); - assertEquals(count, result[0], 0.1d); - assertEquals(sum, result[1], 0.1d); - assertEquals(mean, result[2], 1e-3d); - assertEquals(stddev, result[3], 1e-6d); - assertEquals(min, result[4], 0.1d); - assertEquals(max, result[5], 0.1d); - - count = 14.0; - sum = 204.0; - mean = 14.571428571428571; - stddev = 7.761758689832072; - min = 1.0; - max = 25.0; - result = RasterBandAccessors.getSummaryStats(emptyRaster); - assertEquals(count, result[0], 0.1d); - assertEquals(sum, result[1], 0.1d); - assertEquals(mean, result[2], 1e-6d); - assertEquals(stddev, result[3], 1e-6d); - assertEquals(min, result[4], 0.1d); - assertEquals(max, result[5], 0.1d); + double[] actual = RasterBandAccessors.getSummaryStats(emptyRaster, 1, false); + double[] expected = {25.0, 204.0, 8.1600, 9.2765, 0.0, 25.0}; + assertArrayEquals(expected, actual, FP_TOLERANCE); + + actual = RasterBandAccessors.getSummaryStats(emptyRaster, 2); + expected = new double[]{16.0, 642.0, 40.125, 6.9988838395847095, 28.0, 50.0}; + assertArrayEquals(expected, actual, FP_TOLERANCE); + + actual = RasterBandAccessors.getSummaryStats(emptyRaster); + expected = new double[] {14.0, 204.0, 14.5714, 7.7617, 1.0, 25.0}; + assertArrayEquals(expected, actual, FP_TOLERANCE); } @Test public void testSummaryStatsWithRaster() throws IOException { GridCoverage2D raster = rasterFromGeoTiff(resourceFolder + "raster/raster_with_no_data/test5.tiff"); - double count = 1036800.0; - double sum = 2.06233487E8; - double mean = 198.91347125771605; - double stddev = 95.09054096106192; - double min = 0.0; - double max = 255.0; - double[] result = RasterBandAccessors.getSummaryStats(raster, 1, false); - assertEquals(count, result[0], 0.1d); - assertEquals(sum, result[1], 0.1d); - assertEquals(mean, result[2], 1e-6d); - assertEquals(stddev, result[3], 1e-6d); - assertEquals(min, result[4], 0.1d); - assertEquals(max, result[5], 0.1d); - - count = 928192.0; - sum = 2.06233487E8; - mean = 222.18839097945252; - stddev = 70.20559521132097; - min = 1.0; - max = 255.0; - result = RasterBandAccessors.getSummaryStats(raster, 1); - assertEquals(count, result[0], 0.1d); - assertEquals(sum, result[1], 0.1d); - assertEquals(mean, result[2], 1e-6d); - assertEquals(stddev, result[3], 1e-6d); - assertEquals(min, result[4], 0.1d); - assertEquals(max, result[5], 0.1d); - - result = RasterBandAccessors.getSummaryStats(raster); - assertEquals(count, result[0], 0.1d); - assertEquals(sum, result[1], 0.1d); - assertEquals(mean, result[2], 1e-6d); - assertEquals(stddev, result[3], 1e-6d); - assertEquals(min, result[4], 0.1d); - assertEquals(max, result[5], 0.1d); + double[] actual = RasterBandAccessors.getSummaryStats(raster, 1, false); + double[] expected = {1036800.0, 2.06233487E8, 198.9134, 95.0905, 0.0, 255.0}; + assertArrayEquals(expected, actual, FP_TOLERANCE); + + actual = RasterBandAccessors.getSummaryStats(raster, 1); + expected = new double[]{928192.0, 2.06233487E8, 222.1883, 70.2055, 1.0, 255.0}; + assertArrayEquals(expected, actual, FP_TOLERANCE); + + actual = RasterBandAccessors.getSummaryStats(raster); + assertArrayEquals(expected, actual, FP_TOLERANCE); } @Test