diff --git a/.gitignore b/.gitignore index 1e69b7a5..96570fad 100644 --- a/.gitignore +++ b/.gitignore @@ -188,3 +188,5 @@ _build/ *.sync *.dot _dev/ +*.to_upload +*.tmp diff --git a/CHANGELOG.md b/CHANGELOG.md index 7dd85782..c77268ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,10 @@ Keep it human-readable, your future self will thank you! ### Changed +- Support sub-hourly datasets. +- Change negative variance detection to make it less restrictive +- Fix cutout bug that left some global grid points in the lam part + ### Removed ## [0.4.4] Bug fixes diff --git a/docs/using/combining.rst b/docs/using/combining.rst index 6f2aadea..c882fb62 100644 --- a/docs/using/combining.rst +++ b/docs/using/combining.rst @@ -155,3 +155,28 @@ cutout: :width: 75% :align: center :alt: Cutout + +You can also pass a `min_distance_km` parameter to the `cutout` +function. Any grid points in the global dataset that are closer than +this distance to a grid point in the LAM dataset will be removed. This +can be useful to control the behaviour of the algorithm at the edge of +the cutout area. If no value is provided, the algorithm will compute its +value as the smallest distance between two grid points in the global +dataset over the cutout area. If you do not want to use this feature, +you can set `min_distance_km=0`, or provide your own value. + +The plots below illustrate how the cutout differs if `min_distance_km` +is not given (top) or if `min_distance_km` is is set to `0` (bottom). +The difference can be seen at the boundary between the two grids: + +.. image:: images/cutout-5.png + :align: center + :alt: Cutout + +.. image:: images/cutout-6.png + :align: center + :alt: Cutout + +To debug the combination, you can pass `plot=True` to the `cutout` +function (when running from a Notebook), of use `plot="prefix"` to save +the plots to series of PNG files in the current directory. diff --git a/docs/using/images/cutout-5.png b/docs/using/images/cutout-5.png new file mode 100644 index 00000000..b110c9aa Binary files /dev/null and b/docs/using/images/cutout-5.png differ diff --git a/docs/using/images/cutout-6.png b/docs/using/images/cutout-6.png new file mode 100644 index 00000000..2d7ae14a Binary files /dev/null and b/docs/using/images/cutout-6.png differ diff --git a/pyproject.toml b/pyproject.toml index 1b1e2dac..12cbf150 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ dynamic = [ "version", ] dependencies = [ - "anemoi-utils[provenance]>=0.3.13", + "anemoi-utils[provenance]>=0.3.15", "numpy", "pyyaml", "semantic-version", diff --git a/src/anemoi/datasets/create/check.py b/src/anemoi/datasets/create/check.py index 0c262113..068a044a 100644 --- a/src/anemoi/datasets/create/check.py +++ b/src/anemoi/datasets/create/check.py @@ -56,7 +56,7 @@ def raise_if_not_valid(self, print=print): raise ValueError(self.error_message) def _parse(self, name): - pattern = r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h)-v(\d+)-?([a-zA-Z0-9-]+)$" + pattern = r"^(\w+)-([\w-]+)-(\w+)-(\w+)-(\d\d\d\d)-(\d\d\d\d)-(\d+h)-v(\d+)-?([a-zA-Z0-9-]+)?$" match = re.match(pattern, name) assert match, (name, pattern) diff --git a/src/anemoi/datasets/create/functions/sources/xarray/__init__.py b/src/anemoi/datasets/create/functions/sources/xarray/__init__.py index 468b7a35..32cb607c 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/__init__.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/__init__.py @@ -52,9 +52,19 @@ def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs) result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates]) if len(result) == 0: - LOG.warning(f"No data found for {dataset} and dates {dates}") + LOG.warning(f"No data found for {dataset} and dates {dates} and {kwargs}") LOG.warning(f"Options: {options}") - LOG.warning(data) + + for i, k in enumerate(fs): + a = ["valid_datetime", k.metadata("valid_datetime", default=None)] + for n in kwargs.keys(): + a.extend([n, k.metadata(n, default=None)]) + print([str(x) for x in a]) + + if i > 16: + break + + # LOG.warning(data) return result diff --git a/src/anemoi/datasets/create/functions/sources/xarray/coordinates.py b/src/anemoi/datasets/create/functions/sources/xarray/coordinates.py index f08a00f7..a9c2d142 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/coordinates.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/coordinates.py @@ -55,6 +55,7 @@ class Coordinate: is_time = False is_step = False is_date = False + is_member = False def __init__(self, variable): self.variable = variable @@ -201,8 +202,14 @@ def normalise(self, value): class EnsembleCoordinate(Coordinate): + is_member = True mars_names = ("number",) + def normalise(self, value): + if int(value) == value: + return int(value) + return value + class LongitudeCoordinate(Coordinate): is_grid = True diff --git a/src/anemoi/datasets/create/functions/sources/xarray/field.py b/src/anemoi/datasets/create/functions/sources/xarray/field.py index d464df78..cdbd061f 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/field.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/field.py @@ -80,7 +80,7 @@ def to_numpy(self, flatten=False, dtype=None): return values.reshape(self.shape) def _make_metadata(self): - return XArrayMetadata(self, self.owner.mapping) + return XArrayMetadata(self) def grid_points(self): return self.owner.grid_points() diff --git a/src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py b/src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py index 90ab44aa..52d4a388 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/fieldlist.py @@ -134,8 +134,6 @@ def sel(self, **kwargs): for v in self.variables: - v.update_metadata_mapping(kwargs) - # First, select matching variables # This will consume 'param' or 'variable' from kwargs # and return the rest diff --git a/src/anemoi/datasets/create/functions/sources/xarray/flavour.py b/src/anemoi/datasets/create/functions/sources/xarray/flavour.py index 373cd963..e339536e 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/flavour.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/flavour.py @@ -9,6 +9,7 @@ from .coordinates import DateCoordinate +from .coordinates import EnsembleCoordinate from .coordinates import LatitudeCoordinate from .coordinates import LevelCoordinate from .coordinates import LongitudeCoordinate @@ -135,6 +136,17 @@ def _guess(self, c, coord): if d is not None: return d + d = self._is_number( + c, + axis=axis, + name=name, + long_name=long_name, + standard_name=standard_name, + units=units, + ) + if d is not None: + return d + if c.shape in ((1,), tuple()): return ScalarCoordinate(c) @@ -249,9 +261,13 @@ def _is_level(self, c, *, axis, name, long_name, standard_name, units): if standard_name == "depth": return LevelCoordinate(c, "depth") - if name == "pressure": + if name == "vertical" and units == "hPa": return LevelCoordinate(c, "pl") + def _is_number(self, c, *, axis, name, long_name, standard_name, units): + if name in ("realization", "number"): + return EnsembleCoordinate(c) + class FlavourCoordinateGuesser(CoordinateGuesser): def __init__(self, ds, flavour): @@ -328,3 +344,7 @@ def _levtype(self, c, *, axis, name, long_name, standard_name, units): return self.flavour["levtype"] raise NotImplementedError(f"levtype for {c=}") + + def _is_number(self, c, *, axis, name, long_name, standard_name, units): + if self._match(c, "number", locals()): + return DateCoordinate(c) diff --git a/src/anemoi/datasets/create/functions/sources/xarray/metadata.py b/src/anemoi/datasets/create/functions/sources/xarray/metadata.py index 877045b8..e98f9ea7 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/metadata.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/metadata.py @@ -10,29 +10,37 @@ import logging from functools import cached_property +from anemoi.utils.dates import as_datetime from earthkit.data.core.geography import Geography from earthkit.data.core.metadata import RawMetadata -from earthkit.data.utils.dates import to_datetime from earthkit.data.utils.projections import Projection LOG = logging.getLogger(__name__) -class MDMapping: +class _MDMapping: - def __init__(self, mapping): - self.user_to_internal = mapping + def __init__(self, variable): + self.variable = variable + self.time = variable.time + self.mapping = dict(param="variable") + for c in variable.coordinates: + for v in c.mars_names: + assert v not in self.mapping, f"Duplicate key '{v}' in {c}" + self.mapping[v] = c.variable.name - def from_user(self, kwargs): - if isinstance(kwargs, str): - return self.user_to_internal.get(kwargs, kwargs) - return {self.user_to_internal.get(k, k): v for k, v in kwargs.items()} + def _from_user(self, key): + return self.mapping.get(key, key) - def __len__(self): - return len(self.user_to_internal) + def from_user(self, kwargs): + print("from_user", kwargs, self) + return {self._from_user(k): v for k, v in kwargs.items()} def __repr__(self): - return f"MDMapping({self.user_to_internal})" + return f"MDMapping({self.mapping})" + + def fill_time_metadata(self, field, md): + md["valid_datetime"] = as_datetime(self.variable.time.fill_time_metadata(field._md, md)).isoformat() class XArrayMetadata(RawMetadata): @@ -40,23 +48,11 @@ class XArrayMetadata(RawMetadata): NAMESPACES = ["default", "mars"] MARS_KEYS = ["param", "step", "levelist", "levtype", "number", "date", "time"] - def __init__(self, field, mapping): + def __init__(self, field): self._field = field md = field._md.copy() - - self._mapping = mapping - if mapping is None: - time_coord = [c for c in field.owner.coordinates if c.is_time] - if len(time_coord) == 1: - time_key = time_coord[0].name - else: - time_key = "time" - else: - time_key = mapping.from_user("valid_datetime") - self._time = to_datetime(md.pop(time_key)) - self._field.owner.time.fill_time_metadata(self._time, md) - md["valid_datetime"] = self._time.isoformat() - + self._mapping = _MDMapping(field.owner) + self._mapping.fill_time_metadata(field, md) super().__init__(md) @cached_property @@ -88,10 +84,13 @@ def _base_datetime(self): return self._field.forecast_reference_time def _valid_datetime(self): - return self._time + return self._get("valid_datetime") def _get(self, key, **kwargs): + if key in self._d: + return self._d[key] + if key.startswith("mars."): key = key[5:] if key not in self.MARS_KEYS: @@ -100,8 +99,7 @@ def _get(self, key, **kwargs): else: return kwargs.get("default", None) - if self._mapping is not None: - key = self._mapping.from_user(key) + key = self._mapping._from_user(key) return super()._get(key, **kwargs) diff --git a/src/anemoi/datasets/create/functions/sources/xarray/time.py b/src/anemoi/datasets/create/functions/sources/xarray/time.py index 65c97165..6e845d6b 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/time.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/time.py @@ -10,8 +10,11 @@ import datetime +from anemoi.utils.dates import as_datetime + class Time: + @classmethod def from_coordinates(cls, coordinates): time_coordinate = [c for c in coordinates if c.is_time] @@ -19,16 +22,16 @@ def from_coordinates(cls, coordinates): date_coordinate = [c for c in coordinates if c.is_date] if len(date_coordinate) == 0 and len(time_coordinate) == 1 and len(step_coordinate) == 1: - return ForecasstFromValidTimeAndStep(step_coordinate[0]) + return ForecastFromValidTimeAndStep(time_coordinate[0], step_coordinate[0]) if len(date_coordinate) == 0 and len(time_coordinate) == 1 and len(step_coordinate) == 0: - return Analysis() + return Analysis(time_coordinate[0]) if len(date_coordinate) == 0 and len(time_coordinate) == 0 and len(step_coordinate) == 0: return Constant() if len(date_coordinate) == 1 and len(time_coordinate) == 1 and len(step_coordinate) == 0: - return ForecastFromValidTimeAndBaseTime(date_coordinate[0]) + return ForecastFromValidTimeAndBaseTime(date_coordinate[0], time_coordinate[0]) if len(date_coordinate) == 1 and len(time_coordinate) == 0 and len(step_coordinate) == 1: return ForecastFromBaseTimeAndDate(date_coordinate[0], step_coordinate[0]) @@ -38,61 +41,91 @@ def from_coordinates(cls, coordinates): class Constant(Time): - def fill_time_metadata(self, time, metadata): - metadata["date"] = time.strftime("%Y%m%d") - metadata["time"] = time.strftime("%H%M") - metadata["step"] = 0 + def fill_time_metadata(self, coords_values, metadata): + raise NotImplementedError("Constant time not implemented") + # print("Constant", coords_values, metadata) + # metadata["date"] = time.strftime("%Y%m%d") + # metadata["time"] = time.strftime("%H%M") + # metadata["step"] = 0 class Analysis(Time): - def fill_time_metadata(self, time, metadata): - metadata["date"] = time.strftime("%Y%m%d") - metadata["time"] = time.strftime("%H%M") + def __init__(self, time_coordinate): + self.time_coordinate_name = time_coordinate.variable.name + + def fill_time_metadata(self, coords_values, metadata): + valid_datetime = coords_values[self.time_coordinate_name] + + metadata["date"] = as_datetime(valid_datetime).strftime("%Y%m%d") + metadata["time"] = as_datetime(valid_datetime).strftime("%H%M") metadata["step"] = 0 + return valid_datetime -class ForecasstFromValidTimeAndStep(Time): - def __init__(self, step_coordinate): - self.step_name = step_coordinate.variable.name - def fill_time_metadata(self, time, metadata): - step = metadata.pop(self.step_name) +class ForecastFromValidTimeAndStep(Time): + + def __init__(self, time_coordinate, step_coordinate): + self.time_coordinate_name = time_coordinate.variable.name + self.step_coordinate_name = step_coordinate.variable.name + + def fill_time_metadata(self, coords_values, metadata): + valid_datetime = coords_values[self.time_coordinate_name] + step = coords_values[self.step_coordinate_name] + assert isinstance(step, datetime.timedelta) - base = time - step + base_datetime = valid_datetime - step hours = step.total_seconds() / 3600 assert int(hours) == hours - metadata["date"] = base.strftime("%Y%m%d") - metadata["time"] = base.strftime("%H%M") + metadata["date"] = as_datetime(base_datetime).strftime("%Y%m%d") + metadata["time"] = as_datetime(base_datetime).strftime("%H%M") metadata["step"] = int(hours) + return valid_datetime class ForecastFromValidTimeAndBaseTime(Time): - def __init__(self, date_coordinate): - self.date_coordinate = date_coordinate - def fill_time_metadata(self, time, metadata): + def __init__(self, date_coordinate, time_coordinate): + self.date_coordinate.name = date_coordinate.name + self.time_coordinate.name = time_coordinate.name + + def fill_time_metadata(self, coords_values, metadata): + valid_datetime = coords_values[self.time_coordinate_name] + base_datetime = coords_values[self.date_coordinate_name] - step = time - self.date_coordinate + step = valid_datetime - base_datetime hours = step.total_seconds() / 3600 assert int(hours) == hours - metadata["date"] = self.date_coordinate.single_value.strftime("%Y%m%d") - metadata["time"] = self.date_coordinate.single_value.strftime("%H%M") + metadata["date"] = as_datetime(base_datetime).strftime("%Y%m%d") + metadata["time"] = as_datetime(base_datetime).strftime("%H%M") metadata["step"] = int(hours) + return valid_datetime + class ForecastFromBaseTimeAndDate(Time): + def __init__(self, date_coordinate, step_coordinate): - self.date_coordinate = date_coordinate - self.step_coordinate = step_coordinate + self.date_coordinate_name = date_coordinate.name + self.step_coordinate_name = step_coordinate.name + + def fill_time_metadata(self, coords_values, metadata): + + date = coords_values[self.date_coordinate_name] + step = coords_values[self.step_coordinate_name] + assert isinstance(step, datetime.timedelta) + + metadata["date"] = as_datetime(date).strftime("%Y%m%d") + metadata["time"] = as_datetime(date).strftime("%H%M") + + hours = step.total_seconds() / 3600 - def fill_time_metadata(self, time, metadata): - metadata["date"] = time.strftime("%Y%m%d") - metadata["time"] = time.strftime("%H%M") - hours = metadata[self.step_coordinate.name].total_seconds() / 3600 assert int(hours) == hours metadata["step"] = int(hours) + + return date + step diff --git a/src/anemoi/datasets/create/functions/sources/xarray/variable.py b/src/anemoi/datasets/create/functions/sources/xarray/variable.py index e1f0225a..c5f7b869 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/variable.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/variable.py @@ -14,34 +14,32 @@ import numpy as np from earthkit.data.utils.array import ensure_backend -from anemoi.datasets.create.functions.sources.xarray.metadata import MDMapping - from .field import XArrayField LOG = logging.getLogger(__name__) class Variable: - def __init__(self, *, ds, var, coordinates, grid, time, metadata, mapping=None, array_backend=None): + def __init__( + self, + *, + ds, + var, + coordinates, + grid, + time, + metadata, + array_backend=None, + ): self.ds = ds self.var = var self.grid = grid self.coordinates = coordinates - # print("Variable", var.name) - # for c in coordinates: - # print(" ", c) - self._metadata = metadata.copy() - # self._metadata.update(var.attrs) self._metadata.update({"variable": var.name}) - # self._metadata.setdefault("level", None) - # self._metadata.setdefault("number", 0) - # self._metadata.setdefault("levtype", "sfc") - self._mapping = mapping - self.time = time self.shape = tuple(len(c.variable) for c in coordinates if c.is_dim and not c.scalar and not c.is_grid) @@ -51,23 +49,6 @@ def __init__(self, *, ds, var, coordinates, grid, time, metadata, mapping=None, self.length = math.prod(self.shape) self.array_backend = ensure_backend(array_backend) - def update_metadata_mapping(self, kwargs): - - result = {} - - for k, v in kwargs.items(): - if k == "param": - result[k] = "variable" - continue - - for c in self.coordinates: - if k in c.mars_names: - for v in c.mars_names: - result[v] = c.variable.name - break - - self._mapping = MDMapping(result) - @property def name(self): return self.var.name @@ -111,17 +92,11 @@ def __getitem__(self, i): kwargs = {k: v for k, v in zip(self.names, coords)} return XArrayField(self, self.var.isel(kwargs)) - @property - def mapping(self): - return self._mapping - def sel(self, missing, **kwargs): if not kwargs: return self - kwargs = self._mapping.from_user(kwargs) - k, v = kwargs.popitem() c = self.by_name.get(k) @@ -147,13 +122,15 @@ def sel(self, missing, **kwargs): grid=self.grid, time=self.time, metadata=metadata, - mapping=self.mapping, ) return variable.sel(missing, **kwargs) def match(self, **kwargs): - kwargs = self._mapping.from_user(kwargs) + + if "param" in kwargs: + assert "variable" not in kwargs + kwargs["variable"] = kwargs.pop("param") if "variable" in kwargs: name = kwargs.pop("variable") diff --git a/src/anemoi/datasets/create/input.py b/src/anemoi/datasets/create/input.py index 65353088..e696feb3 100644 --- a/src/anemoi/datasets/create/input.py +++ b/src/anemoi/datasets/create/input.py @@ -106,30 +106,32 @@ def _data_request(data): area = grid = None for field in data: - if not hasattr(field, "as_mars"): - continue - - if date is None: - date = field.datetime()["valid_time"] - - if field.datetime()["valid_time"] != date: - continue + try: + if date is None: + date = field.datetime()["valid_time"] - as_mars = field.metadata(namespace="mars") - step = as_mars.get("step") - levtype = as_mars.get("levtype", "sfc") - param = as_mars["param"] - levelist = as_mars.get("levelist", None) - area = field.mars_area - grid = field.mars_grid + if field.datetime()["valid_time"] != date: + continue - if levelist is None: - params_levels[levtype].add(param) - else: - params_levels[levtype].add((param, levelist)) + as_mars = field.metadata(namespace="mars") + if not as_mars: + continue + step = as_mars.get("step") + levtype = as_mars.get("levtype", "sfc") + param = as_mars["param"] + levelist = as_mars.get("levelist", None) + area = field.mars_area + grid = field.mars_grid + + if levelist is None: + params_levels[levtype].add(param) + else: + params_levels[levtype].add((param, levelist)) - if step: - params_steps[levtype].add((param, step)) + if step: + params_steps[levtype].add((param, step)) + except Exception: + LOG.error(f"Error in retrieving metadata (cannot build data request info) for {field}", exc_info=True) def sort(old_dic): new_dic = {} @@ -288,7 +290,6 @@ def explain(self, ds, *args, remapping, patches): names += list(a.keys()) print(f"Building a {len(names)}D hypercube using", names) - ds = ds.order_by(*args, remapping=remapping, patches=patches) user_coords = ds.unique_values(*names, remapping=remapping, patches=patches, progress_bar=False) diff --git a/src/anemoi/datasets/create/loaders.py b/src/anemoi/datasets/create/loaders.py index dc2ecaa0..acd6ec4d 100644 --- a/src/anemoi/datasets/create/loaders.py +++ b/src/anemoi/datasets/create/loaders.py @@ -18,6 +18,9 @@ import zarr from anemoi.utils.config import DotDict from anemoi.utils.dates import as_datetime +from anemoi.utils.dates import frequency_to_string +from anemoi.utils.dates import frequency_to_timedelta +from anemoi.utils.humanize import compress_dates from anemoi.utils.humanize import seconds_to_human from anemoi.datasets import MissingDateError @@ -25,7 +28,6 @@ from anemoi.datasets.create.persistent import build_storage from anemoi.datasets.data.misc import as_first_date from anemoi.datasets.data.misc import as_last_date -from anemoi.datasets.dates import compress_dates from anemoi.datasets.dates.groups import Groups from .check import DatasetName @@ -39,6 +41,7 @@ from .statistics import check_variance from .statistics import compute_statistics from .statistics import default_statistics_dates +from .statistics import fix_variance from .utils import normalize_and_check_dates from .writer import ViewCacheArray from .zarr import ZarrBuiltRegistry @@ -49,6 +52,20 @@ VERSION = "0.20" +def json_tidy(o): + + if isinstance(o, datetime.datetime): + return o.isoformat() + + if isinstance(o, datetime.datetime): + return o.isoformat() + + if isinstance(o, datetime.timedelta): + return frequency_to_string(o) + + raise TypeError(repr(o) + " is not JSON serializable") + + def set_to_test_mode(cfg): NUMBER_OF_DATES = 4 @@ -160,7 +177,7 @@ def update_metadata(self, **kwargs): v = v.astype(datetime.datetime) if isinstance(v, datetime.date): v = v.isoformat() - z.attrs[k] = v + z.attrs[k] = json.loads(json.dumps(v, default=json_tidy)) def _add_dataset(self, mode="r+", **kwargs): z = zarr.open(self.path, mode=mode) @@ -279,7 +296,7 @@ def initialise(self, check_name=True): dates = self.groups.dates frequency = dates.frequency - assert isinstance(frequency, int), frequency + assert isinstance(frequency, datetime.timedelta), frequency LOG.info(f"Found {len(dates)} datetimes.") LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ") @@ -741,8 +758,10 @@ def finalise(self): assert sums.shape == mean.shape x = squares / count - mean * mean - # remove negative variance due to numerical errors # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0 + # remove negative variance due to numerical errors + for i, name in enumerate(self.variables): + x[i] = fix_variance(x[i], name, agg["count"][i : i + 1], agg["sums"][i : i + 1], agg["squares"][i : i + 1]) check_variance(x, self.variables, minimum, maximum, mean, count, sums, squares) stdev = np.sqrt(x) @@ -875,16 +894,20 @@ def __init__(self, path, delta=None, **kwargs): full_ds = open_dataset(path) self.variables = full_ds.variables - frequency = full_ds.frequency + frequency = frequency_to_timedelta(full_ds.frequency) if delta is None: delta = frequency - assert isinstance(delta, int), delta - if not delta % frequency == 0: + + delta = frequency_to_timedelta(delta) + + if not delta.total_seconds() % frequency.total_seconds() == 0: raise TendenciesStatisticsDeltaNotMultipleOfFrequency( f"Delta {delta} is not a multiple of frequency {frequency}" ) self.delta = delta - idelta = delta // frequency + idelta = delta.total_seconds() // frequency.total_seconds() + assert int(idelta) == idelta, idelta + idelta = int(idelta) super().__init__(path=path, **kwargs) @@ -908,10 +931,7 @@ def final_storage_name(self, k): @classmethod def final_storage_name_from_delta(_, k, delta): - if isinstance(delta, int): - delta = str(delta) - if not delta.endswith("h"): - delta = delta + "h" + delta = frequency_to_string(delta) return f"statistics_tendencies_{delta}_{k}" def run(self, parts): diff --git a/src/anemoi/datasets/create/statistics/__init__.py b/src/anemoi/datasets/create/statistics/__init__.py index 568bc410..d788c203 100644 --- a/src/anemoi/datasets/create/statistics/__init__.py +++ b/src/anemoi/datasets/create/statistics/__init__.py @@ -79,6 +79,37 @@ def to_datetimes(dates): return [to_datetime(d) for d in dates] +def fix_variance(x, name, count, sums, squares): + assert count.shape == sums.shape == squares.shape + assert isinstance(x, float) + + mean = sums / count + assert mean.shape == count.shape + + if x >= 0: + return x + + LOG.warning(f"Negative variance for {name=}, variance={x}") + magnitude = np.sqrt((squares / count + mean * mean) / 2) + LOG.warning(f"square / count - mean * mean = {squares/count} - {mean*mean} = {squares/count - mean*mean}") + LOG.warning(f"Variable span order of magnitude is {magnitude}.") + LOG.warning(f"Count is {count}.") + + variances = squares / count - mean * mean + assert variances.shape == squares.shape == mean.shape + if all(variances >= 0): + LOG.warning(f"All individual variances for {name} are positive, setting variance to 0.") + return 0 + + # if abs(x) < magnitude * 1e-6 and abs(x) < range * 1e-6: + # LOG.warning("Variance is negative but very small.") + # variances = squares / count - mean * mean + # return 0 + + LOG.warning(f"ERROR at least one individual variance is negative ({np.nanmin(variances)}).") + return x + + def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squares): if (x >= 0).all(): return @@ -292,39 +323,24 @@ def check_type(a, b): def aggregate(self): minimum = np.nanmin(self.minimum, axis=0) maximum = np.nanmax(self.maximum, axis=0) + sums = np.nansum(self.sums, axis=0) squares = np.nansum(self.squares, axis=0) count = np.nansum(self.count, axis=0) has_nans = np.any(self.has_nans, axis=0) - mean = sums / count + assert sums.shape == count.shape == squares.shape == minimum.shape == maximum.shape - assert sums.shape == count.shape == squares.shape == mean.shape == minimum.shape == maximum.shape + mean = sums / count + assert mean.shape == minimum.shape x = squares / count - mean * mean - - # def fix_variance(x, name, minimum, maximum, mean, count, sums, squares): - # assert x.shape == minimum.shape == maximum.shape == mean.shape == count.shape == sums.shape == squares.shape - # assert x.shape == (1,) - # x, minimum, maximum, mean, count, sums, squares = x[0], minimum[0], maximum[0], mean[0], count[0], sums[0], squares[0] - # if x >= 0: - # return x - # - # order = np.sqrt((squares / count + mean * mean)/2) - # range = maximum - minimum - # LOG.warning(f"Negative variance for {name=}, variance={x}") - # LOG.warning(f"square / count - mean * mean = {squares / count} - {mean * mean} = {squares / count - mean * mean}") - # LOG.warning(f"Variable order of magnitude is {order}.") - # LOG.warning(f"Range is {range} ({maximum=} - {minimum=}).") - # LOG.warning(f"Count is {count}.") - # if abs(x) < order * 1e-6 and abs(x) < range * 1e-6: - # LOG.warning(f"Variance is negative but very small, setting to 0.") - # return x*0 - # return x + assert x.shape == minimum.shape for i, name in enumerate(self.variables_names): # remove negative variance due to numerical errors - # Not needed for now, fix_variance is disabled - # x[i] = fix_variance(x[i:i+1], name, minimum[i:i+1], maximum[i:i+1], mean[i:i+1], count[i:i+1], sums[i:i+1], squares[i:i+1]) + x[i] = fix_variance(x[i], name, self.count[i : i + 1], self.sums[i : i + 1], self.squares[i : i + 1]) + + for i, name in enumerate(self.variables_names): check_variance( x[i : i + 1], [name], diff --git a/src/anemoi/datasets/create/utils.py b/src/anemoi/datasets/create/utils.py index e4629abd..3dcd86c7 100644 --- a/src/anemoi/datasets/create/utils.py +++ b/src/anemoi/datasets/create/utils.py @@ -7,6 +7,7 @@ # nor does it submit to any jurisdiction. # +import datetime import os from contextlib import contextmanager @@ -61,10 +62,10 @@ def make_list_int(value): def normalize_and_check_dates(dates, start, end, frequency, dtype="datetime64[s]"): - assert isinstance(frequency, int), frequency + assert isinstance(frequency, datetime.timedelta), frequency start = np.datetime64(start) end = np.datetime64(end) - delta = np.timedelta64(frequency, "h") + delta = np.timedelta64(frequency) res = [] while start <= end: diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 51c59376..fde2cf7f 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -10,6 +10,10 @@ import warnings from functools import cached_property +from anemoi.utils.dates import frequency_to_seconds +from anemoi.utils.dates import frequency_to_string +from anemoi.utils.dates import frequency_to_timedelta + LOG = logging.getLogger(__name__) @@ -107,10 +111,9 @@ def _subset(self, **kwargs): raise NotImplementedError("Unsupported arguments: " + ", ".join(kwargs)) def _frequency_to_indices(self, frequency): - from .misc import _frequency_to_hours - requested_frequency = _frequency_to_hours(frequency) - dataset_frequency = _frequency_to_hours(self.frequency) + requested_frequency = frequency_to_seconds(frequency) + dataset_frequency = frequency_to_seconds(self.frequency) assert requested_frequency % dataset_frequency == 0 # Question: where do we start? first date, or first date that is a multiple of the frequency? step = requested_frequency // dataset_frequency @@ -211,12 +214,12 @@ def dataset_metadata(self): def metadata_specific(self, **kwargs): action = self.__class__.__name__.lower() - assert isinstance(self.frequency, int), (self.frequency, self, action) + # assert isinstance(self.frequency, datetime.timedelta), (self.frequency, self, action) return dict( action=action, variables=self.variables, shape=self.shape, - frequency=self.frequency, + frequency=frequency_to_string(frequency_to_timedelta(self.frequency)), start_date=self.dates[0].astype(str), end_date=self.dates[-1].astype(str), **kwargs, diff --git a/src/anemoi/datasets/data/grids.py b/src/anemoi/datasets/data/grids.py index e69e0324..c329fdcf 100644 --- a/src/anemoi/datasets/data/grids.py +++ b/src/anemoi/datasets/data/grids.py @@ -128,7 +128,7 @@ def tree(self): class Cutout(GridsBase): - def __init__(self, datasets, axis): + def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0, neighbours=5, plot=False): from anemoi.datasets.grids import cutout_mask super().__init__(datasets, axis) @@ -144,7 +144,10 @@ def __init__(self, datasets, axis): self.lam.longitudes, self.globe.latitudes, self.globe.longitudes, - # plot="cutout", + plot=plot, + min_distance_km=min_distance_km, + cropping_distance=cropping_distance, + neighbours=neighbours, ) assert len(self.mask) == self.globe.shape[3], ( len(self.mask), @@ -229,6 +232,10 @@ def cutout_factory(args, kwargs): cutout = kwargs.pop("cutout") axis = kwargs.pop("axis", 3) + plot = kwargs.pop("plot", None) + min_distance_km = kwargs.pop("min_distance_km", None) + cropping_distance = kwargs.pop("cropping_distance", 2.0) + neighbours = kwargs.pop("neighbours", 5) assert len(args) == 0 assert isinstance(cutout, (list, tuple)) @@ -236,4 +243,11 @@ def cutout_factory(args, kwargs): datasets = [_open(e) for e in cutout] datasets, kwargs = _auto_adjust(datasets, kwargs) - return Cutout(datasets, axis=axis)._subset(**kwargs) + return Cutout( + datasets, + axis=axis, + neighbours=neighbours, + min_distance_km=min_distance_km, + cropping_distance=cropping_distance, + plot=plot, + )._subset(**kwargs) diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index cad55065..89a671b5 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -8,12 +8,12 @@ import calendar import datetime import logging -import re from pathlib import PurePath import numpy as np import zarr from anemoi.utils.config import load_config as load_settings +from anemoi.utils.dates import frequency_to_timedelta from .dataset import Dataset @@ -39,28 +39,6 @@ def add_dataset_path(path): config["datasets"]["path"].append(path) -def _frequency_to_hours(frequency): - if isinstance(frequency, int): - return frequency - - if isinstance(frequency, float): - assert int(frequency) == frequency - return int(frequency) - - m = re.match(r"(\d+)([dh])?", frequency) - if m is None: - raise ValueError("Invalid frequency: " + frequency) - - frequency = int(m.group(1)) - if m.group(2) == "h": - return frequency - - if m.group(2) == "d": - return frequency * 24 - - raise NotImplementedError() - - def _as_date(d, dates, last): # WARNING, datetime.datetime is a subclass of datetime.date @@ -173,12 +151,12 @@ def _concat_or_join(datasets, kwargs): # For now we should have the datasets in order with no gaps - frequency = _frequency_to_hours(datasets[0].frequency) + frequency = frequency_to_timedelta(datasets[0].frequency) for i in range(len(ranges) - 1): r = ranges[i] s = ranges[i + 1] - if r[1] + datetime.timedelta(hours=frequency) != s[0]: + if r[1] + frequency != s[0]: raise ValueError( "Datasets must be sorted by dates, with no gaps: " f"{r} and {s} ({datasets[i]} {datasets[i+1]})" ) diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 5dc3bb80..062079b2 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -13,6 +13,7 @@ import numpy as np import zarr +from anemoi.utils.dates import frequency_to_timedelta from . import MissingDateError from .dataset import Dataset @@ -268,12 +269,11 @@ def field_shape(self): @property def frequency(self): try: - return self.z.attrs["frequency"] + return frequency_to_timedelta(self.z.attrs["frequency"]) except KeyError: LOG.warning("No 'frequency' in %r, computing from 'dates'", self) dates = self.dates - delta = dates[1].astype(object) - dates[0].astype(object) - return int(delta.total_seconds() / 3600) + return dates[1].astype(object) - dates[0].astype(object) @property def name_to_index(self): diff --git a/src/anemoi/datasets/data/subset.py b/src/anemoi/datasets/data/subset.py index 75952061..3ace75be 100644 --- a/src/anemoi/datasets/data/subset.py +++ b/src/anemoi/datasets/data/subset.py @@ -9,6 +9,7 @@ from functools import cached_property import numpy as np +from anemoi.utils.dates import frequency_to_timedelta from .debug import Node from .debug import Source @@ -133,8 +134,7 @@ def dates(self): @cached_property def frequency(self): dates = self.dates - delta = dates[1].astype(object) - dates[0].astype(object) - return int(delta.total_seconds() / 3600) + return frequency_to_timedelta(dates[1].astype(object) - dates[0].astype(object)) def source(self, index): return Source(self, index, self.forward.source(index)) diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index ed155425..eb7886c8 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -9,64 +9,10 @@ import datetime import warnings +# from anemoi.utils.dates import as_datetime from anemoi.utils.dates import as_datetime - - -def _compress_dates(dates): - dates = sorted(dates) - if len(dates) < 3: - yield dates - return - - prev = first = dates.pop(0) - curr = dates.pop(0) - delta = curr - prev - while curr - prev == delta: - prev = curr - if not dates: - break - curr = dates.pop(0) - - yield (first, prev, delta) - if dates: - yield from _compress_dates([curr] + dates) - - -def compress_dates(dates): - dates = [as_datetime(_) for _ in dates] - result = [] - - for n in _compress_dates(dates): - if isinstance(n, list): - result.extend([str(_) for _ in n]) - else: - result.append(" ".join([str(n[0]), "to", str(n[1]), "by", str(n[2])])) - - return result - - -def print_dates(dates): - print(compress_dates(dates)) - - -def no_time_zone(date): - return date.replace(tzinfo=None) - - -def frequency_to_hours(frequency): - if isinstance(frequency, int): - return frequency - assert isinstance(frequency, str), (type(frequency), frequency) - - unit = frequency[-1].lower() - v = int(frequency[:-1]) - return {"h": v, "d": v * 24}[unit] - - -def normalize_date(x): - if isinstance(x, str): - return no_time_zone(datetime.datetime.fromisoformat(x)) - return x +from anemoi.utils.dates import frequency_to_timedelta +from anemoi.utils.humanize import print_dates def extend(x): @@ -79,15 +25,15 @@ def extend(x): if isinstance(x, str): if "/" in x: start, end, step = x.split("/") - start = normalize_date(start) - end = normalize_date(end) - step = frequency_to_hours(step) + start = as_datetime(start) + end = as_datetime(end) + step = frequency_to_timedelta(step) while start <= end: yield start start += datetime.timedelta(hours=step) return - yield normalize_date(x) + yield as_datetime(x) class Dates: @@ -145,7 +91,7 @@ def summary(self): class ValuesDates(Dates): def __init__(self, values, **kwargs): - self.values = sorted([no_time_zone(_) for _ in values]) + self.values = sorted([as_datetime(_) for _ in values]) super().__init__(**kwargs) def __repr__(self): @@ -157,7 +103,8 @@ def as_dict(self): class StartEndDates(Dates): def __init__(self, start, end, frequency=1, months=None, **kwargs): - frequency = frequency_to_hours(frequency) + frequency = frequency_to_timedelta(frequency) + assert isinstance(frequency, datetime.timedelta), frequency def _(x): if isinstance(x, str): @@ -173,13 +120,13 @@ def _(x): if isinstance(end, datetime.date) and not isinstance(end, datetime.datetime): end = datetime.datetime(end.year, end.month, end.day) - start = no_time_zone(start) - end = no_time_zone(end) + start = as_datetime(start) + end = as_datetime(end) # if end <= start: # raise ValueError(f"End date {end} must be after start date {start}") - increment = datetime.timedelta(hours=frequency) + increment = frequency self.start = start self.end = end diff --git a/src/anemoi/datasets/dates/groups.py b/src/anemoi/datasets/dates/groups.py index 65fcfa22..36a071f1 100644 --- a/src/anemoi/datasets/dates/groups.py +++ b/src/anemoi/datasets/dates/groups.py @@ -9,7 +9,7 @@ import itertools from anemoi.datasets.dates import Dates -from anemoi.datasets.dates import normalize_date +from anemoi.datasets.dates import as_datetime class Groups: @@ -67,7 +67,7 @@ def __repr__(self): class Filter: def __init__(self, missing): - self.missing = [normalize_date(m) for m in missing] + self.missing = [as_datetime(m) for m in missing] def __call__(self, dates): return [d for d in dates if d not in self.missing] diff --git a/src/anemoi/datasets/grids.py b/src/anemoi/datasets/grids.py index e4df47ee..288a5c94 100644 --- a/src/anemoi/datasets/grids.py +++ b/src/anemoi/datasets/grids.py @@ -7,41 +7,65 @@ # nor does it submit to any jurisdiction. # +import logging + import numpy as np +LOG = logging.getLogger(__name__) + def plot_mask(path, mask, lats, lons, global_lats, global_lons): import matplotlib.pyplot as plt - middle = (np.amin(lons) + np.amax(lons)) / 2 - print("middle", middle) s = 1 - # gmiddle = (np.amin(global_lons)+ np.amax(global_lons))/2 - - # print('gmiddle', gmiddle) - # global_lons = global_lons-gmiddle+middle global_lons[global_lons >= 180] -= 360 plt.figure(figsize=(10, 5)) plt.scatter(global_lons, global_lats, s=s, marker="o", c="r") - plt.savefig(path + "-global.png") + if isinstance(path, str): + plt.savefig(path + "-global.png") plt.figure(figsize=(10, 5)) plt.scatter(global_lons[mask], global_lats[mask], s=s, c="k") - plt.savefig(path + "-cutout.png") + if isinstance(path, str): + plt.savefig(path + "-cutout.png") plt.figure(figsize=(10, 5)) plt.scatter(lons, lats, s=s) - plt.savefig(path + "-lam.png") + if isinstance(path, str): + plt.savefig(path + "-lam.png") # plt.scatter(lons, lats, s=0.01) plt.figure(figsize=(10, 5)) plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r") plt.scatter(lons, lats, s=s) - plt.savefig(path + "-both.png") + if isinstance(path, str): + plt.savefig(path + "-both.png") # plt.scatter(lons, lats, s=0.01) + plt.figure(figsize=(10, 5)) + plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r") + plt.scatter(lons, lats, s=s) + plt.xlim(np.amin(lons) - 1, np.amax(lons) + 1) + plt.ylim(np.amin(lats) - 1, np.amax(lats) + 1) + if isinstance(path, str): + plt.savefig(path + "-both-zoomed.png") + + plt.figure(figsize=(10, 5)) + plt.scatter(global_lons[mask], global_lats[mask], s=s, c="r") + plt.xlim(np.amin(lons) - 1, np.amax(lons) + 1) + plt.ylim(np.amin(lats) - 1, np.amax(lats) + 1) + if isinstance(path, str): + plt.savefig(path + "-global-zoomed.png") + + +def xyz_to_latlon(x, y, z): + return ( + np.rad2deg(np.arcsin(np.minimum(1.0, np.maximum(-1.0, z)))), + np.rad2deg(np.arctan2(y, x)), + ) + def latlon_to_xyz(lat, lon, radius=1.0): # https://en.wikipedia.org/wiki/Geographic_coordinate_conversion#From_geodetic_to_ECEF_coordinates @@ -121,6 +145,7 @@ def cutout_mask( global_lats, global_lons, cropping_distance=2.0, + neighbours=5, min_distance_km=None, plot=None, ): @@ -164,58 +189,52 @@ def cutout_mask( xyx = latlon_to_xyz(lats, lons) lam_points = np.array(xyx).transpose() - # Use a KDTree to find the nearest points - kdtree = KDTree(lam_points) - distances, indices = kdtree.query(global_points, k=3) - - if min_distance_km is not None: + if isinstance(min_distance_km, (int, float)): min_distance = min_distance_km / 6371.0 else: - # Estimnation of the minimum distance between two grib points - - glats = sorted(set(global_lats_masked)) - glons = sorted(set(global_lons_masked)) - min_dlats = np.min(np.diff(glats)) - min_dlons = np.min(np.diff(glons)) - - # Use the centre of the LAM grid as the reference point - centre = np.mean(lats), np.mean(lons) - centre_xyz = np.array(latlon_to_xyz(*centre)) - - pt1 = np.array(latlon_to_xyz(centre[0] + min_dlats, centre[1])) - pt2 = np.array(latlon_to_xyz(centre[0], centre[1] + min_dlons)) - min_distance = ( - min( - np.linalg.norm(pt1 - centre_xyz), - np.linalg.norm(pt2 - centre_xyz), - ) - / 2.0 - ) + points = {"lam": lam_points, "global": global_points, None: global_points}[min_distance_km] + distances, _ = KDTree(points).query(points, k=2) + min_distance = np.min(distances[:, 1]) + + LOG.info(f"cutout_mask using min_distance = {min_distance * 6371.0} km") + + # Use a KDTree to find the nearest points + distances, indices = KDTree(lam_points).query(global_points, k=neighbours) + # Centre of the Earth zero = np.array([0.0, 0.0, 0.0]) - ok = [] + + # After the loop, 'inside_lam' will contain a list point to EXCLUDE + inside_lam = [] + for i, (global_point, distance, index) in enumerate(zip(global_points, distances, indices)): - t = Triangle3D(lam_points[index[0]], lam_points[index[1]], lam_points[index[2]]) - # distance = np.min(distance) - # The point is inside the triangle if the intersection with the ray - # from the point to the centre of the Earth is not None - # (the direction of the ray is not important) - intersect = t.intersect(zero, global_point) + # We check more than one triangle in case te global point + # is near the edge of triangle, (the lam point and global points are colinear) + + inside = False + for j in range(neighbours): + t = Triangle3D( + lam_points[index[j]], lam_points[index[(j + 1) % neighbours]], lam_points[index[(j + 2) % neighbours]] + ) + inside = t.intersect(zero, global_point) + if inside: + break + close = np.min(distance) <= min_distance - ok.append(intersect or close) + inside_lam.append(inside or close) j = 0 - ok = np.array(ok) + inside_lam = np.array(inside_lam) for i, m in enumerate(mask): if not m: continue - mask[i] = ok[j] + mask[i] = inside_lam[j] j += 1 - assert j == len(ok) + assert j == len(inside_lam) # Invert the mask, so we have only the points outside the cutout mask = ~mask @@ -271,8 +290,7 @@ def thinning_mask( points = np.array(xyx).transpose() # Use a KDTree to find the nearest points - kdtree = KDTree(points) - _, indices = kdtree.query(global_points, k=1) + _, indices = KDTree(points).query(global_points, k=1) return np.array([i for i in indices]) diff --git a/tests/create/test_create.py b/tests/create/test_create.py index 50c100f8..97c26007 100755 --- a/tests/create/test_create.py +++ b/tests/create/test_create.py @@ -9,23 +9,21 @@ import hashlib import json import os -import shutil -import warnings from functools import wraps from unittest.mock import patch import numpy as np import pytest import requests -from earthkit.data import from_source +from earthkit.data import from_source as original_from_source +from multiurl import download from anemoi.datasets import open_dataset from anemoi.datasets.create import Creator from anemoi.datasets.data.stores import open_zarr -MARS_CLIENT_PRESENT = os.path.exists("/usr/local/bin/mars") - -TEST_DATA_ROOT = "https://object-store.os-api.cci1.ecmwf.int/ml-tests/test-data/anemoi-datasets/create/" +TEST_DATA_ROOT = "https://object-store.os-api.cci1.ecmwf.int/ml-tests/test-data/anemoi-datasets/create" +TEST_DATA_S3_ROOT = "s3://ml-tests/test-data/anemoi-datasets/create" HERE = os.path.dirname(__file__) @@ -46,103 +44,95 @@ def wrapper(*args, **kwargs): class LoadSource: - def __init__(self, read_dir=None, write_dir=None): - self.read_dir = read_dir - self.write_dir = write_dir def filename(self, args, kwargs): - try: - string = json.dumps([args, kwargs]) - except Exception as e: - warnings.warn(f"Could not build hash for {args}, {kwargs}, {e}") - return None + string = json.dumps([args, kwargs], sort_keys=True, default=str) h = hashlib.md5(string.encode("utf8")).hexdigest() - return h + ".copy" - - def write(self, directory, ds, args, kwargs): - if self.write_dir is None: - return - - if not hasattr(ds, "path"): - return + return h + ".grib" + + def get_data(self, args, kwargs, path): + upload_path = os.path.realpath(path + ".to_upload") + ds = original_from_source("mars", *args, **kwargs) + ds.save(upload_path) + print(f"Mockup: Saving to {upload_path} for {args}, {kwargs}") + exe = os.path.realpath(os.path.join(os.path.dirname(__file__), "../../tools/upload-sample-dataset.py")) + print() + print("⚠️ To upload the test data, run this:") + print() + print(f"{exe} {upload_path} anemoi-datasets/create/{os.path.basename(path)} --overwrite") + print() + exit(1) + raise ValueError("Test data is missing") + + def mars(self, args, kwargs): filename = self.filename(args, kwargs) - path = os.path.join(directory, filename) - print(f"Saving to {path} for {args}, {kwargs}") - shutil.copy(ds.path, path) + dirname = "." + path = os.path.join(dirname, filename) + url = TEST_DATA_ROOT + "/" + filename - def read(self, directory, args, kwargs): - if self.read_dir is None: - return None - filename = self.filename(args, kwargs) - if filename is None: - return None - path = os.path.join(directory, filename) - - if os.path.exists(path): - print(f"Mockup: Loading path {path} for {args}, {kwargs}") - ds = from_source("file", path) - return ds + assert url.startswith("http:") or url.startswith("https:") - elif path.startswith("http:") or path.startswith("https:"): + if not os.path.exists(path): print(f"Mockup: Loading url {path} for {args}, {kwargs}") try: - return from_source("url", path) - except requests.exceptions.HTTPError: - print(f"Mockup: ❌ Cannot load from url for {path} for {args}, {kwargs}") - - return None - - def source_name(self, *args, **kwargs): - if args: - return args[0] - return kwargs["name"] + download(url, path + ".tmp") + os.rename(path + ".tmp", path) + except requests.exceptions.HTTPError as e: + print(e) + if e.response.status_code == 404: + self.get_data(args, kwargs, path) + raise - def __call__(self, *args, **kwargs): - name = self.source_name(*args, **kwargs) + return original_from_source("file", path) - if name != "mars": - return from_source(*args, **kwargs) + def __call__(self, name, *args, **kwargs): + if name == "mars": + return self.mars(args, kwargs) - ds = self.read(self.read_dir, args, kwargs) - if ds is not None: - return ds + return original_from_source(name, *args, **kwargs) - ds = from_source(*args, **kwargs) - self.write(self.write_dir, ds, args, kwargs) +_from_source = LoadSource() - return ds - -_from_source = LoadSource( - read_dir=os.environ.get("LOAD_SOURCE_MOCKUP_READ_DIRECTORY", TEST_DATA_ROOT), - write_dir=os.environ.get("LOAD_SOURCE_MOCKUP_WRITE_DIRECTORY"), -) - - -def compare_dot_zattrs(a, b): +def compare_dot_zattrs(a, b, path, errors): if isinstance(a, dict): a_keys = list(a.keys()) b_keys = list(b.keys()) for k in set(a_keys) & set(b_keys): - if k in ["timestamp", "uuid", "latest_write_timestamp", "yaml_config"]: - assert type(a[k]) == type(b[k]), ( # noqa: E721 - type(a[k]), - type(b[k]), - a[k], - b[k], - ) - assert k in a_keys, (k, a_keys) - assert k in b_keys, (k, b_keys) - return compare_dot_zattrs(a[k], b[k]) + if k in [ + "timestamp", + "uuid", + "latest_write_timestamp", + "yaml_config", + "history", + "provenance", + "provenance_load", + "description", + "config_path", + "dataset_status", + ]: + if type(a[k]) != type(b[k]): # noqa : E721 + errors.append(f"❌ {path}.{k} : type differs {type(a[k])} != {type(b[k])}") + continue + compare_dot_zattrs(a[k], b[k], f"{path}.{k}", errors) + return if isinstance(a, list): - assert len(a) == len(b), (a, b) - for v, w in zip(a, b): - return compare_dot_zattrs(v, w) + if len(a) != len(b): + errors.append(f"❌ {path} : lengths are different {len(a)} != {len(b)}") + return + for i, (v, w) in enumerate(zip(a, b)): + compare_dot_zattrs(v, w, f"{path}.{i}", errors) + return - assert type(a) == type(b), (type(a), type(b), a, b) # noqa: E721 - return a == b, (a, b) + if type(a) != type(b): # noqa : E721 + msg = f"❌ {path} actual != expected : {a} ({type(a)}) != {b} ({type(b)})" + errors.append(msg) + return + if a != b: + msg = f"❌ {path} actual != expected : {a} != {b}" + errors.append(msg) def compare_datasets(a, b): @@ -195,23 +185,32 @@ def compare_statistics(ds1, ds2): class Comparer: def __init__(self, name, output_path=None, reference_path=None): self.name = name - self.reference = reference_path or os.path.join(TEST_DATA_ROOT, name + ".zarr") self.output = output_path or os.path.join(name + ".zarr") - print(f"Comparing {self.reference} and {self.output}") + self.reference_path = reference_path + print(f"Comparing {self.output} and {self.reference_path}") - self.z_reference = open_zarr(self.reference) self.z_output = open_zarr(self.output) + self.z_reference = open_zarr(self.reference_path) - self.ds_reference = open_dataset(self.reference) + self.z_reference["data"] self.ds_output = open_dataset(self.output) + self.ds_reference = open_dataset(self.reference_path) def compare(self): - compare_dot_zattrs(self.z_output.attrs, self.z_reference.attrs) + errors = [] + compare_dot_zattrs(dict(self.z_output.attrs), dict(self.z_reference.attrs), "metadata", errors) + if errors: + print("Comparison failed") + print("\n".join(errors)) + + if errors: + raise AssertionError("Comparison failed") + compare_datasets(self.ds_output, self.ds_reference) + compare_statistics(self.ds_output, self.ds_reference) -@pytest.mark.skipif(not MARS_CLIENT_PRESENT, reason="Test requires direct mars access.") @pytest.mark.parametrize("name", NAMES) @mockup_from_source def test_run(name): @@ -226,8 +225,14 @@ def test_run(name): c.additions(delta=[1, 3, 6, 12]) c.cleanup() - comparer = Comparer(name, output_path=output) - comparer.compare() + # reference_path = os.path.join(HERE, name + "-reference.zarr") + s3_uri = TEST_DATA_S3_ROOT + "/" + name + ".zarr" + # if not os.path.exists(reference_path): + # from anemoi.utils.s3 import download as s3_download + # s3_download(s3_uri + '/', reference_path, overwrite=True) + + Comparer(name, output_path=output, reference_path=s3_uri).compare() + # Comparer(name, output_path=output, reference_path=reference_path).compare() if __name__ == "__main__": diff --git a/tests/test_chunks.py b/tests/test_chunks.py index 4dfc38e7..d0fc3aa1 100644 --- a/tests/test_chunks.py +++ b/tests/test_chunks.py @@ -86,4 +86,7 @@ def test_chunk_filter(): if __name__ == "__main__": - test_chunk_filter() + for name, obj in list(globals().items()): + if name.startswith("test_") and callable(obj): + print(f"Running {name}...") + obj() diff --git a/tests/test_data.py b/tests/test_data.py index 83711d07..42e2a001 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -12,13 +12,14 @@ import numpy as np import zarr +from anemoi.utils.dates import frequency_to_string +from anemoi.utils.dates import frequency_to_timedelta from anemoi.datasets import open_dataset from anemoi.datasets.data.concat import Concat from anemoi.datasets.data.ensemble import Ensemble from anemoi.datasets.data.grids import GridsBase from anemoi.datasets.data.join import Join -from anemoi.datasets.data.misc import _frequency_to_hours from anemoi.datasets.data.misc import as_first_date from anemoi.datasets.data.misc import as_last_date from anemoi.datasets.data.select import Rename @@ -63,12 +64,13 @@ def create_zarr( missing=False, ): root = zarr.group() + assert isinstance(frequency, datetime.timedelta) dates = [] date = datetime.datetime(start, 1, 1) while date.year <= end: dates.append(date) - date += datetime.timedelta(hours=frequency) + date += frequency dates = np.array(dates, dtype="datetime64") @@ -105,7 +107,7 @@ def create_zarr( compressor=None, ) - root.attrs["frequency"] = frequency + root.attrs["frequency"] = frequency_to_string(frequency) root.attrs["resolution"] = resolution root.attrs["name_to_index"] = {k: i for i, k in enumerate(vars)} @@ -170,7 +172,7 @@ def zarr_from_str(name, mode): return create_zarr( start=int(args["start"]), end=int(args["end"]), - frequency=_frequency_to_hours(args["frequency"]), + frequency=frequency_to_timedelta(args["frequency"]), resolution=args["resolution"], vars=[x for x in args["vars"]], k=int(args["k"]), @@ -1111,4 +1113,7 @@ def test_cropping(): if __name__ == "__main__": - test_cropping() + for name, obj in list(globals().items()): + if name.startswith("test_") and callable(obj): + print(f"Running {name}...") + obj() diff --git a/tests/xarray/test_kerchunk.py b/tests/xarray/test_kerchunk.py index 7ca5f6e8..c1bf5667 100644 --- a/tests/xarray/test_kerchunk.py +++ b/tests/xarray/test_kerchunk.py @@ -33,4 +33,7 @@ def dont_test_kerchunk(): if __name__ == "__main__": - dont_test_kerchunk() + for name, obj in list(globals().items()): + if name.startswith("test_") and callable(obj): + print(f"Running {name}...") + obj() diff --git a/tests/xarray/test_opendap.py b/tests/xarray/test_opendap.py index 3f7ce3ad..6ae3981f 100644 --- a/tests/xarray/test_opendap.py +++ b/tests/xarray/test_opendap.py @@ -21,4 +21,7 @@ def test_opendap(): if __name__ == "__main__": - test_opendap() + for name, obj in list(globals().items()): + if name.startswith("test_") and callable(obj): + print(f"Running {name}...") + obj() diff --git a/tests/xarray/test_zarr.py b/tests/xarray/test_zarr.py index d138d371..509e3afe 100644 --- a/tests/xarray/test_zarr.py +++ b/tests/xarray/test_zarr.py @@ -22,7 +22,7 @@ def test_arco_era5(): print(len(fs)) print(fs[-1].metadata()) - print(fs[-1].to_numpy()) + # print(fs[-1].to_numpy()) assert len(fs) == 128677526 @@ -47,8 +47,27 @@ def test_weatherbench(): assert len(fs) == 2430240 - assert fs[0].metadata("valid_datetime") == "2020-01-01T00:00:00", fs[0].metadata("valid_datetime") + assert fs[0].metadata("valid_datetime") == "2020-01-01T06:00:00", fs[0].metadata("valid_datetime") + assert fs[-1].metadata("valid_datetime") == "2021-01-10T12:00:00", fs[-1].metadata("valid_datetime") + + +def test_inca_one_date(): + url = "https://object-store.os-api.cci1.ecmwf.int/ml-tests/test-data/example-inca-one-date.zarr" + + ds = xr.open_zarr(url) + fs = XarrayFieldList.from_xarray(ds) + vars = ["DD_10M", "SP_10M", "TD_2M", "TOT_PREC", "T_2M"] + + for i, f in enumerate(fs): + print(f) + assert f.metadata("valid_datetime") == "2023-01-01T00:00:00" + assert f.metadata("step") == 0 + assert f.metadata("number") == 0 + assert f.metadata("variable") == vars[i] if __name__ == "__main__": - test_weatherbench() + for name, obj in list(globals().items()): + if name.startswith("test_") and callable(obj): + print(f"Running {name}...") + obj() diff --git a/tools/upload-sample-dataset.py b/tools/upload-sample-dataset.py new file mode 100755 index 00000000..5d6abadc --- /dev/null +++ b/tools/upload-sample-dataset.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import argparse +import logging +import os + +from anemoi.utils.s3 import upload + +LOG = logging.getLogger(__name__) + +logging.basicConfig(level=logging.INFO) + +parser = argparse.ArgumentParser(description="Upload sample dataset to S3") +parser.add_argument("--bucket", type=str, help="S3 target path", default="s3://ml-tests/test-data/") +parser.add_argument("source", type=str, help="Path to the sample dataset") +parser.add_argument("target", type=str, help="Path to the sample dataset") +parser.add_argument("--overwrite", action="store_true", help="Overwrite existing data") + +args = parser.parse_args() + +source = args.source +target = args.target +bucket = args.bucket + +if not target.startswith("s3://"): + if target.startswith("/"): + target = target[1:] + if bucket.endswith("/"): + bucket = bucket[:-1] + target = os.path.join(bucket, target) + +LOG.info(f"Uploading {source} to {target}") +upload(source, target, overwrite=args.overwrite) +LOG.info("Upload complete")