Skip to content

Commit

Permalink
feat: add mask and fix resample_cube_spatial (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
clausmichele authored Jan 2, 2024
1 parent 8d9e3fb commit 4a011c4
Show file tree
Hide file tree
Showing 9 changed files with 320 additions and 76 deletions.
12 changes: 6 additions & 6 deletions openeo_processes_dask/process_implementations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
"Did not load machine learning processes due to missing dependencies: Install them like this: `pip install openeo-processes-dask[implementations, ml]`"
)

try:
from .experimental import *
except ImportError as e:
logger.warning(
"Did not experimental processes due to missing dependencies: Install them like this: `pip install openeo-processes-dask[implementations, experimental]`"
)
# try:
# from .experimental import *
# except ImportError as e:
# logger.warning(
# "Did not experimental processes due to missing dependencies: Install them like this: `pip install openeo-processes-dask[implementations, experimental]`"
# )

import rioxarray as rio # Required for the .rio accessor on xarrays.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from .general import *
from .indices import *
from .load import *
from .mask import *
from .mask_polygon import *
from .merge import *
from .reduce import *
from .resample import *
120 changes: 120 additions & 0 deletions openeo_processes_dask/process_implementations/cubes/mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import logging
from typing import Callable

import numpy as np

from openeo_processes_dask.process_implementations.cubes.resample import (
resample_cube_spatial,
)
from openeo_processes_dask.process_implementations.cubes.utils import notnull
from openeo_processes_dask.process_implementations.data_model import RasterCube
from openeo_processes_dask.process_implementations.exceptions import (
DimensionLabelCountMismatch,
DimensionMismatch,
LabelMismatch,
)
from openeo_processes_dask.process_implementations.logic import _not

logger = logging.getLogger(__name__)

__all__ = ["mask"]


def mask(data: RasterCube, mask: RasterCube, replacement=None) -> RasterCube:
if replacement is None:
replacement = np.nan

data_band_dims = data.openeo.band_dims
mask_band_dims = mask.openeo.band_dims
# Check if temporal dimensions are present and check the names
data_temporal_dims = data.openeo.temporal_dims
mask_temporal_dims = mask.openeo.temporal_dims

check_temporal_labels = True
if not set(data_temporal_dims) == set(mask_temporal_dims):
check_temporal_labels = False
# To continue with a valid case, mask shouldn't have a temporal dimension, so that the mask will be applied to all the temporal labels
if len(mask_temporal_dims) != 0:
raise DimensionMismatch(
f"data and mask temporal dimensions do no match: data has temporal dimensions ({data_temporal_dims}) and mask {mask_temporal_dims}."
)
if check_temporal_labels:
# Check if temporal labels correspond
for n in data_temporal_dims:
data_temporal_labels = data[n].values
mask_temporal_labels = mask[n].values
data_n_labels = len(data_temporal_labels)
mask_n_labels = len(mask_temporal_labels)

if not data_n_labels == mask_n_labels:
raise DimensionLabelCountMismatch(
f"data and mask temporal dimensions do no match: data has {data_n_labels} temporal dimensions labels and mask {mask_n_labels}."
)
elif not all(data_temporal_labels == mask_temporal_labels):
raise LabelMismatch(
f"data and mask temporal dimension labels don't match for dimension {n}."
)

# From the process definition: https://processes.openeo.org/#mask
# The data cubes have to be compatible except that the horizontal spatial dimensions (axes x and y) will be aligned implicitly by resample_cube_spatial.
apply_resample_cube_spatial = False

# Check if spatial dimensions have the same name
data_spatial_dims = data.openeo.spatial_dims
mask_spatial_dims = mask.openeo.spatial_dims
if not set(data_spatial_dims) == set(mask_spatial_dims):
raise DimensionMismatch(
f"data and mask spatial dimensions do no match: data has spatial dimensions ({data_spatial_dims}) and mask {mask_spatial_dims}"
)

