Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed May 10, 2024
2 parents 7cc1eea + e9ad943 commit a3dda3a
Show file tree
Hide file tree
Showing 9 changed files with 145 additions and 57 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ repos:
- id: debug-statements # Check for debugger imports and py37+ breakpoint()
- id: end-of-file-fixer # Ensure files end in a newline
- id: trailing-whitespace # Trailing whitespace checker
- id: no-commit-to-branch # Prevent committing to main / master
# - id: no-commit-to-branch # Prevent committing to main / master
- id: check-added-large-files # Check for large files added to git
- id: check-merge-conflict # Check for files that contain merge conflict

Expand Down
13 changes: 12 additions & 1 deletion docs/building/handling-missing-values.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,16 @@
Handling missing values
#########################

.. literalinclude:: ../../tests/create/nan.yaml
When handling data for machine learning models, missing values (NaNs)
can pose a challenge, as models require complete data to operate
effectively and may crash otherwise. Ideally, we anticipate having
complete data in all fields. However, there are scenarios where NaNs
naturally occur, such as with variables only relevant on land or at sea
(such as sea surface temperature (`sst`), for example). In such cases,
the default behavior is to reject data with NaNs as invalid. To
accommodate NaNs and accurately compute statistics based on them, you
can include the `allow_nans` key in the configuration. Here's an example
of how to implement it:

.. literalinclude:: yaml/nan.yaml
:language: yaml
2 changes: 1 addition & 1 deletion docs/building/statistics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ algorithm:
- If the dataset covers 10 years or more, the last year is excluded.
- Otherwise, 80% of the dataset is used.

You can override this behaviour by setting the `start` and `end`
You can override this behaviour by setting the `start` or `end`
parameters in the `statistics` config.

.. code:: yaml
Expand Down
2 changes: 2 additions & 0 deletions docs/building/yaml/nan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
statistics:
allow_nans: [sst, ci]
108 changes: 80 additions & 28 deletions src/anemoi/datasets/compute/perturbations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,59 +7,95 @@
# nor does it submit to any jurisdiction.
#

import warnings
import logging

import numpy as np
from climetlab.core.temporary import temp_file
from climetlab.readers.grib.output import new_grib_output

from anemoi.datasets.create.check import check_data_values
from anemoi.datasets.create.functions import assert_is_fieldset

LOG = logging.getLogger(__name__)

CLIP_VARIABLES = (
"q",
"cp",
"lsp",
"tp",
"sf",
"swl4",
"swl3",
"swl2",
"swl1",
)

SKIP = ("class", "stream", "type", "number", "expver", "_leg_number", "anoffset")


def check_compatible(f1, f2, center_field_as_mars, ensemble_field_as_mars):
assert f1.mars_grid == f2.mars_grid, (f1.mars_grid, f2.mars_grid)
assert f1.mars_area == f2.mars_area, (f1.mars_area, f2.mars_area)
assert f1.shape == f2.shape, (f1.shape, f2.shape)

# Not in *_as_mars
assert f1.metadata("valid_datetime") == f2.metadata("valid_datetime"), (
f1.metadata("valid_datetime"),
f2.metadata("valid_datetime"),
)

for k in set(center_field_as_mars.keys()) | set(ensemble_field_as_mars.keys()):
if k in SKIP:
continue
assert center_field_as_mars[k] == ensemble_field_as_mars[k], (
k,
center_field_as_mars[k],
ensemble_field_as_mars[k],
)


def perturbations(
*,
members,
center,
positive_clipping_variables=[
"q",
"cp",
"lsp",
"tp",
], # add "swl4", "swl3", "swl2", "swl1", "swl0", and more ?
clip_variables=CLIP_VARIABLES,
output=None,
):

keys = ["param", "level", "valid_datetime", "date", "time", "step", "number"]

def check_compatible(f1, f2, ignore=["number"]):
for k in keys + ["grid", "shape"]:
if k in ignore:
continue
assert f1.metadata(k) == f2.metadata(k), (k, f1.metadata(k), f2.metadata(k))
number_list = members.unique_values("number")["number"]
n_numbers = len(number_list)

print(f"Retrieving ensemble data with {members}")
print(f"Retrieving center data with {center}")
assert None not in number_list

LOG.info("Ordering fields")
members = members.order_by(*keys)
center = center.order_by(*keys)

number_list = members.unique_values("number")["number"]
n_numbers = len(number_list)
LOG.info("Done")

if len(center) * n_numbers != len(members):
print(len(center), n_numbers, len(members))
LOG.error("%s %s %s", len(center), n_numbers, len(members))
for f in members:
print("Member: ", f)
LOG.error("Member: %r", f)
for f in center:
print("Center: ", f)
LOG.error("Center: %r", f)
raise ValueError(f"Inconsistent number of fields: {len(center)} * {n_numbers} != {len(members)}")

