diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cfcfcc8e..0cf27489 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,7 @@ repos: - id: debug-statements # Check for debugger imports and py37+ breakpoint() - id: end-of-file-fixer # Ensure files end in a newline - id: trailing-whitespace # Trailing whitespace checker - - id: no-commit-to-branch # Prevent committing to main / master +# - id: no-commit-to-branch # Prevent committing to main / master - id: check-added-large-files # Check for large files added to git - id: check-merge-conflict # Check for files that contain merge conflict diff --git a/docs/building/handling-missing-values.rst b/docs/building/handling-missing-values.rst index 7898cfd9..adc07cba 100644 --- a/docs/building/handling-missing-values.rst +++ b/docs/building/handling-missing-values.rst @@ -2,5 +2,16 @@ Handling missing values ######################### -.. literalinclude:: ../../tests/create/nan.yaml +When handling data for machine learning models, missing values (NaNs) +can pose a challenge, as models require complete data to operate +effectively and may crash otherwise. Ideally, we anticipate having +complete data in all fields. However, there are scenarios where NaNs +naturally occur, such as with variables only relevant on land or at sea +(such as sea surface temperature (`sst`), for example). In such cases, +the default behavior is to reject data with NaNs as invalid. To +accommodate NaNs and accurately compute statistics based on them, you +can include the `allow_nans` key in the configuration. Here's an example +of how to implement it: + +.. literalinclude:: yaml/nan.yaml :language: yaml diff --git a/docs/building/statistics.rst b/docs/building/statistics.rst index 6cc2aacc..f2187dc8 100644 --- a/docs/building/statistics.rst +++ b/docs/building/statistics.rst @@ -17,7 +17,7 @@ algorithm: - If the dataset covers 10 years or more, the last year is excluded. - Otherwise, 80% of the dataset is used. -You can override this behaviour by setting the `start` and `end` +You can override this behaviour by setting the `start` or `end` parameters in the `statistics` config. .. code:: yaml diff --git a/docs/building/yaml/nan.yaml b/docs/building/yaml/nan.yaml new file mode 100644 index 00000000..a9ce827b --- /dev/null +++ b/docs/building/yaml/nan.yaml @@ -0,0 +1,2 @@ +statistics: + allow_nans: [sst, ci] diff --git a/src/anemoi/datasets/compute/perturbations.py b/src/anemoi/datasets/compute/perturbations.py index c09041f3..f86c1256 100644 --- a/src/anemoi/datasets/compute/perturbations.py +++ b/src/anemoi/datasets/compute/perturbations.py @@ -7,59 +7,95 @@ # nor does it submit to any jurisdiction. # -import warnings +import logging import numpy as np from climetlab.core.temporary import temp_file from climetlab.readers.grib.output import new_grib_output -from anemoi.datasets.create.check import check_data_values from anemoi.datasets.create.functions import assert_is_fieldset +LOG = logging.getLogger(__name__) + +CLIP_VARIABLES = ( + "q", + "cp", + "lsp", + "tp", + "sf", + "swl4", + "swl3", + "swl2", + "swl1", +) + +SKIP = ("class", "stream", "type", "number", "expver", "_leg_number", "anoffset") + + +def check_compatible(f1, f2, center_field_as_mars, ensemble_field_as_mars): + assert f1.mars_grid == f2.mars_grid, (f1.mars_grid, f2.mars_grid) + assert f1.mars_area == f2.mars_area, (f1.mars_area, f2.mars_area) + assert f1.shape == f2.shape, (f1.shape, f2.shape) + + # Not in *_as_mars + assert f1.metadata("valid_datetime") == f2.metadata("valid_datetime"), ( + f1.metadata("valid_datetime"), + f2.metadata("valid_datetime"), + ) + + for k in set(center_field_as_mars.keys()) | set(ensemble_field_as_mars.keys()): + if k in SKIP: + continue + assert center_field_as_mars[k] == ensemble_field_as_mars[k], ( + k, + center_field_as_mars[k], + ensemble_field_as_mars[k], + ) + def perturbations( + *, members, center, - positive_clipping_variables=[ - "q", - "cp", - "lsp", - "tp", - ], # add "swl4", "swl3", "swl2", "swl1", "swl0", and more ? + clip_variables=CLIP_VARIABLES, + output=None, ): keys = ["param", "level", "valid_datetime", "date", "time", "step", "number"] - def check_compatible(f1, f2, ignore=["number"]): - for k in keys + ["grid", "shape"]: - if k in ignore: - continue - assert f1.metadata(k) == f2.metadata(k), (k, f1.metadata(k), f2.metadata(k)) + number_list = members.unique_values("number")["number"] + n_numbers = len(number_list) - print(f"Retrieving ensemble data with {members}") - print(f"Retrieving center data with {center}") + assert None not in number_list + LOG.info("Ordering fields") members = members.order_by(*keys) center = center.order_by(*keys) - - number_list = members.unique_values("number")["number"] - n_numbers = len(number_list) + LOG.info("Done") if len(center) * n_numbers != len(members): - print(len(center), n_numbers, len(members)) + LOG.error("%s %s %s", len(center), n_numbers, len(members)) for f in members: - print("Member: ", f) + LOG.error("Member: %r", f) for f in center: - print("Center: ", f) + LOG.error("Center: %r", f) raise ValueError(f"Inconsistent number of fields: {len(center)} * {n_numbers} != {len(members)}") - # prepare output tmp file so we can read it back - tmp = temp_file() - path = tmp.path + if output is None: + # prepare output tmp file so we can read it back + tmp = temp_file() + path = tmp.path + else: + tmp = None + path = output + out = new_grib_output(path) + seen = set() + for i, center_field in enumerate(center): param = center_field.metadata("param") + center_field_as_mars = center_field.as_mars() # load the center field center_np = center_field.to_numpy() @@ -69,9 +105,21 @@ def check_compatible(f1, f2, ignore=["number"]): for j in range(n_numbers): ensemble_field = members[i * n_numbers + j] - check_compatible(center_field, ensemble_field) + ensemble_field_as_mars = ensemble_field.as_mars() + check_compatible(center_field, ensemble_field, center_field_as_mars, ensemble_field_as_mars) members_np[j] = ensemble_field.to_numpy() + ensemble_field_as_mars = tuple(sorted(ensemble_field_as_mars.items())) + assert ensemble_field_as_mars not in seen, ensemble_field_as_mars + seen.add(ensemble_field_as_mars) + + # cmin=np.amin(center_np) + # emin=np.amin(members_np) + + # if cmin < 0 and emin >= 0: + # LOG.warning(f"Negative values in {param} cmin={cmin} emin={emin}") + # LOG.warning(f"Center: {center_field_as_mars}") + mean_np = members_np.mean(axis=0) for j in range(n_numbers): @@ -84,18 +132,22 @@ def check_compatible(f1, f2, ignore=["number"]): x = c - m + e - if param in positive_clipping_variables: - warnings.warn(f"Clipping {param} to be positive") + if param in clip_variables: + # LOG.warning(f"Clipping {param} to be positive") x = np.maximum(x, 0) assert x.shape == e.shape, (x.shape, e.shape) - check_data_values(x, name=param) out.write(x, template=template) template = None + assert len(seen) == len(members), (len(seen), len(members)) + out.close() + if output is not None: + return path + from climetlab import load_source ds = load_source("file", path) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index b829f641..dbaacc7f 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -94,7 +94,7 @@ def patch(self, **kwargs): apply_patch(self.path, **kwargs) - def init_additions(self, delta=[1, 3, 6, 12]): + def init_additions(self, delta=[1, 3, 6, 12, 24]): from .loaders import StatisticsAddition from .loaders import TendenciesStatisticsAddition from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency @@ -109,7 +109,7 @@ def init_additions(self, delta=[1, 3, 6, 12]): except TendenciesStatisticsDeltaNotMultipleOfFrequency: self.print(f"Skipping delta={d} as it is not a multiple of the frequency.") - def run_additions(self, parts=None, delta=[1, 3, 6, 12]): + def run_additions(self, parts=None, delta=[1, 3, 6, 12, 24]): from .loaders import StatisticsAddition from .loaders import TendenciesStatisticsAddition from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency @@ -124,7 +124,7 @@ def run_additions(self, parts=None, delta=[1, 3, 6, 12]): except TendenciesStatisticsDeltaNotMultipleOfFrequency: self.print(f"Skipping delta={d} as it is not a multiple of the frequency.") - def finalise_additions(self, delta=[1, 3, 6, 12]): + def finalise_additions(self, delta=[1, 3, 6, 12, 24]): from .loaders import StatisticsAddition from .loaders import TendenciesStatisticsAddition from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency diff --git a/src/anemoi/datasets/create/loaders.py b/src/anemoi/datasets/create/loaders.py index 780a078e..2e359c37 100644 --- a/src/anemoi/datasets/create/loaders.py +++ b/src/anemoi/datasets/create/loaders.py @@ -546,12 +546,17 @@ def write_stats_to_stdout(self, stats): class GenericAdditions(GenericDatasetHandler): - def __init__(self, name="", **kwargs): + def __init__(self, **kwargs): super().__init__(**kwargs) - self.name = name + self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True) + + @property + def tmp_storage_path(self): + raise NotImplementedError - storage_path = f"{self.path}.tmp_storage_{name}" - self.tmp_storage = build_storage(directory=storage_path, create=True) + @property + def final_storage_path(self): + raise NotImplementedError def initialise(self): self.tmp_storage.delete() @@ -589,7 +594,7 @@ def finalise(self): count=np.full(shape, -1, dtype=np.int64), has_nans=np.full(shape, False, dtype=np.bool_), ) - LOG.info(f"Aggregating {self.name} statistics on shape={shape}. Variables : {self.variables}") + LOG.info(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}") found = set() ifound = set() @@ -659,9 +664,10 @@ def finalise(self): def _write(self, summary): for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: - self._add_dataset(name=k, array=summary[k]) - self.registry.add_to_history("compute_statistics_end") - LOG.info(f"Wrote {self.name} additions in {self.path}") + name = self.final_storage_name(k) + self._add_dataset(name=name, array=summary[k]) + self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end") + LOG.info(f"Wrote additions in {self.path} ({self.final_storage_name('*')})") def check_statistics(self): pass @@ -669,7 +675,7 @@ def check_statistics(self): class StatisticsAddition(GenericAdditions): def __init__(self, **kwargs): - super().__init__("statistics_", **kwargs) + super().__init__(**kwargs) z = zarr.open(self.path, mode="r") start = z.attrs["statistics_start_date"] @@ -682,6 +688,13 @@ def __init__(self, **kwargs): assert len(self.variables) == self.ds.shape[1], self.ds.shape self.total = len(self.dates) + @property + def tmp_storage_path(self): + return f"{self.path}.tmp_storage_statistics" + + def final_storage_name(self, k): + return k + def run(self, parts): chunk_filter = ChunkFilter(parts=parts, total=self.total) for i in range(0, self.total): @@ -725,8 +738,6 @@ class TendenciesStatisticsDeltaNotMultipleOfFrequency(ValueError): class TendenciesStatisticsAddition(GenericAdditions): - DATASET_NAME_PATTERN = "statistics_tendencies_{delta}" - def __init__(self, path, delta=None, **kwargs): full_ds = open_dataset(path) self.variables = full_ds.variables @@ -739,9 +750,10 @@ def __init__(self, path, delta=None, **kwargs): raise TendenciesStatisticsDeltaNotMultipleOfFrequency( f"Delta {delta} is not a multiple of frequency {frequency}" ) + self.delta = delta idelta = delta // frequency - super().__init__(path=path, name=self.DATASET_NAME_PATTERN.format(delta=f"{delta}h"), **kwargs) + super().__init__(path=path, **kwargs) z = zarr.open(self.path, mode="r") start = z.attrs["statistics_start_date"] @@ -754,6 +766,21 @@ def __init__(self, path, delta=None, **kwargs): ds = open_dataset(self.path, start=start, end=end) self.ds = DeltaDataset(ds, idelta) + @property + def tmp_storage_path(self): + return f"{self.path}.tmp_storage_statistics_{self.delta}h" + + def final_storage_name(self, k): + return self.final_storage_name_from_delta(k, delta=self.delta) + + @classmethod + def final_storage_name_from_delta(_, k, delta): + if isinstance(delta, int): + delta = str(delta) + if not delta.endswith("h"): + delta = delta + "h" + return f"statistics_tendencies_{delta}_{k}" + def run(self, parts): chunk_filter = ChunkFilter(parts=parts, total=self.total) for i in range(0, self.total): @@ -768,9 +795,3 @@ def run(self, parts): self.tmp_storage.add([date, i, "missing"], key=date) self.tmp_storage.flush() LOG.info(f"Dataset {self.path} additions run.") - - def _write(self, summary): - for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: - self._add_dataset(name=f"{self.name}_{k}", array=summary[k]) - self.registry.add_to_history(f"compute_{self.name}_end") - LOG.info(f"Wrote {self.name} additions in {self.path}") diff --git a/src/anemoi/datasets/create/size.py b/src/anemoi/datasets/create/size.py index 1671290f..35c20994 100644 --- a/src/anemoi/datasets/create/size.py +++ b/src/anemoi/datasets/create/size.py @@ -10,6 +10,8 @@ import logging import os +from anemoi.utils.humanize import bytes + from anemoi.datasets.create.utils import progress_bar LOG = logging.getLogger(__name__) diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 62173e08..d5d737fa 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -225,12 +225,12 @@ def statistics_tendencies(self, delta=None): delta = f"{delta}h" from anemoi.datasets.create.loaders import TendenciesStatisticsAddition - prefix = TendenciesStatisticsAddition.DATASET_NAME_PATTERN.format(delta=delta) + "_" + func = TendenciesStatisticsAddition.final_storage_name_from_delta return dict( - mean=self.z[f"{prefix}mean"][:], - stdev=self.z[f"{prefix}stdev"][:], - maximum=self.z[f"{prefix}maximum"][:], - minimum=self.z[f"{prefix}minimum"][:], + mean=self.z[func("mean", delta)][:], + stdev=self.z[func("stdev", delta)][:], + maximum=self.z[func("maximum", delta)][:], + minimum=self.z[func("minimum", delta)][:], ) @property