Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds tests for geospatial functions #3830

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added api/app/tests/utils/snow_masked_hfi20240810.tif
Binary file not shown.
107 changes: 107 additions & 0 deletions api/app/tests/utils/test_geospatial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import os
import pytest
from osgeo import gdal
import numpy as np

from app.utils.geospatial import raster_mul, warp_to_match_extent

fixture_path = os.path.join(os.path.dirname(__file__), "snow_masked_hfi20240810.tif")


def get_test_tpi_raster(hfi_ds: gdal.Dataset, fill_value: int):
# Get raster dimensions
x_size = hfi_ds.RasterXSize
y_size = hfi_ds.RasterYSize

# Get the geotransform and projection from the first raster
geotransform = hfi_ds.GetGeoTransform()
projection = hfi_ds.GetProjection()

# Create the output raster
driver = gdal.GetDriverByName("MEM")
out_ds: gdal.Dataset = driver.Create("memory", x_size, y_size, 1, gdal.GDT_Byte)

# Set the geotransform and projection
out_ds.SetGeoTransform(geotransform)
out_ds.SetProjection(projection)

filler_data = hfi_ds.GetRasterBand(1).ReadAsArray()
tpi_data = np.full_like(filler_data, fill_value)

# Write the modified data to the new raster
out_band = out_ds.GetRasterBand(1)
out_band.SetNoDataValue(0)
out_band.WriteArray(tpi_data)
return out_ds


def get_tpi_raster_wrong_shape():
driver = gdal.GetDriverByName("MEM")
out_ds: gdal.Dataset = driver.Create("memory", 1, 1, 1, gdal.GDT_Byte)
out_band = out_ds.GetRasterBand(1)
out_band.SetNoDataValue(0)
out_band.WriteArray(np.array([[1]]))
return out_ds


def test_zero_case():
hfi_ds: gdal.Dataset = gdal.Open(fixture_path, gdal.GA_ReadOnly)
tpi_ds: gdal.Dataset = get_test_tpi_raster(hfi_ds, 0)

masked_raster = raster_mul(tpi_ds, hfi_ds)
masked_data = masked_raster.GetRasterBand(1).ReadAsArray()

assert masked_data.shape == hfi_ds.GetRasterBand(1).ReadAsArray().shape
assert np.all(masked_data == 0) == True

hfi_ds = None
tpi_ds = None


def test_identity_case():
hfi_ds: gdal.Dataset = gdal.Open(fixture_path, gdal.GA_ReadOnly)
tpi_ds: gdal.Dataset = get_test_tpi_raster(hfi_ds, 1)

masked_raster = raster_mul(tpi_ds, hfi_ds)
masked_data = masked_raster.GetRasterBand(1).ReadAsArray()
hfi_data = hfi_ds.GetRasterBand(1).ReadAsArray()

# do the simple classification for hfi, pixels >4k are 1
hfi_data[hfi_data >= 1] = 1
hfi_data[hfi_data < 1] = 0

assert masked_data.shape == hfi_data.shape
assert np.all(masked_data == hfi_data) == True

hfi_ds = None
tpi_ds = None


def test_wrong_dimensions():
hfi_ds: gdal.Dataset = gdal.Open(fixture_path, gdal.GA_ReadOnly)
tpi_ds: gdal.Dataset = get_tpi_raster_wrong_shape()

with pytest.raises(ValueError):
raster_mul(tpi_ds, hfi_ds)

hfi_ds = None
tpi_ds = None


def test_warp_to_match_dimension():
hfi_ds: gdal.Dataset = gdal.Open(fixture_path, gdal.GA_ReadOnly)
tpi_ds: gdal.Dataset = get_tpi_raster_wrong_shape()

driver = gdal.GetDriverByName("MEM")
out_dataset: gdal.Dataset = driver.Create("memory", hfi_ds.RasterXSize, hfi_ds.RasterYSize, 1, gdal.GDT_Byte)

warp_to_match_extent(tpi_ds, hfi_ds, out_dataset)
output_data = out_dataset.GetRasterBand(1).ReadAsArray()
hfi_data = hfi_ds.GetRasterBand(1).ReadAsArray()

assert hfi_data.shape == output_data.shape
assert hfi_ds.RasterXSize == out_dataset.RasterXSize
assert hfi_ds.RasterYSize == out_dataset.RasterYSize

hfi_ds = None
tpi_ds = None
46 changes: 22 additions & 24 deletions api/app/utils/geospatial.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,64 @@
from dataclasses import dataclass
import logging
from typing import Any, Optional
from osgeo import gdal


logger = logging.getLogger(__name__)


