Skip to content

Commit

Permalink
[SEDONA-433] Improve RS_SummaryStats performance (#1128)
Browse files Browse the repository at this point in the history
  • Loading branch information
furqaankhan authored Nov 22, 2023
1 parent 1030d0b commit a75fe7b
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.stat.StatUtils;
import org.apache.commons.math3.stat.descriptive.DescriptiveStatistics;
import org.apache.commons.math3.stat.descriptive.moment.StandardDeviation;
import org.apache.sedona.common.Functions;
import org.apache.sedona.common.utils.RasterUtils;
import org.geotools.coverage.GridSampleDimension;
Expand Down Expand Up @@ -277,39 +278,37 @@ public static double[] getSummaryStats(GridCoverage2D rasterGeom, int band, bool
Raster raster = RasterUtils.getRaster(rasterGeom.getRenderedImage());
int height = RasterAccessors.getHeight(rasterGeom), width = RasterAccessors.getWidth(rasterGeom);
double[] pixels = raster.getSamples(0, 0, width, height, band - 1, (double[]) null);
double count = 0, sum = 0, mean = 0, stddev = 0, min = Double.MAX_VALUE, max = -Double.MAX_VALUE;
Double noDataValue = RasterBandAccessors.getBandNoDataValue(rasterGeom, band);
for (double pixel: pixels) {
if(excludeNoDataValue) {
// exclude no data values

List<Double> pixelData = null;

if (excludeNoDataValue) {
pixelData = new ArrayList<>();
Double noDataValue = RasterBandAccessors.getBandNoDataValue(rasterGeom, band);
for (double pixel: pixels) {
if (noDataValue == null || pixel != noDataValue) {
count++;
sum += pixel;
min = Math.min(min, pixel);
max = Math.max(max, pixel);
pixelData.add(pixel);
}
} else {
// include no data values
count = pixels.length;
sum += pixel;
min = Math.min(min, pixel);
max = Math.max(max, pixel);
}
}
if (count == 0) {
return new double[] {0, Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN};
}
mean = sum / count;
for(double pixel: pixels){
if (excludeNoDataValue){
if (noDataValue == null || pixel != noDataValue) {
stddev += Math.pow(pixel - mean, 2);
}
} else {
stddev += Math.pow(pixel - mean, 2);
}

DescriptiveStatistics stats = null;

if (pixelData == null) {
stats = new DescriptiveStatistics(pixels);
} else {
pixels = pixelData.stream().mapToDouble(d -> d).toArray();
stats = new DescriptiveStatistics(pixels);
}
stddev = Math.sqrt(stddev/count);

StandardDeviation sd = new StandardDeviation(false);

double count = stats.getN();
double sum = stats.getSum();
double mean = stats.getMean();
double stddev = sd.evaluate(pixels, mean);
double min = stats.getMin();
double max = stats.getMax();

return new double[]{count, sum, mean, stddev, min, max};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,19 +193,9 @@ public void testSummaryStatsWithAllNoData() throws FactoryException {
GridCoverage2D emptyRaster = RasterConstructors.makeEmptyRaster(1, 5, 5, 0, 0, 1, -1, 0, 0, 0);
double[] values = new double[] {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
emptyRaster = MapAlgebra.addBandFromArray(emptyRaster, values, 1, 0d);
double count = 0.0;
double sum = Double.NaN;
double mean = Double.NaN;
double stddev = Double.NaN;
double min = Double.NaN;
double max = Double.NaN;
double[] result = RasterBandAccessors.getSummaryStats(emptyRaster);
assertEquals(count, result[0], 0.1d);
assertEquals(sum, result[1], 0.1d);
assertEquals(mean, result[2], 0.1d);
assertEquals(stddev, result[3], 0.1d);
assertEquals(min, result[4], 0.1d);
assertEquals(max, result[5], 0.1d);
double[] actual = RasterBandAccessors.getSummaryStats(emptyRaster);
double[] expected = {0.0, 0.0, Double.NaN, Double.NaN, Double.NaN, Double.NaN};
assertArrayEquals(expected, actual, FP_TOLERANCE);
}

@Test
Expand All @@ -215,87 +205,32 @@ public void testSummaryStatsWithEmptyRaster() throws FactoryException {
double[] values2 = new double[] {0,0,28,29,0,0,0,33,34,35,36,37,38,0,0,0,0,43,44,45,46,47,48,49,50};
emptyRaster = MapAlgebra.addBandFromArray(emptyRaster, values1, 1, 0d);
emptyRaster = MapAlgebra.addBandFromArray(emptyRaster, values2, 2, 0d);
double count = 25.0;
double sum = 204.0;
double mean = 8.16;
double stddev = 9.27655108324209;
double min = 0.0;
double max = 25.0;
double[] result = RasterBandAccessors.getSummaryStats(emptyRaster, 1, false);
assertEquals(count, result[0], 0.1d);
assertEquals(sum, result[1], 0.1d);
assertEquals(mean, result[2], 1e-2d);
assertEquals(stddev, result[3], 1e-6);
assertEquals(min, result[4], 0.1d);
assertEquals(max, result[5], 0.1d);

count = 16.0;
sum = 642.0;
mean = 40.125;
stddev = 6.9988838395847095;
min = 28.0;
max = 50.0;
result = RasterBandAccessors.getSummaryStats(emptyRaster, 2);
assertEquals(count, result[0], 0.1d);
assertEquals(sum, result[1], 0.1d);
assertEquals(mean, result[2], 1e-3d);
assertEquals(stddev, result[3], 1e-6d);
assertEquals(min, result[4], 0.1d);
assertEquals(max, result[5], 0.1d);

count = 14.0;
sum = 204.0;
mean = 14.571428571428571;
stddev = 7.761758689832072;
min = 1.0;
max = 25.0;
result = RasterBandAccessors.getSummaryStats(emptyRaster);
assertEquals(count, result[0], 0.1d);
assertEquals(sum, result[1], 0.1d);
assertEquals(mean, result[2], 1e-6d);
assertEquals(stddev, result[3], 1e-6d);
assertEquals(min, result[4], 0.1d);
assertEquals(max, result[5], 0.1d);
double[] actual = RasterBandAccessors.getSummaryStats(emptyRaster, 1, false);
double[] expected = {25.0, 204.0, 8.1600, 9.2765, 0.0, 25.0};
assertArrayEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(emptyRaster, 2);
expected = new double[]{16.0, 642.0, 40.125, 6.9988838395847095, 28.0, 50.0};
assertArrayEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(emptyRaster);
expected = new double[] {14.0, 204.0, 14.5714, 7.7617, 1.0, 25.0};
assertArrayEquals(expected, actual, FP_TOLERANCE);
}

@Test
public void testSummaryStatsWithRaster() throws IOException {
GridCoverage2D raster = rasterFromGeoTiff(resourceFolder + "raster/raster_with_no_data/test5.tiff");
double count = 1036800.0;
double sum = 2.06233487E8;
double mean = 198.91347125771605;
double stddev = 95.09054096106192;
double min = 0.0;
double max = 255.0;
double[] result = RasterBandAccessors.getSummaryStats(raster, 1, false);
assertEquals(count, result[0], 0.1d);
assertEquals(sum, result[1], 0.1d);
assertEquals(mean, result[2], 1e-6d);
assertEquals(stddev, result[3], 1e-6d);
assertEquals(min, result[4], 0.1d);
assertEquals(max, result[5], 0.1d);

count = 928192.0;
sum = 2.06233487E8;
mean = 222.18839097945252;
stddev = 70.20559521132097;
min = 1.0;
max = 255.0;
result = RasterBandAccessors.getSummaryStats(raster, 1);
assertEquals(count, result[0], 0.1d);
assertEquals(sum, result[1], 0.1d);
assertEquals(mean, result[2], 1e-6d);
assertEquals(stddev, result[3], 1e-6d);
assertEquals(min, result[4], 0.1d);
assertEquals(max, result[5], 0.1d);

result = RasterBandAccessors.getSummaryStats(raster);
assertEquals(count, result[0], 0.1d);
assertEquals(sum, result[1], 0.1d);
assertEquals(mean, result[2], 1e-6d);
assertEquals(stddev, result[3], 1e-6d);
assertEquals(min, result[4], 0.1d);
assertEquals(max, result[5], 0.1d);
double[] actual = RasterBandAccessors.getSummaryStats(raster, 1, false);
double[] expected = {1036800.0, 2.06233487E8, 198.9134, 95.0905, 0.0, 255.0};
assertArrayEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(raster, 1);
expected = new double[]{928192.0, 2.06233487E8, 222.1883, 70.2055, 1.0, 255.0};
assertArrayEquals(expected, actual, FP_TOLERANCE);

actual = RasterBandAccessors.getSummaryStats(raster);
assertArrayEquals(expected, actual, FP_TOLERANCE);
}

@Test
Expand Down

0 comments on commit a75fe7b

Please sign in to comment.