Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add more robust version of ndvi #147

Merged
merged 7 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .aggregate import *
from .apply import *
from .general import *
from .indices import *
from .load import *
from .merge import *
from .reduce import *
59 changes: 59 additions & 0 deletions openeo_processes_dask/process_implementations/cubes/indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import xarray as xr

from openeo_processes_dask.process_implementations.data_model import RasterCube
from openeo_processes_dask.process_implementations.exceptions import (
BandExists,
DimensionAmbiguous,
NirBandAmbiguous,
RedBandAmbiguous,
)
from openeo_processes_dask.process_implementations.math import normalized_difference

__all__ = ["ndvi"]


def ndvi(data: RasterCube, nir="nir", red="red", target_band=None):
if len(data.openeo.band_dims) == 0:
raise DimensionAmbiguous(
"Dimension of type `bands` is not available or is ambiguous."
)
band_dim = data.openeo.band_dims[0]
available_bands = data.coords[band_dim]

if nir not in available_bands or red not in available_bands:
try:
data = data.set_xindex("common_name")
except (ValueError, KeyError):
pass

if (
nir not in available_bands
and "common_name" in data.xindexes._coord_name_id.keys()
and nir not in data.coords["common_name"].data
):
raise NirBandAmbiguous(
"The NIR band can't be resolved, please specify the specific NIR band name."
)
elif (
red not in available_bands
and "common_name" in data.xindexes._coord_name_id.keys()
and red not in data.coords["common_name"].data
):
raise RedBandAmbiguous(
"The Red band can't be resolved, please specify the specific Red band name."
)

nir_band_dim = "common_name" if nir not in available_bands else band_dim
red_band_dim = "common_name" if red not in available_bands else band_dim
LukeWeidenwalker marked this conversation as resolved.
Show resolved Hide resolved

nir_band = data.sel({nir_band_dim: nir})
red_band = data.sel({red_band_dim: red})

nd = normalized_difference(nir_band, red_band)
if target_band is not None:
if target_band in data.coords:
raise BandExists("A band with the specified target name exists.")
nd = nd.expand_dims(band_dim).assign_coords({band_dim: [target_band]})
nd = xr.merge([data, nd])
nd.attrs = data.attrs
return nd
16 changes: 16 additions & 0 deletions openeo_processes_dask/process_implementations/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,19 @@ class NoDataAvailable(OpenEOException):

class TemporalExtentEmpty(OpenEOException):
pass


class DimensionAmbiguous(OpenEOException):
pass


class NirBandAmbiguous(OpenEOException):
pass


class RedBandAmbiguous(OpenEOException):
pass


class BandExists(OpenEOException):
pass
39 changes: 0 additions & 39 deletions openeo_processes_dask/process_implementations/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
"quantiles",
"product",
"normalized_difference",
"ndvi",
]


Expand Down Expand Up @@ -326,41 +325,3 @@ def product(data, ignore_nodata=True, axis=None, keepdims=False):
def normalized_difference(x, y):
nd = (x - y) / (x + y)
return nd


def ndvi(data, nir="nir", red="red", target_band=None):
r = np.nan
n = np.nan
if "bands" in data.dims:
if red == "red":
if "B04" in data["bands"].values:
r = data.sel(bands="B04")
elif red == "rededge":
if "B05" in data["bands"].values:
r = data.sel(bands="B05")
elif "B06" in data["bands"].values:
r = data.sel(bands="B06")
elif "B07" in data["bands"].values:
r = data.sel(bands="B07")
if nir == "nir":
n = data.sel(bands="B08")
elif nir == "nir08":
if "B8a" in data["bands"].values:
n = data.sel(bands="B8a")
elif "B8A" in data["bands"].values:
n = data.sel(bands="B8A")
elif "B05" in data["bands"].values:
n = data.sel(bands="B05")
elif nir == "nir09":
if "B09" in data["bands"].values:
n = data.sel(bands="B09")
if red in data["bands"].values:
r = data.sel(bands=red)
if nir in data["bands"].values:
n = data.sel(bands=nir)
nd = normalized_difference(n, r)
if target_band is not None:
nd = nd.assign_coords(bands=target_band)
# TODO: Remove this once we have the .openeo accessor
nd.attrs = data.attrs
return nd
8 changes: 6 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,18 @@ def dask_client():
client.shutdown()


