diff --git a/openeo_processes_dask/process_implementations/__init__.py b/openeo_processes_dask/process_implementations/__init__.py index 875ee6d5..aa7c3d15 100644 --- a/openeo_processes_dask/process_implementations/__init__.py +++ b/openeo_processes_dask/process_implementations/__init__.py @@ -15,12 +15,12 @@ "Did not load machine learning processes due to missing dependencies: Install them like this: `pip install openeo-processes-dask[implementations, ml]`" ) -try: - from .experimental import * -except ImportError as e: - logger.warning( - "Did not experimental processes due to missing dependencies: Install them like this: `pip install openeo-processes-dask[implementations, experimental]`" - ) +# try: +# from .experimental import * +# except ImportError as e: +# logger.warning( +# "Did not experimental processes due to missing dependencies: Install them like this: `pip install openeo-processes-dask[implementations, experimental]`" +# ) import rioxarray as rio # Required for the .rio accessor on xarrays. diff --git a/openeo_processes_dask/process_implementations/cubes/__init__.py b/openeo_processes_dask/process_implementations/cubes/__init__.py index d18a1475..11dce1b5 100644 --- a/openeo_processes_dask/process_implementations/cubes/__init__.py +++ b/openeo_processes_dask/process_implementations/cubes/__init__.py @@ -4,6 +4,8 @@ from .general import * from .indices import * from .load import * +from .mask import * from .mask_polygon import * from .merge import * from .reduce import * +from .resample import * diff --git a/openeo_processes_dask/process_implementations/cubes/mask.py b/openeo_processes_dask/process_implementations/cubes/mask.py new file mode 100644 index 00000000..96cc09b0 --- /dev/null +++ b/openeo_processes_dask/process_implementations/cubes/mask.py @@ -0,0 +1,120 @@ +import logging +from typing import Callable + +import numpy as np + +from openeo_processes_dask.process_implementations.cubes.resample import ( + resample_cube_spatial, +) +from openeo_processes_dask.process_implementations.cubes.utils import notnull +from openeo_processes_dask.process_implementations.data_model import RasterCube +from openeo_processes_dask.process_implementations.exceptions import ( + DimensionLabelCountMismatch, + DimensionMismatch, + LabelMismatch, +) +from openeo_processes_dask.process_implementations.logic import _not + +logger = logging.getLogger(__name__) + +__all__ = ["mask"] + + +def mask(data: RasterCube, mask: RasterCube, replacement=None) -> RasterCube: + if replacement is None: + replacement = np.nan + + data_band_dims = data.openeo.band_dims + mask_band_dims = mask.openeo.band_dims + # Check if temporal dimensions are present and check the names + data_temporal_dims = data.openeo.temporal_dims + mask_temporal_dims = mask.openeo.temporal_dims + + check_temporal_labels = True + if not set(data_temporal_dims) == set(mask_temporal_dims): + check_temporal_labels = False + # To continue with a valid case, mask shouldn't have a temporal dimension, so that the mask will be applied to all the temporal labels + if len(mask_temporal_dims) != 0: + raise DimensionMismatch( + f"data and mask temporal dimensions do no match: data has temporal dimensions ({data_temporal_dims}) and mask {mask_temporal_dims}." + ) + if check_temporal_labels: + # Check if temporal labels correspond + for n in data_temporal_dims: + data_temporal_labels = data[n].values + mask_temporal_labels = mask[n].values + data_n_labels = len(data_temporal_labels) + mask_n_labels = len(mask_temporal_labels) + + if not data_n_labels == mask_n_labels: + raise DimensionLabelCountMismatch( + f"data and mask temporal dimensions do no match: data has {data_n_labels} temporal dimensions labels and mask {mask_n_labels}." + ) + elif not all(data_temporal_labels == mask_temporal_labels): + raise LabelMismatch( + f"data and mask temporal dimension labels don't match for dimension {n}." + ) + + # From the process definition: https://processes.openeo.org/#mask + # The data cubes have to be compatible except that the horizontal spatial dimensions (axes x and y) will be aligned implicitly by resample_cube_spatial. + apply_resample_cube_spatial = False + + # Check if spatial dimensions have the same name + data_spatial_dims = data.openeo.spatial_dims + mask_spatial_dims = mask.openeo.spatial_dims + if not set(data_spatial_dims) == set(mask_spatial_dims): + raise DimensionMismatch( + f"data and mask spatial dimensions do no match: data has spatial dimensions ({data_spatial_dims}) and mask {mask_spatial_dims}" + ) + + # Check if spatial labels correspond + for n in data_spatial_dims: + data_spatial_labels = data[n].values + mask_spatial_labels = mask[n].values + data_n_labels = len(data_spatial_labels) + mask_n_labels = len(mask_spatial_labels) + + if not data_n_labels == mask_n_labels: + apply_resample_cube_spatial = True + logger.info( + f"data and mask spatial dimension labels don't match: data has ({data_n_labels}) labels and mask has {mask_n_labels} for dimension {n}." + ) + elif not all(data_spatial_labels == mask_spatial_labels): + apply_resample_cube_spatial = True + logger.info( + f"data and mask spatial dimension labels don't match for dimension {n}, i.e. the coordinate values are different." + ) + + if apply_resample_cube_spatial: + logger.info(f"mask is aligned to data using resample_cube_spatial.") + mask = resample_cube_spatial(data=mask, target=data) + + original_dim_order = data.dims + # If bands dimension in data but not in mask, ensure that it comes first and all the other dimensions at the end + if len(data_band_dims) != 0 and len(mask_band_dims) == 0: + required_dim_order = ( + data_band_dims[0] if len(data_band_dims) > 0 else (), + data_temporal_dims[0] if len(data_temporal_dims) > 0 else (), + data.openeo.y_dim, + data.openeo.x_dim, + ) + data = data.transpose(*required_dim_order, missing_dims="ignore") + mask = mask.transpose(*required_dim_order, missing_dims="ignore") + + elif len(data_temporal_dims) != 0 and len(mask_temporal_dims) == 0: + required_dim_order = ( + data_temporal_dims[0] if len(data_temporal_dims) > 0 else (), + data_band_dims[0] if len(data_band_dims) > 0 else (), + data.openeo.y_dim, + data.openeo.x_dim, + ) + data = data.transpose(*required_dim_order, missing_dims="ignore") + mask = mask.transpose(*required_dim_order, missing_dims="ignore") + + data = data.where(_not(mask), replacement) + + if len(data_band_dims) != 0 and len(mask_band_dims) == 0: + # Order axes back to how they were before + data = data.transpose(*original_dim_order) + + return data diff --git a/openeo_processes_dask/process_implementations/cubes/resample.py b/openeo_processes_dask/process_implementations/cubes/resample.py index 01d4bb29..1211cd97 100644 --- a/openeo_processes_dask/process_implementations/cubes/resample.py +++ b/openeo_processes_dask/process_implementations/cubes/resample.py @@ -2,14 +2,20 @@ from typing import Optional, Union import odc.geo.xr +import rioxarray # needs to be imported to set .rio accessor on xarray objects. from odc.geo.geobox import resolution_from_affine from pyproj.crs import CRS, CRSError from openeo_processes_dask.process_implementations.data_model import RasterCube -from openeo_processes_dask.process_implementations.exceptions import OpenEOException +from openeo_processes_dask.process_implementations.exceptions import ( + DimensionMissing, + OpenEOException, +) logger = logging.getLogger(__name__) +__all__ = ["resample_spatial", "resample_cube_spatial"] + resample_methods_list = [ "near", "bilinear", @@ -78,3 +84,70 @@ def resample_spatial( reprojected.attrs["crs"] = data_cp.rio.crs return reprojected + + +def resample_cube_spatial( + data: RasterCube, target: RasterCube, method="near", options=None +) -> RasterCube: + methods_list = [ + "near", + "bilinear", + "cubic", + "cubicspline", + "lanczos", + "average", + "mode", + "max", + "min", + "med", + "q1", + "q3", + ] + + if ( + data.openeo.y_dim is None + or data.openeo.x_dim is None + or target.openeo.y_dim is None + or target.openeo.x_dim is None + ): + raise DimensionMissing( + f"Spatial dimension missing from data or target. Available dimensions for data: {data.dims} for target: {target.dims}" + ) + + # ODC reproject requires y to be before x + required_dim_order = (..., data.openeo.y_dim, data.openeo.x_dim) + + data_reordered = data.transpose(*required_dim_order, missing_dims="ignore") + target_reordered = target.transpose(*required_dim_order, missing_dims="ignore") + + if method == "near": + method = "nearest" + + elif method not in methods_list: + raise Exception( + f'Selected resampling method "{method}" is not available! Please select one of ' + f"[{', '.join(methods_list)}]" + ) + + resampled_data = data_reordered.odc.reproject( + target_reordered.odc.geobox, resampling=method + ) + + resampled_data.rio.write_crs(target_reordered.rio.crs, inplace=True) + + try: + # odc.reproject renames the coordinates according to the geobox, this undoes that. + resampled_data = resampled_data.rename( + {"longitude": data.openeo.x_dim, "latitude": data.openeo.y_dim} + ) + except ValueError: + pass + + # Order axes back to how they were before + resampled_data = resampled_data.transpose(*data.dims) + + # Ensure that attrs except crs are copied over + for k, v in data.attrs.items(): + if k.lower() != "crs": + resampled_data.attrs[k] = v + return resampled_data diff --git a/openeo_processes_dask/process_implementations/exceptions.py b/openeo_processes_dask/process_implementations/exceptions.py index 6d1cf128..a61da854 100644 --- a/openeo_processes_dask/process_implementations/exceptions.py +++ b/openeo_processes_dask/process_implementations/exceptions.py @@ -80,3 +80,11 @@ class RedBandAmbiguous(OpenEOException): class BandExists(OpenEOException): pass + + +class DimensionMismatch(OpenEOException): + pass + + +class LabelMismatch(OpenEOException): + pass diff --git a/openeo_processes_dask/process_implementations/experimental/__init__.py b/openeo_processes_dask/process_implementations/experimental/__init__.py index efabc3ec..e69de29b 100644 --- a/openeo_processes_dask/process_implementations/experimental/__init__.py +++ b/openeo_processes_dask/process_implementations/experimental/__init__.py @@ -1 +0,0 @@ -from .resample import * diff --git a/openeo_processes_dask/process_implementations/experimental/resample.py b/openeo_processes_dask/process_implementations/experimental/resample.py deleted file mode 100644 index e349f23e..00000000 --- a/openeo_processes_dask/process_implementations/experimental/resample.py +++ /dev/null @@ -1,67 +0,0 @@ -import rioxarray # needs to be imported to set .rio accessor on xarray objects. - -from openeo_processes_dask.process_implementations.data_model import RasterCube - -__all__ = ["resample_cube_spatial"] - - -def resample_cube_spatial( - data: RasterCube, target: RasterCube, method="near", options=None -) -> RasterCube: - methods_list = [ - "near", - "bilinear", - "cubic", - "cubicspline", - "lanczos", - "average", - "mode", - "max", - "min", - "med", - "q1", - "q3", - ] - - # ODC reproject requires y to be before x - required_dim_order = ( - data.openeo.band_dims - + data.openeo.temporal_dims - + tuple(data.openeo.y_dim) - + tuple(data.openeo.x_dim) - ) - - data_reordered = data.transpose(*required_dim_order, missing_dims="ignore") - target_reordered = target.transpose(*required_dim_order, missing_dims="ignore") - - if method == "near": - method = "nearest" - - elif method not in methods_list: - raise Exception( - f'Selected resampling method "{method}" is not available! Please select one of ' - f"[{', '.join(methods_list)}]" - ) - - resampled_data = data_reordered.odc.reproject( - target_reordered.odc.geobox, resampling=method - ) - - resampled_data.rio.write_crs(target_reordered.rio.crs, inplace=True) - - try: - # odc.reproject renames the coordinates according to the geobox, this undoes that. - resampled_data = resampled_data.rename( - {"longitude": data.openeo.x_dim, "latitude": data.openeo.y_dim} - ) - except ValueError: - pass - - # Order axes back to how they were before - resampled_data = resampled_data.transpose(*data.dims) - - # Ensure that attrs except crs are copied over - for k, v in data.attrs.items(): - if k.lower() != "crs": - resampled_data.attrs[k] = v - return resampled_data diff --git a/tests/test_mask.py b/tests/test_mask.py index c5c6148e..efda7f85 100644 --- a/tests/test_mask.py +++ b/tests/test_mask.py @@ -1,10 +1,17 @@ +from functools import partial + import numpy as np import pytest -from openeo_pg_parser_networkx.pg_schema import TemporalInterval +from openeo_pg_parser_networkx.pg_schema import ParameterReference, TemporalInterval +from openeo_processes_dask.process_implementations.cubes.mask import mask from openeo_processes_dask.process_implementations.cubes.mask_polygon import ( mask_polygon, ) +from openeo_processes_dask.process_implementations.cubes.reduce import ( + reduce_dimension, + reduce_spatial, +) from tests.mockdata import create_fake_rastercube @@ -29,3 +36,56 @@ def test_mask_polygon( assert np.isnan(output_cube).sum() > np.isnan(input_cube).sum() assert len(output_cube.y) == len(input_cube.y) assert len(output_cube.x) == len(input_cube.x) + + +@pytest.mark.parametrize("size", [(30, 30, 20, 2)]) +@pytest.mark.parametrize("dtype", [np.float32]) +def test_mask( + temporal_interval, + bounding_box, + random_raster_data, + process_registry, +): + """Test to ensure resolution gets changed correctly.""" + input_cube = create_fake_rastercube( + data=random_raster_data, + spatial_extent=bounding_box, + temporal_extent=temporal_interval, + bands=["B02", "B03"], + backend="dask", + ) + + mask_cube = input_cube > 0 + output_cube = mask(data=input_cube, mask=mask_cube) + + assert np.isnan(output_cube).sum() > np.isnan(input_cube).sum() + assert len(output_cube.y) == len(input_cube.y) + assert len(output_cube.x) == len(input_cube.x) + + _process = partial( + process_registry["max"].implementation, + ignore_nodata=True, + data=ParameterReference(from_parameter="data"), + ) + + mask_cube_no_x = reduce_dimension(data=mask_cube, dimension="x", reducer=_process) + with pytest.raises(Exception): + output_cube = mask(data=input_cube, mask=mask_cube_no_x) + + # Mask should work without bands + mask_cube_no_bands = reduce_dimension( + data=mask_cube, dimension="bands", reducer=_process + ) + output_cube = mask(data=input_cube, mask=mask_cube_no_bands) + + # Mask should work without time + mask_cube_no_time = reduce_dimension( + data=mask_cube, dimension="t", reducer=_process + ) + output_cube = mask(data=input_cube, mask=mask_cube_no_time) + + # Mask should work without time and bands + mask_cube_no_time_bands = reduce_dimension( + data=mask_cube_no_bands, dimension="t", reducer=_process + ) + output_cube = mask(data=input_cube, mask=mask_cube_no_time_bands) diff --git a/tests/test_resample.py b/tests/test_resample.py index 44632495..1a70e671 100644 --- a/tests/test_resample.py +++ b/tests/test_resample.py @@ -4,6 +4,7 @@ from pyproj.crs import CRS from openeo_processes_dask.process_implementations.cubes.resample import ( + resample_cube_spatial, resample_spatial, ) from tests.general_checks import general_output_checks @@ -63,3 +64,51 @@ def test_resample_spatial( assert min(output_cube.y) >= -90 assert max(output_cube.y) <= 90 + + +@pytest.mark.parametrize( + "output_crs", + [ + 3587, + "32633", + "+proj=aeqd +lat_0=53 +lon_0=24 +x_0=5837287.81977 +y_0=2121415.69617 +datum=WGS84 +units=m +no_defs", + "4326", + ], +) +@pytest.mark.parametrize("output_res", [5, 30, 60]) +@pytest.mark.parametrize("size", [(30, 30, 20, 4)]) +@pytest.mark.parametrize("dtype", [np.float32]) +def test_resample_cube_spatial( + output_crs, output_res, temporal_interval, bounding_box, random_raster_data +): + """Test to ensure resolution gets changed correctly.""" + input_cube = create_fake_rastercube( + data=random_raster_data, + spatial_extent=bounding_box, + temporal_extent=temporal_interval, + bands=["B02", "B03", "B04", "B08"], + backend="dask", + ) + + resampled_cube = resample_spatial( + data=input_cube, projection=output_crs, resolution=output_res + ) + + with pytest.raises(Exception): + output_cube = resample_cube_spatial( + data=input_cube, target=resampled_cube, method="bad" + ) + + output_cube = resample_cube_spatial( + data=input_cube, target=resampled_cube, method="average" + ) + + general_output_checks( + input_cube=input_cube, + output_cube=output_cube, + expected_dims=input_cube.dims, + verify_attrs=False, + verify_crs=False, + ) + + assert output_cube.odc.spatial_dims == ("y", "x")