From 75f63fb2fcb4a064597547dc6cf215639b771609 Mon Sep 17 00:00:00 2001 From: ValentinaHutter Date: Wed, 4 Oct 2023 15:41:43 +0200 Subject: [PATCH] fit_curve with dimension convertion --- .../ml/curve_fitting.py | 29 +++++++++++++++++-- tests/test_ml.py | 13 +++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/openeo_processes_dask/process_implementations/ml/curve_fitting.py b/openeo_processes_dask/process_implementations/ml/curve_fitting.py index 81736ea7..8f5f400d 100644 --- a/openeo_processes_dask/process_implementations/ml/curve_fitting.py +++ b/openeo_processes_dask/process_implementations/ml/curve_fitting.py @@ -26,6 +26,23 @@ def fit_curve( f"Provided dimension ({dimension}) not found in data.dims: {data.dims}" ) + try: + # Try parsing as datetime first + dates = data[dimension].values + dates = np.asarray(dates, dtype=np.datetime64) + except ValueError: + dates = np.asarray(data[dimension].values) + + if np.issubdtype(dates.dtype, np.datetime64): + timestep = [ + ( + (np.datetime64(x) - np.datetime64("1970-01-01", "s")) + / np.timedelta64(1, "s") + ) + for x in dates + ] + data[dimension] = np.array(timestep) + dims_before = list(data.dims) # In the spec, parameters is a list, but xr.curvefit requires names for them, @@ -87,8 +104,16 @@ def predict_curve( labels = np.asarray(labels) if np.issubdtype(labels.dtype, np.datetime64): - labels = labels.astype(int) labels_were_datetime = True + initial_labels = labels + timestep = [ + ( + (np.datetime64(x) - np.datetime64("1970-01-01", "s")) + / np.timedelta64(1, "s") + ) + for x in labels + ] + labels = np.array(timestep) # This is necessary to pipe the arguments correctly through @process def wrapper(f): @@ -122,6 +147,6 @@ def _wrap(*args, **kwargs): predictions = predictions.assign_coords({dimension: labels.data}) if labels_were_datetime: - predictions[dimension] = pd.DatetimeIndex(predictions[dimension].values) + predictions[dimension] = initial_labels return predictions diff --git a/tests/test_ml.py b/tests/test_ml.py index c9157010..4248b212 100644 --- a/tests/test_ml.py +++ b/tests/test_ml.py @@ -84,6 +84,19 @@ def fitFunction(x, parameters): assert len(result.coords["param"]) == len(parameters) labels = dimension_labels(origin_cube, origin_cube.openeo.temporal_dims[0]) + labels = [float(l) for l in labels] + predictions = predict_curve( + result, + _process, + origin_cube.openeo.temporal_dims[0], + labels=labels, + ).compute() + + assert len(predictions.coords[origin_cube.openeo.temporal_dims[0]]) == len(labels) + assert "param" not in predictions.dims + assert result.rio.crs == predictions.rio.crs + + labels = ["2020-02-02", "2020-03-02", "2020-04-02", "2020-05-02"] predictions = predict_curve( result, _process,