Skip to content

Commit

Permalink
fit_curve with dimension convertion
Browse files Browse the repository at this point in the history
  • Loading branch information
ValentinaHutter committed Oct 4, 2023
1 parent 7ddf6c2 commit 75f63fb
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
29 changes: 27 additions & 2 deletions openeo_processes_dask/process_implementations/ml/curve_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@ def fit_curve(
f"Provided dimension ({dimension}) not found in data.dims: {data.dims}"
)

try:
# Try parsing as datetime first
dates = data[dimension].values
dates = np.asarray(dates, dtype=np.datetime64)
except ValueError:
dates = np.asarray(data[dimension].values)

Check warning on line 34 in openeo_processes_dask/process_implementations/ml/curve_fitting.py

View check run for this annotation

Codecov / codecov/patch

openeo_processes_dask/process_implementations/ml/curve_fitting.py#L33-L34

Added lines #L33 - L34 were not covered by tests

if np.issubdtype(dates.dtype, np.datetime64):
timestep = [
(
(np.datetime64(x) - np.datetime64("1970-01-01", "s"))
/ np.timedelta64(1, "s")
)
for x in dates
]
data[dimension] = np.array(timestep)

dims_before = list(data.dims)

# In the spec, parameters is a list, but xr.curvefit requires names for them,
Expand Down Expand Up @@ -87,8 +104,16 @@ def predict_curve(
labels = np.asarray(labels)

if np.issubdtype(labels.dtype, np.datetime64):
labels = labels.astype(int)
labels_were_datetime = True
initial_labels = labels
timestep = [
(
(np.datetime64(x) - np.datetime64("1970-01-01", "s"))
/ np.timedelta64(1, "s")
)
for x in labels
]
labels = np.array(timestep)

# This is necessary to pipe the arguments correctly through @process
def wrapper(f):
Expand Down Expand Up @@ -122,6 +147,6 @@ def _wrap(*args, **kwargs):
predictions = predictions.assign_coords({dimension: labels.data})

if labels_were_datetime:
predictions[dimension] = pd.DatetimeIndex(predictions[dimension].values)
predictions[dimension] = initial_labels

return predictions
13 changes: 13 additions & 0 deletions tests/test_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ def fitFunction(x, parameters):
assert len(result.coords["param"]) == len(parameters)

labels = dimension_labels(origin_cube, origin_cube.openeo.temporal_dims[0])
labels = [float(l) for l in labels]
predictions = predict_curve(
result,
_process,
origin_cube.openeo.temporal_dims[0],
labels=labels,
).compute()

assert len(predictions.coords[origin_cube.openeo.temporal_dims[0]]) == len(labels)
assert "param" not in predictions.dims
assert result.rio.crs == predictions.rio.crs

labels = ["2020-02-02", "2020-03-02", "2020-04-02", "2020-05-02"]
predictions = predict_curve(
result,
_process,
Expand Down

0 comments on commit 75f63fb

Please sign in to comment.