Skip to content

Commit

Permalink
Adding 'Detecting Deep Moist Convection' to openeo-processes-dask (#263)
Browse files Browse the repository at this point in the history
* add ddmc and test

* update pydantic dependencies

* Adding ddmc to openeo-processes-dask

* Changing the version in the pyproject and changing the import statement to the right folder

* Deleting the unused file example%20file.geoparquet
  • Loading branch information
koenifra authored Aug 12, 2024
1 parent 5d0825b commit f3cd4dd
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._filter import *
from .aggregate import *
from .apply import *
from .ddmc import *
from .general import *
from .indices import *
from .load import *
Expand Down
86 changes: 86 additions & 0 deletions openeo_processes_dask/process_implementations/cubes/ddmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from openeo_processes_dask.process_implementations.arrays import array_element
from openeo_processes_dask.process_implementations.cubes.general import add_dimension
from openeo_processes_dask.process_implementations.cubes.merge import merge_cubes
from openeo_processes_dask.process_implementations.cubes.reduce import reduce_dimension
from openeo_processes_dask.process_implementations.data_model import RasterCube

__all__ = ["ddmc"]


def ddmc(
data: RasterCube,
nir08="nir08",
nir09="nir09",
cirrus="cirrus",
swir16="swir16",
swir22="swir22",
gain=2.5,
target_band=None,
):
dimension = data.openeo.band_dims[0]
if target_band is None:
target_band = dimension

# Mid-Level Clouds
def MIDCL(data):
# B08 = array_element(data, label=nir08, axis = axis)

B08 = data.sel(**{dimension: nir08})

# B09 = array_element(data, label=nir09, axis = axis)

B09 = data.sel(**{dimension: nir09})

MIDCL = B08 - B09

MIDCL_result = MIDCL * gain

return MIDCL_result

# Deep moist convection
def DC(data):
# B10 = array_element(data, label=cirrus, axis = axis)
# B12 = array_element(data, label=swir22, axis = axis)

B10 = data.sel(**{dimension: cirrus})
B12 = data.sel(**{dimension: swir22})

DC = B10 - B12

DC_result = DC * gain

return DC_result

# low-level cloudiness
def LOWCL(data):
# B10 = array_element(data, label=cirrus, axis = axis)
# B11 = array_element(data, label=swir16, axis = axis)
B10 = data.sel(**{dimension: cirrus})
B11 = data.sel(**{dimension: swir16})

LOWCL = B11 - B10

LOWCL_result = LOWCL * gain

return LOWCL_result

# midcl = reduce_dimension(data, reducer=MIDCL, dimension=dimension)
midcl = MIDCL(data)
midcl = add_dimension(midcl, name=target_band, label="midcl", type=dimension)

# dc = reduce_dimension(data, reducer=DC, dimension=dimension)
dc = DC(data)
# dc = add_dimension(dc, target_band, "dc")
dc = add_dimension(dc, target_band, label="dc", type=dimension)

# lowcl = reduce_dimension(data, reducer=LOWCL, dimension=dimension)
lowcl = LOWCL(data)
lowcl = add_dimension(lowcl, target_band, label="lowcl", type=dimension)

# ddmc = merge_cubes(merge_cubes(midcl, dc), lowcl)
ddmc1 = merge_cubes(midcl, lowcl)
ddmc1.openeo.add_dim_type(name=target_band, type=dimension)
ddmc = merge_cubes(dc, ddmc1, overlap_resolver=target_band)

# return a datacube
return ddmc
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ rasterio = { version = "^1.3.4", optional = true }
dask-geopandas = { version = ">=0.2.0,<1", optional = true }
xgboost = { version = ">=1.5.1", optional = true }
rioxarray = { version = ">=0.12.0,<1", optional = true }
openeo-pg-parser-networkx = { version = ">=2023.5.1", optional = true }
openeo-pg-parser-networkx = { version = ">=2024.7", optional = true }
odc-geo = { version = ">=0.4.1,<1", optional = true }
stac_validator = { version = ">=3.3.1", optional = true }
stackstac = { version = ">=0.4.3", optional = true }
Expand Down
75 changes: 75 additions & 0 deletions tests/test_ddmc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from functools import partial

import numpy as np
import pytest
import xarray as xr
from openeo_pg_parser_networkx.pg_schema import (
BoundingBox,
ParameterReference,
TemporalInterval,
)

from openeo_processes_dask.process_implementations.cubes.ddmc import ddmc
from openeo_processes_dask.process_implementations.cubes.load import load_stac
from openeo_processes_dask.process_implementations.cubes.reduce import (
reduce_dimension,
reduce_spatial,
)
from openeo_processes_dask.process_implementations.exceptions import (
ArrayElementNotAvailable,
)
from tests.general_checks import general_output_checks
from tests.mockdata import create_fake_rastercube


@pytest.mark.parametrize("size", [(30, 30, 20, 5)])
@pytest.mark.parametrize("dtype", [np.float32])
def test_ddmc_instance_dims(
temporal_interval: TemporalInterval, bounding_box: BoundingBox, random_raster_data
):
input_cube = create_fake_rastercube(
data=random_raster_data,
spatial_extent=bounding_box,
temporal_extent=temporal_interval,
bands=["nir08", "nir09", "cirrus", "swir16", "swir22"],
backend="dask",
)

data = ddmc(input_cube)

assert isinstance(data, xr.DataArray)
assert set(input_cube.dims) == set(data.dims)


@pytest.mark.parametrize("size", [(30, 30, 20, 5)])
@pytest.mark.parametrize("dtype", [np.float32])
def test_ddmc_target_band(
temporal_interval: TemporalInterval, bounding_box: BoundingBox, random_raster_data
):
input_cube = create_fake_rastercube(
data=random_raster_data,
spatial_extent=bounding_box,
temporal_extent=temporal_interval,
bands=["nir08", "nir09", "cirrus", "swir16", "swir22"],
backend="dask",
)

data_band = ddmc(data=input_cube, target_band="ddmc")
assert "ddmc" in data_band.dims


@pytest.mark.parametrize("size", [(30, 30, 20, 5)])
@pytest.mark.parametrize("dtype", [np.float32])
def test_ddmc_input_cube_exception(
temporal_interval: TemporalInterval, bounding_box: BoundingBox, random_raster_data
):
input_cube_exception = create_fake_rastercube(
data=random_raster_data,
spatial_extent=bounding_box,
temporal_extent=temporal_interval,
bands=["b04", "nir09", "cirrus", "swir16", "swir22"],
backend="dask",
)

with pytest.raises(KeyError):
data = ddmc(input_cube_exception)

0 comments on commit f3cd4dd

Please sign in to comment.