Skip to content

Commit

Permalink
Feature/masks (#104)
Browse files Browse the repository at this point in the history
* add masks

Co-authored-by: Florian Pinault <floriankrb@users.noreply.github.com>
  • Loading branch information
b8raoult and floriankrb authored Oct 28, 2024
1 parent 3620e8d commit 94a89e0
Show file tree
Hide file tree
Showing 10 changed files with 247 additions and 25 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ Keep it human-readable, your future self will thank you!
- Various bug fixes
- Control compatibility check in xy/zip
- Add `merge` feature
- Add support for storing `supporting_arrays` in checkpoint files
- Allow naming of datasets components
- Contributors file (#105)

### Changed
Expand Down
23 changes: 23 additions & 0 deletions src/anemoi/datasets/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,30 @@ class MissingDateError(Exception):
pass


def _convert(x):

if isinstance(x, list):
return [_convert(a) for a in x]

if isinstance(x, tuple):
return tuple(_convert(a) for a in x)

if isinstance(x, dict):
return {k: _convert(v) for k, v in x.items()}

if x.__class__.__name__ in ("DictConfig", "ListConfig"):
from omegaconf import OmegaConf

return OmegaConf.to_container(x, resolve=True)

return x


def open_dataset(*args, **kwargs):

# That will get rid of OmegaConf objects
args, kwargs = _convert(args), _convert(kwargs)

ds = _open_dataset(*args, **kwargs)
ds = ds.mutate()
ds.arguments = {"args": args, "kwargs": kwargs}
Expand Down
139 changes: 115 additions & 24 deletions src/anemoi/datasets/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,34 @@
LOG = logging.getLogger(__name__)


def _tidy(v):
if isinstance(v, (list, tuple, set)):
return [_tidy(i) for i in v]
if isinstance(v, dict):
return {k: _tidy(v) for k, v in v.items()}
if isinstance(v, str) and v.startswith("/"):
return os.path.basename(v)
if isinstance(v, datetime.datetime):
return v.isoformat()
if isinstance(v, datetime.date):
return v.isoformat()
if isinstance(v, datetime.timedelta):
return frequency_to_string(v)

if isinstance(v, Dataset):
# That can happen in the `arguments`
# if a dataset is passed as an argument
return repr(v)

if isinstance(v, slice):
return (v.start, v.stop, v.step)

return v


class Dataset:
arguments = {}
_name = None

def mutate(self) -> "Dataset":
"""Give an opportunity to a subclass to return a new Dataset
Expand All @@ -41,6 +67,21 @@ def _len(self):
return len(self)

def _subset(self, **kwargs):

if not kwargs:
return self.mutate()

name = kwargs.pop("name", None)
result = self.__subset(**kwargs)
result._name = name

return result

@property
def name(self):
return self._name

def __subset(self, **kwargs):
if not kwargs:
return self.mutate()

Expand Down Expand Up @@ -254,41 +295,32 @@ def typed_variables(self):

return result

def _input_sources(self):
sources = []
self.collect_input_sources(sources)
return sources

def metadata(self):
import anemoi

def tidy(v):
if isinstance(v, (list, tuple, set)):
return [tidy(i) for i in v]
if isinstance(v, dict):
return {k: tidy(v) for k, v in v.items()}
if isinstance(v, str) and v.startswith("/"):
return os.path.basename(v)
if isinstance(v, datetime.datetime):
return v.isoformat()
if isinstance(v, datetime.date):
return v.isoformat()
if isinstance(v, datetime.timedelta):
return frequency_to_string(v)

if isinstance(v, Dataset):
# That can happen in the `arguments`
# if a dataset is passed as an argument
return repr(v)

if isinstance(v, slice):
return (v.start, v.stop, v.step)

return v
_, source_to_arrays = self._supporting_arrays_and_sources()

sources = []
for i, source in enumerate(self._input_sources()):
source_metadata = source.dataset_metadata().copy()
source_metadata["supporting_arrays"] = source_to_arrays[id(source)]
sources.append(source_metadata)

md = dict(
version=anemoi.datasets.__version__,
arguments=self.arguments,
**self.dataset_metadata(),
sources=sources,
supporting_arrays=source_to_arrays[id(self)],
)

try:
return json.loads(json.dumps(tidy(md)))
return json.loads(json.dumps(_tidy(md)))
except Exception:
LOG.exception("Failed to serialize metadata")
pprint.pprint(md)
Expand All @@ -313,8 +345,67 @@ def dataset_metadata(self):
dtype=str(self.dtype),
start_date=self.start_date.astype(str),
end_date=self.end_date.astype(str),
name=self.name,
)

def _supporting_arrays(self, *path):

import numpy as np

def _path(path, name):
return "/".join(str(_) for _ in [*path, name])

result = {
_path(path, "latitudes"): self.latitudes,
_path(path, "longitudes"): self.longitudes,
}
collected = []

self.collect_supporting_arrays(collected, *path)

for path, name, array in collected:
assert isinstance(path, tuple) and isinstance(name, str)
assert isinstance(array, np.ndarray)

name = _path(path, name)

if name in result:
raise ValueError(f"Duplicate key {name}")

result[name] = array

return result

def supporting_arrays(self):
"""Arrays to be saved in the checkpoints"""
arrays, _ = self._supporting_arrays_and_sources()
return arrays

def _supporting_arrays_and_sources(self):

source_to_arrays = {}

# Top levels arrays
result = self._supporting_arrays()
source_to_arrays[id(self)] = sorted(result.keys())

# Arrays from the input sources
for i, source in enumerate(self._input_sources()):
name = source.name if source.name is not None else i
src_arrays = source._supporting_arrays(name)
source_to_arrays[id(source)] = sorted(src_arrays.keys())

for k in src_arrays:
assert k not in result

result.update(src_arrays)

return result, source_to_arrays

def collect_supporting_arrays(self, collected, *path):
# Override this method to add more arrays
pass

def metadata_specific(self, **kwargs):
action = self.__class__.__name__.lower()
# assert isinstance(self.frequency, datetime.timedelta), (self.frequency, self, action)
Expand Down
19 changes: 19 additions & 0 deletions src/anemoi/datasets/data/forwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


import logging
import warnings
from functools import cached_property

import numpy as np
Expand All @@ -34,6 +35,12 @@ def __len__(self):
def __getitem__(self, n):
return self.forward[n]

@property
def name(self):
if self._name is not None:
return self._name
return self.forward.name

@property
def dates(self):
return self.forward.dates
Expand Down Expand Up @@ -102,6 +109,12 @@ def metadata_specific(self, **kwargs):
**kwargs,
)

def collect_supporting_arrays(self, collected, *path):
self.forward.collect_supporting_arrays(collected, *path)

def collect_input_sources(self, collected):
self.forward.collect_input_sources(collected)

def source(self, index):
return self.forward.source(index)

Expand Down Expand Up @@ -197,6 +210,12 @@ def metadata_specific(self, **kwargs):
**kwargs,
)

def collect_supporting_arrays(self, collected, *path):
warnings.warn(f"The behaviour of {self.__class__.__name__}.collect_supporting_arrays() is not well defined")
for i, d in enumerate(self.datasets):
name = d.name if d.name is not None else i
d.collect_supporting_arrays(collected, *path, name)

@property
def missing(self):
raise NotImplementedError("missing() not implemented for Combined")
Expand Down
21 changes: 20 additions & 1 deletion src/anemoi/datasets/data/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ def check_same_resolution(self, d1, d2):
# We don't check the resolution, because we want to be able to combine
pass

def metadata_specific(self):
return super().metadata_specific(
multi_grids=True,
)

def collect_input_sources(self, collected):
# We assume that,because they have different grids, they have different input sources
for d in self.datasets:
collected.append(d)
d.collect_input_sources(collected)


class Grids(GridsBase):
# TODO: select the statistics of the most global grid?
Expand Down Expand Up @@ -157,6 +168,9 @@ def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0,
self.globe.shape[3],
)

def collect_supporting_arrays(self, collected, *path):
collected.append((path, "cutout_mask", self.mask))

@cached_property
def shape(self):
shape = self.lam.shape
Expand Down Expand Up @@ -212,6 +226,11 @@ def grids(self):
def tree(self):
return Node(self, [d.tree() for d in self.datasets])

# def metadata_specific(self):
# return super().metadata_specific(
# mask=serialise_mask(self.mask),
# )


def grids_factory(args, kwargs):
if "ensemble" in kwargs:
Expand Down Expand Up @@ -241,7 +260,7 @@ def cutout_factory(args, kwargs):
neighbours = kwargs.pop("neighbours", 5)

assert len(args) == 0
assert isinstance(cutout, (list, tuple))
assert isinstance(cutout, (list, tuple)), "cutout must be a list or tuple"

datasets = [_open(e) for e in cutout]
datasets, kwargs = _auto_adjust(datasets, kwargs)
Expand Down
8 changes: 8 additions & 0 deletions src/anemoi/datasets/data/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def __init__(self, forward, mask):
self.mask = mask
self.axis = 3

self.mask_name = f"{self.__class__.__name__.lower()}_mask"

@cached_property
def shape(self):
return self.forward.shape[:-1] + (np.count_nonzero(self.mask),)
Expand Down Expand Up @@ -67,8 +69,13 @@ def _get_tuple(self, index):
result = apply_index_to_slices_changes(result, changes)
return result

def collect_supporting_arrays(self, collected, *path):
super().collect_supporting_arrays(collected, *path)
collected.append((path, self.mask_name, self.mask))


class Thinning(Masked):

def __init__(self, forward, thinning, method):
self.thinning = thinning
self.method = method
Expand Down Expand Up @@ -110,6 +117,7 @@ def subclass_metadata_specific(self):


class Cropping(Masked):

def __init__(self, forward, area):
from ..data import open_dataset

Expand Down
1 change: 1 addition & 0 deletions src/anemoi/datasets/data/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ def _auto_adjust(datasets, kwargs):


def _open_dataset(*args, **kwargs):

sets = []
for a in args:
sets.append(_open(a))
Expand Down
6 changes: 6 additions & 0 deletions src/anemoi/datasets/data/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,12 @@ def get_dataset_names(self, names):
name, _ = os.path.splitext(os.path.basename(self.path))
names.add(name)

def collect_supporting_arrays(self, collected, *path):
pass

def collect_input_sources(self, collected):
pass


class ZarrWithMissingDates(Zarr):
"""A zarr dataset with missing dates."""
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/datasets/data/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def dates(self):
@cached_property
def frequency(self):
dates = self.dates
if len(dates) < 2:
raise ValueError(f"Cannot determine frequency of a subset with less than two dates ({self.dates}).")
return frequency_to_timedelta(dates[1].astype(object) - dates[0].astype(object))

def source(self, index):
Expand Down
Loading

0 comments on commit 94a89e0

Please sign in to comment.