@pytest.fixture
def random_raster_data(size, dtype, seed=42):
def _random_raster_data(size, dtype, seed=42):
rng = np.random.default_rng(seed)
data = rng.integers(-100, 100, size=size)
data = data.astype(dtype)
return data


@pytest.fixture
def random_raster_data(size, dtype, seed=42):
return _random_raster_data(size, dtype, seed=seed)


@pytest.fixture
def bounding_box(
west=10.45, east=10.5, south=46.1, north=46.2, crs="EPSG:4326"
Expand Down
85 changes: 85 additions & 0 deletions tests/test_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np
import pytest

from openeo_processes_dask.process_implementations.cubes.indices import ndvi
from openeo_processes_dask.process_implementations.cubes.load import load_stac
from openeo_processes_dask.process_implementations.exceptions import (
BandExists,
DimensionAmbiguous,
NirBandAmbiguous,
RedBandAmbiguous,
)
from tests.conftest import _random_raster_data
from tests.general_checks import general_output_checks


def test_ndvi(bounding_box):
url = "./tests/data/stac/s2_l2a_test_item.json"
input_cube = load_stac(
url=url,
spatial_extent=bounding_box,
bands=["red", "nir"],
).isel({"x": slice(0, 20), "y": slice(0, 20)})

# Test whether this works with different band names
input_cube = input_cube.rename({"band": "b"})

import dask.array as da

numpy_data = _random_raster_data(input_cube.data.shape, dtype=np.float64)

input_cube.data = da.from_array(numpy_data, chunks=("auto", "auto", "auto", -1))

output = ndvi(input_cube)

band_dim = input_cube.openeo.band_dims[0]
assert band_dim not in output.dims

expected_results = (
input_cube.sel({band_dim: "nir"}) - input_cube.sel({band_dim: "red"})
) / (input_cube.sel({band_dim: "nir"}) + input_cube.sel({band_dim: "red"}))

general_output_checks(
input_cube=input_cube, output_cube=output, expected_results=expected_results
)

cube_with_resolvable_coords = input_cube.assign_coords(
{band_dim: ["blue", "yellow"]}
)
output = ndvi(cube_with_resolvable_coords)
general_output_checks(
input_cube=cube_with_resolvable_coords,
output_cube=output,
expected_results=expected_results,
)

with pytest.raises(DimensionAmbiguous):
ndvi(output)

cube_with_nir_unresolvable = cube_with_resolvable_coords
cube_with_nir_unresolvable.common_name.data = np.array(["blue", "red"])

with pytest.raises(NirBandAmbiguous):
ndvi(cube_with_nir_unresolvable)

cube_with_red_unresolvable = cube_with_resolvable_coords
cube_with_red_unresolvable.common_name.data = np.array(["nir", "yellow"])

with pytest.raises(RedBandAmbiguous):
ndvi(cube_with_red_unresolvable)

cube_with_nothing_resolvable = cube_with_resolvable_coords
cube_with_nothing_resolvable = cube_with_nothing_resolvable.drop_vars("common_name")
with pytest.raises(KeyError):
ndvi(cube_with_nothing_resolvable)

target_band = "yay"
output_with_extra_dim = ndvi(input_cube, target_band=target_band)
assert len(output_with_extra_dim.dims) == len(output.dims) + 1
assert (
len(output_with_extra_dim.coords[band_dim])
== len(input_cube.coords[band_dim]) + 1
)

with pytest.raises(BandExists):
output_with_extra_dim = ndvi(input_cube, target_band="time")