From 406c306edbf0ba6f4a48dcb43fb95f31257ee084 Mon Sep 17 00:00:00 2001 From: Lukas Weidenholzer <17790923+LukeWeidenwalker@users.noreply.github.com> Date: Mon, 28 Aug 2023 13:45:24 +0200 Subject: [PATCH] fix: crs being lost on fit_curve (#156) fix crs being lost on fit_curve --- .../process_implementations/ml/curve_fitting.py | 1 + tests/test_ml.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/openeo_processes_dask/process_implementations/ml/curve_fitting.py b/openeo_processes_dask/process_implementations/ml/curve_fitting.py index 425ce030..ac13b9ca 100644 --- a/openeo_processes_dask/process_implementations/ml/curve_fitting.py +++ b/openeo_processes_dask/process_implementations/ml/curve_fitting.py @@ -65,6 +65,7 @@ def _wrap(*args, **kwargs): ) fit_result.attrs = data.attrs + fit_result = fit_result.rio.write_crs(rechunked_data.rio.crs) return fit_result diff --git a/tests/test_ml.py b/tests/test_ml.py index 1325bace..c9157010 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -76,6 +76,7 @@ def fitFunction(x, parameters): ) assert len(result.param) == 3 assert isinstance(result.data, dask.array.Array) + assert result.rio.crs == origin_cube.rio.crs assert len(result.coords["bands"]) == len(origin_cube.coords["bands"]) assert len(result.coords["x"]) == len(origin_cube.coords["x"]) @@ -92,3 +93,4 @@ def fitFunction(x, parameters): assert len(predictions.coords[origin_cube.openeo.temporal_dims[0]]) == len(labels) assert "param" not in predictions.dims + assert result.rio.crs == predictions.rio.crs