Skip to content

Commit

Permalink
name supporting_arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Oct 27, 2024
1 parent b36be84 commit 4d7dee5
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 44 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Keep it human-readable, your future self will thank you!
- Control compatibility check in xy/zip
- Add `merge` feature
- Add support for storing `supporting_arrays` in checkpoint files
- Allow naming of datasets components

### Changed

Expand Down
130 changes: 98 additions & 32 deletions src/anemoi/datasets/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,34 @@
LOG = logging.getLogger(__name__)


def _tidy(v):
if isinstance(v, (list, tuple, set)):
return [_tidy(i) for i in v]
if isinstance(v, dict):
return {k: _tidy(v) for k, v in v.items()}
if isinstance(v, str) and v.startswith("/"):
return os.path.basename(v)
if isinstance(v, datetime.datetime):
return v.isoformat()
if isinstance(v, datetime.date):
return v.isoformat()
if isinstance(v, datetime.timedelta):
return frequency_to_string(v)

if isinstance(v, Dataset):
# That can happen in the `arguments`
# if a dataset is passed as an argument
return repr(v)

if isinstance(v, slice):
return (v.start, v.stop, v.step)

return v


class Dataset:
arguments = {}
_name = None

def mutate(self) -> "Dataset":
"""Give an opportunity to a subclass to return a new Dataset
Expand All @@ -41,6 +67,21 @@ def _len(self):
return len(self)

def _subset(self, **kwargs):

if not kwargs:
return self.mutate()

name = kwargs.pop("name", None)
result = self.__subset(**kwargs)
result._name = name

return result

@property
def name(self):
return self._name

def __subset(self, **kwargs):
if not kwargs:
return self.mutate()

Expand Down Expand Up @@ -254,41 +295,32 @@ def typed_variables(self):

return result

def _input_sources(self):
sources = []
self.collect_input_sources(sources)
return sources

def metadata(self):
import anemoi

def tidy(v):
if isinstance(v, (list, tuple, set)):
return [tidy(i) for i in v]
if isinstance(v, dict):
return {k: tidy(v) for k, v in v.items()}
if isinstance(v, str) and v.startswith("/"):
return os.path.basename(v)
if isinstance(v, datetime.datetime):
return v.isoformat()
if isinstance(v, datetime.date):
return v.isoformat()
if isinstance(v, datetime.timedelta):
return frequency_to_string(v)

if isinstance(v, Dataset):
# That can happen in the `arguments`
# if a dataset is passed as an argument
return repr(v)

if isinstance(v, slice):
return (v.start, v.stop, v.step)

return v
_, source_to_arrays = self._supporting_arrays_and_sources()

sources = []
for i, source in enumerate(self._input_sources()):
source_metadata = source.dataset_metadata().copy()
source_metadata["supporting_arrays"] = source_to_arrays[id(source)]
sources.append(source_metadata)

md = dict(
version=anemoi.datasets.__version__,
arguments=self.arguments,
**self.dataset_metadata(),
sources=sources,
supporting_arrays=source_to_arrays[id(self)],
)

try:
return json.loads(json.dumps(tidy(md)))
return json.loads(json.dumps(_tidy(md)))
except Exception:
LOG.exception("Failed to serialize metadata")
pprint.pprint(md)
Expand All @@ -313,29 +345,63 @@ def dataset_metadata(self):
dtype=str(self.dtype),
start_date=self.start_date.astype(str),
end_date=self.end_date.astype(str),
name=self.name,
)

def supporting_arrays(self):
"""Arrays to be saved in the checkpoints"""
def _supporting_arrays(self, *path):

import numpy as np

result = dict(latitudes=self.latitudes, longitudes=self.longitudes)
def _path(path, name):
return "/".join(str(_) for _ in [*path, name])

result = {
_path(path, "latitudes"): self.latitudes,
_path(path, "longitudes"): self.longitudes,
}
collected = []
self.collect_supporting_arrays(collected)

self.collect_supporting_arrays(collected, *path)

for path, name, array in collected:
assert isinstance(path, tuple) and isinstance(name, str)
assert isinstance(array, np.ndarray)

path = "/".join([*path, name])
name = _path(path, name)

if path in result:
raise ValueError(f"Duplicate key {path}")
if name in result:
raise ValueError(f"Duplicate key {name}")

result[path] = array
result[name] = array

return result

def support_arrays(self):
"""Arrays to be saved in the checkpoints"""
arrays, _ = self._supporting_arrays_and_sources()
return arrays

def _supporting_arrays_and_sources(self):

source_to_arrays = {}

# Top levels arrays
result = self._supporting_arrays()
source_to_arrays[id(self)] = sorted(result.keys())

# Arrays from the input sources
for i, source in enumerate(self._input_sources()):
name = source.name if source.name is not None else i
src_arrays = source._supporting_arrays(name)
source_to_arrays[id(source)] = sorted(src_arrays.keys())

for k in src_arrays:
assert k not in result

result.update(src_arrays)

return result, source_to_arrays

def collect_supporting_arrays(self, collected, *path):
# Override this method to add more arrays
pass
Expand Down
14 changes: 13 additions & 1 deletion src/anemoi/datasets/data/forwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@


import logging
import warnings
from functools import cached_property

import numpy as np
Expand All @@ -34,6 +35,12 @@ def __len__(self):
def __getitem__(self, n):
return self.forward[n]

@property
def name(self):
if self._name is not None:
return self._name
return self.forward.name

@property
def dates(self):
return self.forward.dates
Expand Down Expand Up @@ -105,6 +112,9 @@ def metadata_specific(self, **kwargs):
def collect_supporting_arrays(self, collected, *path):
self.forward.collect_supporting_arrays(collected, *path)

def collect_input_sources(self, collected):
self.forward.collect_input_sources(collected)

def source(self, index):
return self.forward.source(index)

Expand Down Expand Up @@ -201,8 +211,10 @@ def metadata_specific(self, **kwargs):
)

def collect_supporting_arrays(self, collected, *path):
warnings.warn(f"The behaviour of {self.__class__.__name__}.collect_supporting_arrays() is not well defined")
for i, d in enumerate(self.datasets):
d.collect_supporting_arrays(collected, *path, str(i))
name = d.name if d.name is not None else i
d.collect_supporting_arrays(collected, *path, name)

@property
def missing(self):
Expand Down
22 changes: 15 additions & 7 deletions src/anemoi/datasets/data/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import numpy as np

from ..grids import serialise_mask
from .debug import Node
from .debug import debug_indexing
from .forwards import Combined
Expand Down Expand Up @@ -109,6 +108,17 @@ def check_same_resolution(self, d1, d2):
# We don't check the resolution, because we want to be able to combine
pass

def metadata_specific(self):
return super().metadata_specific(
multi_grids=True,
)

def collect_input_sources(self, collected):
# We assume that,because they have different grids, they have different input sources
for d in self.datasets:
collected.append(d)
d.collect_input_sources(collected)


class Grids(GridsBase):
# TODO: select the statistics of the most global grid?
Expand Down Expand Up @@ -159,8 +169,6 @@ def __init__(self, datasets, axis, min_distance_km=None, cropping_distance=2.0,
)

def collect_supporting_arrays(self, collected, *path):
self.lam.collect_supporting_arrays(collected, *path, "lam")
self.globe.collect_supporting_arrays(collected, *path, "global")
collected.append((path, "cutout_mask", self.mask))

@cached_property
Expand Down Expand Up @@ -218,10 +226,10 @@ def grids(self):
def tree(self):
return Node(self, [d.tree() for d in self.datasets])

def metadata_specific(self):
return super().metadata_specific(
mask=serialise_mask(self.mask),
)
# def metadata_specific(self):
# return super().metadata_specific(
# mask=serialise_mask(self.mask),
# )


def grids_factory(args, kwargs):
Expand Down
7 changes: 3 additions & 4 deletions src/anemoi/datasets/data/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import numpy as np

from ..grids import cropping_mask
from ..grids import serialise_mask
from .dataset import Dataset
from .debug import Node
from .debug import debug_indexing
Expand Down Expand Up @@ -71,7 +70,7 @@ def _get_tuple(self, index):
return result

def collect_supporting_arrays(self, collected, *path):
super().collect_supporting_arrays(collected, *path, "mask")
super().collect_supporting_arrays(collected, *path)
collected.append((path, self.mask_name, self.mask))


Expand Down Expand Up @@ -114,7 +113,7 @@ def tree(self):
return Node(self, [self.forward.tree()], thinning=self.thinning, method=self.method)

def subclass_metadata_specific(self):
return dict(thinning=self.thinning, method=self.method, mask=serialise_mask(self.mask))
return dict(thinning=self.thinning, method=self.method)


class Cropping(Masked):
Expand All @@ -140,4 +139,4 @@ def tree(self):
return Node(self, [self.forward.tree()], area=self.area)

def subclass_metadata_specific(self):
return dict(area=self.area, mask=serialise_mask(self.mask))
return dict(area=self.area)
6 changes: 6 additions & 0 deletions src/anemoi/datasets/data/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,12 @@ def get_dataset_names(self, names):
name, _ = os.path.splitext(os.path.basename(self.path))
names.add(name)

def collect_supporting_arrays(self, collected, *path):
pass

def collect_input_sources(self, collected):
pass


class ZarrWithMissingDates(Zarr):
"""A zarr dataset with missing dates."""
Expand Down

0 comments on commit 4d7dee5

Please sign in to comment.