# prepare output tmp file so we can read it back
tmp = temp_file()
path = tmp.path
if output is None:
# prepare output tmp file so we can read it back
tmp = temp_file()
path = tmp.path
else:
tmp = None
path = output

out = new_grib_output(path)

seen = set()

for i, center_field in enumerate(center):
param = center_field.metadata("param")
center_field_as_mars = center_field.as_mars()

# load the center field
center_np = center_field.to_numpy()
Expand All @@ -69,9 +105,21 @@ def check_compatible(f1, f2, ignore=["number"]):

for j in range(n_numbers):
ensemble_field = members[i * n_numbers + j]
check_compatible(center_field, ensemble_field)
ensemble_field_as_mars = ensemble_field.as_mars()
check_compatible(center_field, ensemble_field, center_field_as_mars, ensemble_field_as_mars)
members_np[j] = ensemble_field.to_numpy()

ensemble_field_as_mars = tuple(sorted(ensemble_field_as_mars.items()))
assert ensemble_field_as_mars not in seen, ensemble_field_as_mars
seen.add(ensemble_field_as_mars)

# cmin=np.amin(center_np)
# emin=np.amin(members_np)

# if cmin < 0 and emin >= 0:
# LOG.warning(f"Negative values in {param} cmin={cmin} emin={emin}")
# LOG.warning(f"Center: {center_field_as_mars}")

mean_np = members_np.mean(axis=0)

for j in range(n_numbers):
Expand All @@ -84,18 +132,22 @@ def check_compatible(f1, f2, ignore=["number"]):

x = c - m + e

if param in positive_clipping_variables:
warnings.warn(f"Clipping {param} to be positive")
if param in clip_variables:
# LOG.warning(f"Clipping {param} to be positive")
x = np.maximum(x, 0)

assert x.shape == e.shape, (x.shape, e.shape)

check_data_values(x, name=param)
out.write(x, template=template)
template = None

assert len(seen) == len(members), (len(seen), len(members))

out.close()

if output is not None:
return path

from climetlab import load_source

ds = load_source("file", path)
Expand Down
6 changes: 3 additions & 3 deletions src/anemoi/datasets/create/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def patch(self, **kwargs):

apply_patch(self.path, **kwargs)

def init_additions(self, delta=[1, 3, 6, 12]):
def init_additions(self, delta=[1, 3, 6, 12, 24]):
from .loaders import StatisticsAddition
from .loaders import TendenciesStatisticsAddition
from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency
Expand All @@ -109,7 +109,7 @@ def init_additions(self, delta=[1, 3, 6, 12]):
except TendenciesStatisticsDeltaNotMultipleOfFrequency:
self.print(f"Skipping delta={d} as it is not a multiple of the frequency.")

def run_additions(self, parts=None, delta=[1, 3, 6, 12]):
def run_additions(self, parts=None, delta=[1, 3, 6, 12, 24]):
from .loaders import StatisticsAddition
from .loaders import TendenciesStatisticsAddition
from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency
Expand All @@ -124,7 +124,7 @@ def run_additions(self, parts=None, delta=[1, 3, 6, 12]):
except TendenciesStatisticsDeltaNotMultipleOfFrequency:
self.print(f"Skipping delta={d} as it is not a multiple of the frequency.")

def finalise_additions(self, delta=[1, 3, 6, 12]):
def finalise_additions(self, delta=[1, 3, 6, 12, 24]):
from .loaders import StatisticsAddition
from .loaders import TendenciesStatisticsAddition
from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency
Expand Down
57 changes: 39 additions & 18 deletions src/anemoi/datasets/create/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,12 +546,17 @@ def write_stats_to_stdout(self, stats):


class GenericAdditions(GenericDatasetHandler):
def __init__(self, name="", **kwargs):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.name = name
self.tmp_storage = build_storage(directory=self.tmp_storage_path, create=True)

@property
def tmp_storage_path(self):
raise NotImplementedError

storage_path = f"{self.path}.tmp_storage_{name}"
self.tmp_storage = build_storage(directory=storage_path, create=True)
@property
def final_storage_path(self):
raise NotImplementedError

def initialise(self):
self.tmp_storage.delete()
Expand Down Expand Up @@ -589,7 +594,7 @@ def finalise(self):
count=np.full(shape, -1, dtype=np.int64),
has_nans=np.full(shape, False, dtype=np.bool_),
)
LOG.info(f"Aggregating {self.name} statistics on shape={shape}. Variables : {self.variables}")
LOG.info(f"Aggregating {self.__class__.__name__} statistics on shape={shape}. Variables : {self.variables}")