def warp_to_match_extent(source_raster: gdal.Dataset, raster_to_match: gdal.Dataset, output_path: str) -> gdal.Dataset:
def warp_to_match_extent(source_ds: gdal.Dataset, ds_to_match: gdal.Dataset, output_path: str) -> gdal.Dataset:
"""
Warp the source_raster to match the extent and projection of the other raster.

:param source_raster: the raster to warp
:param raster_to_match: the reference raster to match the source against
:param source_ds: the dataset raster to warp
:param ds_to_match: the reference dataset raster to match the source against
:param output_path: output path of the resulting raster
:return: warped raster dataset
"""
source_geotransform = raster_to_match.GetGeoTransform()
source_geotransform = ds_to_match.GetGeoTransform()
x_res = source_geotransform[1]
y_res = -source_geotransform[5]
minx = source_geotransform[0]
maxy = source_geotransform[3]
maxx = minx + source_geotransform[1] * raster_to_match.RasterXSize
miny = maxy + source_geotransform[5] * raster_to_match.RasterYSize
maxx = minx + source_geotransform[1] * ds_to_match.RasterXSize
miny = maxy + source_geotransform[5] * ds_to_match.RasterYSize
extent = [minx, miny, maxx, maxy]

# Warp to match input option parameters
return gdal.Warp(output_path, source_raster, dstSRS=raster_to_match.GetProjection(), outputBounds=extent, xRes=x_res, yRes=y_res, resampleAlg=gdal.GRA_NearestNeighbour)
return gdal.Warp(output_path, source_ds, dstSRS=ds_to_match.GetProjection(), outputBounds=extent, xRes=x_res, yRes=y_res, resampleAlg=gdal.GRA_NearestNeighbour)


def raster_mul(tpi_raster: gdal.Dataset, hfi_raster: gdal.Dataset, chunk_size=256) -> gdal.Dataset:
def raster_mul(tpi_ds: gdal.Dataset, hfi_ds: gdal.Dataset, chunk_size=256) -> gdal.Dataset:
"""
Multiply rasters together by reading in chunks of pixels at a time to avoid loading
the rasters into memory all at once.

:param tpi_raster: Classified TPI raster to multiply against the classified HFI raster
:param hfi_raster: Classified HFI raster to multiply against the classified TPI raster
:param tpi_ds: Classified TPI dataset raster to multiply against the classified HFI dataset raster
:param hfi_ds: Classified HFI dataset raster to multiply against the classified TPI dataset raster
:raises ValueError: Raised if the dimensions of the rasters don't match
:return: Multiplied raster result as a raster dataset
"""
# Get raster dimensions
x_size = tpi_raster.RasterXSize
y_size = tpi_raster.RasterYSize
x_size = tpi_ds.RasterXSize
y_size = tpi_ds.RasterYSize

# Check if the dimensions of both rasters match
if x_size != hfi_raster.RasterXSize or y_size != hfi_raster.RasterYSize:
if x_size != hfi_ds.RasterXSize or y_size != hfi_ds.RasterYSize:
raise ValueError("The dimensions of the two rasters do not match.")

# Get the geotransform and projection from the first raster
geotransform = tpi_raster.GetGeoTransform()
projection = tpi_raster.GetProjection()
geotransform = tpi_ds.GetGeoTransform()
projection = tpi_ds.GetProjection()

# Create the output raster
driver = gdal.GetDriverByName("MEM")
out_raster: gdal.Dataset = driver.Create("memory", x_size, y_size, 1, gdal.GDT_Byte)
out_ds: gdal.Dataset = driver.Create("memory", x_size, y_size, 1, gdal.GDT_Byte)

# Set the geotransform and projection
out_raster.SetGeoTransform(geotransform)
out_raster.SetProjection(projection)
out_ds.SetGeoTransform(geotransform)
out_ds.SetProjection(projection)

tpi_raster_band = tpi_raster.GetRasterBand(1)
hfi_raster_band = hfi_raster.GetRasterBand(1)
tpi_raster_band = tpi_ds.GetRasterBand(1)
hfi_raster_band = hfi_ds.GetRasterBand(1)

# Process in chunks
for y in range(0, y_size, chunk_size):
Expand All @@ -80,8 +78,8 @@ def raster_mul(tpi_raster: gdal.Dataset, hfi_raster: gdal.Dataset, chunk_size=25
tpi_chunk *= hfi_chunk

# Write the result to the output raster
out_raster.GetRasterBand(1).WriteArray(tpi_chunk, x, y)
out_ds.GetRasterBand(1).WriteArray(tpi_chunk, x, y)
tpi_chunk = None
hfi_chunk = None

return out_raster
return out_ds
Loading