# Check if spatial labels correspond
for n in data_spatial_dims:
data_spatial_labels = data[n].values
mask_spatial_labels = mask[n].values
data_n_labels = len(data_spatial_labels)
mask_n_labels = len(mask_spatial_labels)

if not data_n_labels == mask_n_labels:
apply_resample_cube_spatial = True
logger.info(
f"data and mask spatial dimension labels don't match: data has ({data_n_labels}) labels and mask has {mask_n_labels} for dimension {n}."
)
elif not all(data_spatial_labels == mask_spatial_labels):
apply_resample_cube_spatial = True
logger.info(
f"data and mask spatial dimension labels don't match for dimension {n}, i.e. the coordinate values are different."
)

if apply_resample_cube_spatial:
logger.info(f"mask is aligned to data using resample_cube_spatial.")
mask = resample_cube_spatial(data=mask, target=data)

original_dim_order = data.dims
# If bands dimension in data but not in mask, ensure that it comes first and all the other dimensions at the end
if len(data_band_dims) != 0 and len(mask_band_dims) == 0:
required_dim_order = (
data_band_dims[0] if len(data_band_dims) > 0 else (),
data_temporal_dims[0] if len(data_temporal_dims) > 0 else (),
data.openeo.y_dim,
data.openeo.x_dim,
)
data = data.transpose(*required_dim_order, missing_dims="ignore")
mask = mask.transpose(*required_dim_order, missing_dims="ignore")

elif len(data_temporal_dims) != 0 and len(mask_temporal_dims) == 0:
required_dim_order = (
data_temporal_dims[0] if len(data_temporal_dims) > 0 else (),
data_band_dims[0] if len(data_band_dims) > 0 else (),
data.openeo.y_dim,
data.openeo.x_dim,
)
data = data.transpose(*required_dim_order, missing_dims="ignore")
mask = mask.transpose(*required_dim_order, missing_dims="ignore")

data = data.where(_not(mask), replacement)

if len(data_band_dims) != 0 and len(mask_band_dims) == 0:
# Order axes back to how they were before
data = data.transpose(*original_dim_order)

return data
75 changes: 74 additions & 1 deletion openeo_processes_dask/process_implementations/cubes/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@
from typing import Optional, Union

import odc.geo.xr
import rioxarray # needs to be imported to set .rio accessor on xarray objects.
from odc.geo.geobox import resolution_from_affine
from pyproj.crs import CRS, CRSError

from openeo_processes_dask.process_implementations.data_model import RasterCube
from openeo_processes_dask.process_implementations.exceptions import OpenEOException
from openeo_processes_dask.process_implementations.exceptions import (
DimensionMissing,
OpenEOException,
)

logger = logging.getLogger(__name__)

__all__ = ["resample_spatial", "resample_cube_spatial"]