found = set()
ifound = set()
Expand Down Expand Up @@ -659,17 +664,18 @@ def finalise(self):

def _write(self, summary):
for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]:
self._add_dataset(name=k, array=summary[k])
self.registry.add_to_history("compute_statistics_end")
LOG.info(f"Wrote {self.name} additions in {self.path}")
name = self.final_storage_name(k)
self._add_dataset(name=name, array=summary[k])
self.registry.add_to_history(f"compute_statistics_{self.__class__.__name__.lower()}_end")
LOG.info(f"Wrote additions in {self.path} ({self.final_storage_name('*')})")

def check_statistics(self):
pass


class StatisticsAddition(GenericAdditions):
def __init__(self, **kwargs):
super().__init__("statistics_", **kwargs)
super().__init__(**kwargs)

z = zarr.open(self.path, mode="r")
start = z.attrs["statistics_start_date"]
Expand All @@ -682,6 +688,13 @@ def __init__(self, **kwargs):
assert len(self.variables) == self.ds.shape[1], self.ds.shape
self.total = len(self.dates)

@property
def tmp_storage_path(self):
return f"{self.path}.tmp_storage_statistics"

def final_storage_name(self, k):
return k

def run(self, parts):
chunk_filter = ChunkFilter(parts=parts, total=self.total)
for i in range(0, self.total):
Expand Down Expand Up @@ -725,8 +738,6 @@ class TendenciesStatisticsDeltaNotMultipleOfFrequency(ValueError):


class TendenciesStatisticsAddition(GenericAdditions):
DATASET_NAME_PATTERN = "statistics_tendencies_{delta}"

def __init__(self, path, delta=None, **kwargs):
full_ds = open_dataset(path)
self.variables = full_ds.variables
Expand All @@ -739,9 +750,10 @@ def __init__(self, path, delta=None, **kwargs):
raise TendenciesStatisticsDeltaNotMultipleOfFrequency(
f"Delta {delta} is not a multiple of frequency {frequency}"
)
self.delta = delta
idelta = delta // frequency

super().__init__(path=path, name=self.DATASET_NAME_PATTERN.format(delta=f"{delta}h"), **kwargs)
super().__init__(path=path, **kwargs)

z = zarr.open(self.path, mode="r")
start = z.attrs["statistics_start_date"]
Expand All @@ -754,6 +766,21 @@ def __init__(self, path, delta=None, **kwargs):
ds = open_dataset(self.path, start=start, end=end)
self.ds = DeltaDataset(ds, idelta)

@property
def tmp_storage_path(self):
return f"{self.path}.tmp_storage_statistics_{self.delta}h"

def final_storage_name(self, k):
return self.final_storage_name_from_delta(k, delta=self.delta)

@classmethod
def final_storage_name_from_delta(_, k, delta):
if isinstance(delta, int):
delta = str(delta)
if not delta.endswith("h"):
delta = delta + "h"
return f"statistics_tendencies_{delta}_{k}"

def run(self, parts):
chunk_filter = ChunkFilter(parts=parts, total=self.total)
for i in range(0, self.total):
Expand All @@ -768,9 +795,3 @@ def run(self, parts):
self.tmp_storage.add([date, i, "missing"], key=date)
self.tmp_storage.flush()
LOG.info(f"Dataset {self.path} additions run.")

def _write(self, summary):
for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]:
self._add_dataset(name=f"{self.name}_{k}", array=summary[k])
self.registry.add_to_history(f"compute_{self.name}_end")
LOG.info(f"Wrote {self.name} additions in {self.path}")
2 changes: 2 additions & 0 deletions src/anemoi/datasets/create/size.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import logging
import os

from anemoi.utils.humanize import bytes

from anemoi.datasets.create.utils import progress_bar

LOG = logging.getLogger(__name__)
Expand Down
10 changes: 5 additions & 5 deletions src/anemoi/datasets/data/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,12 @@ def statistics_tendencies(self, delta=None):
delta = f"{delta}h"
from anemoi.datasets.create.loaders import TendenciesStatisticsAddition

prefix = TendenciesStatisticsAddition.DATASET_NAME_PATTERN.format(delta=delta) + "_"
func = TendenciesStatisticsAddition.final_storage_name_from_delta
return dict(
mean=self.z[f"{prefix}mean"][:],
stdev=self.z[f"{prefix}stdev"][:],
maximum=self.z[f"{prefix}maximum"][:],
minimum=self.z[f"{prefix}minimum"][:],
mean=self.z[func("mean", delta)][:],
stdev=self.z[func("stdev", delta)][:],
maximum=self.z[func("maximum", delta)][:],
minimum=self.z[func("minimum", delta)][:],
)

@property
Expand Down

0 comments on commit a3dda3a

Please sign in to comment.