diff --git a/openeo_processes_dask/process_implementations/ml/curve_fitting.py b/openeo_processes_dask/process_implementations/ml/curve_fitting.py index aad967be..98ef6db4 100644 --- a/openeo_processes_dask/process_implementations/ml/curve_fitting.py +++ b/openeo_processes_dask/process_implementations/ml/curve_fitting.py @@ -25,6 +25,10 @@ def fit_curve( raise DimensionNotAvailable( f"Provided dimension ({dimension}) not found in data.dims: {data.dims}" ) + bands_required = False + if "bands" in data.dims: + if len(data["bands"].values) == 1: + bands_required = data["bands"].values[0] try: # Try parsing as datetime first @@ -81,11 +85,15 @@ def _wrap(*args, **kwargs): .drop_dims(["cov_i", "cov_j"]) .to_array() .squeeze() - .transpose(*expected_dims_after) ) fit_result.attrs = data.attrs fit_result = fit_result.rio.write_crs(rechunked_data.rio.crs) + if bands_required and not "bands" in fit_result.dims: + fit_result = fit_result.assign_coords(**{"bands": bands_required}) + fit_result = fit_result.expand_dims(dim="bands") + + fit_result = fit_result.transpose(*expected_dims_after) return fit_result @@ -99,6 +107,7 @@ def predict_curve( ): labels_were_datetime = False dims_before = list(parameters.dims) + initial_labels = labels try: # Try parsing as datetime first @@ -108,7 +117,6 @@ def predict_curve( if np.issubdtype(labels.dtype, np.datetime64): labels_were_datetime = True - initial_labels = labels timestep = [ ( (np.datetime64(x) - np.datetime64("1970-01-01", "s")) diff --git a/pyproject.toml b/pyproject.toml index a2ccff80..426ffefa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "openeo-processes-dask" -version = "2023.10.2" +version = "2023.10.3" description = "Python implementations of many OpenEO processes, dask-friendly by default." authors = ["Lukas Weidenholzer ", "Sean Hoyal ", "Valentina Hutter "] maintainers = ["EODC Staff "] diff --git a/tests/test_ml.py b/tests/test_ml.py index 4248b212..3f185fd1 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -83,8 +83,14 @@ def fitFunction(x, parameters): assert len(result.coords["y"]) == len(origin_cube.coords["y"]) assert len(result.coords["param"]) == len(parameters) + origin_cube_B02 = origin_cube.sel(bands=["B02"]) + result_B02 = fit_curve( + origin_cube_B02, parameters=parameters, function=_process, dimension="t" + ) + assert "bands" in result_B02.dims + assert result_B02["bands"].values == "B02" + labels = dimension_labels(origin_cube, origin_cube.openeo.temporal_dims[0]) - labels = [float(l) for l in labels] predictions = predict_curve( result, _process, @@ -96,7 +102,7 @@ def fitFunction(x, parameters): assert "param" not in predictions.dims assert result.rio.crs == predictions.rio.crs - labels = ["2020-02-02", "2020-03-02", "2020-04-02", "2020-05-02"] + labels = [0, 1, 2, 3] predictions = predict_curve( result, _process,