resample_methods_list = [
"near",
"bilinear",
Expand Down Expand Up @@ -78,3 +84,70 @@ def resample_spatial(
reprojected.attrs["crs"] = data_cp.rio.crs

return reprojected


def resample_cube_spatial(
data: RasterCube, target: RasterCube, method="near", options=None
) -> RasterCube:
methods_list = [
"near",
"bilinear",
"cubic",
"cubicspline",
"lanczos",
"average",
"mode",
"max",
"min",
"med",
"q1",
"q3",
]

if (
data.openeo.y_dim is None
or data.openeo.x_dim is None
or target.openeo.y_dim is None
or target.openeo.x_dim is None
):
raise DimensionMissing(
f"Spatial dimension missing from data or target. Available dimensions for data: {data.dims} for target: {target.dims}"
)

# ODC reproject requires y to be before x
required_dim_order = (..., data.openeo.y_dim, data.openeo.x_dim)

data_reordered = data.transpose(*required_dim_order, missing_dims="ignore")
target_reordered = target.transpose(*required_dim_order, missing_dims="ignore")

if method == "near":
method = "nearest"

elif method not in methods_list:
raise Exception(
f'Selected resampling method "{method}" is not available! Please select one of '
f"[{', '.join(methods_list)}]"
)

resampled_data = data_reordered.odc.reproject(
target_reordered.odc.geobox, resampling=method
)

resampled_data.rio.write_crs(target_reordered.rio.crs, inplace=True)

try:
# odc.reproject renames the coordinates according to the geobox, this undoes that.
resampled_data = resampled_data.rename(
{"longitude": data.openeo.x_dim, "latitude": data.openeo.y_dim}
)
except ValueError:
pass

# Order axes back to how they were before
resampled_data = resampled_data.transpose(*data.dims)

# Ensure that attrs except crs are copied over
for k, v in data.attrs.items():
if k.lower() != "crs":
resampled_data.attrs[k] = v
return resampled_data
8 changes: 8 additions & 0 deletions openeo_processes_dask/process_implementations/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,11 @@ class RedBandAmbiguous(OpenEOException):

class BandExists(OpenEOException):
pass


class DimensionMismatch(OpenEOException):
pass


class LabelMismatch(OpenEOException):
pass
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .resample import *

This file was deleted.

62 changes: 61 additions & 1 deletion tests/test_mask.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from functools import partial

import numpy as np
import pytest
from openeo_pg_parser_networkx.pg_schema import TemporalInterval
from openeo_pg_parser_networkx.pg_schema import ParameterReference, TemporalInterval

from openeo_processes_dask.process_implementations.cubes.mask import mask
from openeo_processes_dask.process_implementations.cubes.mask_polygon import (
mask_polygon,
)
from openeo_processes_dask.process_implementations.cubes.reduce import (
reduce_dimension,
reduce_spatial,
)
from tests.mockdata import create_fake_rastercube


Expand All @@ -29,3 +36,56 @@ def test_mask_polygon(
assert np.isnan(output_cube).sum() > np.isnan(input_cube).sum()
assert len(output_cube.y) == len(input_cube.y)
assert len(output_cube.x) == len(input_cube.x)


@pytest.mark.parametrize("size", [(30, 30, 20, 2)])
@pytest.mark.parametrize("dtype", [np.float32])
def test_mask(
temporal_interval,
bounding_box,
random_raster_data,
process_registry,
):
"""Test to ensure resolution gets changed correctly."""
input_cube = create_fake_rastercube(
data=random_raster_data,
spatial_extent=bounding_box,
temporal_extent=temporal_interval,
bands=["B02", "B03"],
backend="dask",
)

mask_cube = input_cube > 0
output_cube = mask(data=input_cube, mask=mask_cube)

assert np.isnan(output_cube).sum() > np.isnan(input_cube).sum()
assert len(output_cube.y) == len(input_cube.y)
assert len(output_cube.x) == len(input_cube.x)

_process = partial(
process_registry["max"].implementation,
ignore_nodata=True,
data=ParameterReference(from_parameter="data"),
)

mask_cube_no_x = reduce_dimension(data=mask_cube, dimension="x", reducer=_process)
with pytest.raises(Exception):
output_cube = mask(data=input_cube, mask=mask_cube_no_x)

# Mask should work without bands
mask_cube_no_bands = reduce_dimension(
data=mask_cube, dimension="bands", reducer=_process
)
output_cube = mask(data=input_cube, mask=mask_cube_no_bands)

# Mask should work without time
mask_cube_no_time = reduce_dimension(
data=mask_cube, dimension="t", reducer=_process
)
output_cube = mask(data=input_cube, mask=mask_cube_no_time)

# Mask should work without time and bands
mask_cube_no_time_bands = reduce_dimension(
data=mask_cube_no_bands, dimension="t", reducer=_process
)
output_cube = mask(data=input_cube, mask=mask_cube_no_time_bands)
Loading

0 comments on commit 4a011c4

Please sign in to comment.