diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 1bbc9a1e..32074d93 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -49,38 +49,6 @@ jobs: - name: Tests run: pytest - # notify-failure: - # if: failure() - # runs-on: ubuntu-latest - # needs: [quality, checks] - # name: Notify failure - # steps: - # - uses: jdcargile/ms-teams-notification@v1.4 - # with: - # github-token: ${{ github.token }} - # ms-teams-webhook-uri: ${{ secrets.MS_TEAMS_WEBHOOK_URI_F }} - # # notification-summary: ${{ steps.qa.outputs.status }} - # notification-summary: ❌ Build failed on anemoi.datasets! - # notification-color: dc3545 - # timezone: Europe/Paris - # verbose-logging: true - - # notify-success: - # if: success() - # runs-on: ubuntu-latest - # needs: [quality, checks] - # name: Notify success - # steps: - # - uses: jdcargile/ms-teams-notification@v1.4 - # with: - # github-token: ${{ github.token }} - # ms-teams-webhook-uri: ${{ secrets.MS_TEAMS_WEBHOOK_URI_F }} - # # notification-summary: ${{ steps.qa.outputs.status }} - # notification-summary: ✅ New commit on anemoi.datasets - # notification-color: 17a2b8 - # timezone: Europe/Paris - # verbose-logging: true - deploy: if: ${{ github.event_name == 'release' }} @@ -93,24 +61,16 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: '3.10' - - - name: Check that tag version matches code version - run: | - tag=${GITHUB_REF#refs/tags/} - version=$(python setup.py --version) - echo 'tag='$tag - echo "version file="$version - test "$tag" == "$version" + python-version: 3.x - name: Install dependencies run: | python -m pip install --upgrade pip - pip install setuptools wheel twine + pip install build wheel twine - name: Build and publish env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} run: | - python setup.py sdist + python -m build twine upload dist/* diff --git a/.gitignore b/.gitignore index fdc24f30..64958f06 100644 --- a/.gitignore +++ b/.gitignore @@ -159,6 +159,8 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +_version.py + *.grib *.onnx *.ckpt @@ -184,3 +186,4 @@ _build/ ?.* ~* *.sync +*.dot diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 05d460cb..cfcfcc8e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -51,23 +51,6 @@ repos: - --exit-non-zero-on-fix - --preview -# - repo: https://github.com/thclark/pre-commit-sphinx -# rev: 0.0.1 -# hooks: -# - id: build-docs -# additional_dependencies: -# - sphinx -# - sphinx_rtd_theme -# - nbsphinx -# - pandoc -# args: -# - --cache-dir -# - docs/_build/doctrees -# - --html-dir -# - docs/_build/html -# - --source-dir -# - docs -# language_version: python3 - repo: https://github.com/sphinx-contrib/sphinx-lint rev: v0.9.1 @@ -80,8 +63,8 @@ repos: hooks: - id: rstfmt -# - repo: https://github.com/b8raoult/pre-commit-docconvert -# rev: "0.1.0" -# hooks: -# - id: docconvert -# args: ["-o", "numpy"] +- repo: https://github.com/b8raoult/pre-commit-docconvert + rev: "0.1.4" + hooks: + - id: docconvert + args: ["numpy"] diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 45958d1f..c03429e5 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -3,11 +3,15 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.10" + python: "3.11" sphinx: configuration: docs/conf.py python: install: - - requirements: docs/requirements.txt + - requirements: docs/requirements.txt + - method: pip + path: . + extra_requirements: + - docs diff --git a/anemoi/datasets/create/__init__.py b/anemoi/datasets/create/__init__.py deleted file mode 100644 index 3d85d2a0..00000000 --- a/anemoi/datasets/create/__init__.py +++ /dev/null @@ -1,119 +0,0 @@ -# (C) Copyright 2023 ECMWF. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# - -import os - - -class Creator: - def __init__( - self, - path, - config=None, - cache=None, - print=print, - statistics_tmp=None, - overwrite=False, - **kwargs, - ): - self.path = path # Output path - self.config = config - self.cache = cache - self.print = print - self.statistics_tmp = statistics_tmp - self.overwrite = overwrite - - def init(self, check_name=False): - # check path - _, ext = os.path.splitext(self.path) - assert ext != "zarr", f"Unsupported extension={ext}" - from .loaders import InitialiseLoader - - if self._path_readable() and not self.overwrite: - raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.") - - with self._cache_context(): - obj = InitialiseLoader.from_config( - path=self.path, - config=self.config, - statistics_tmp=self.statistics_tmp, - print=self.print, - ) - obj.initialise(check_name=check_name) - - def load(self, parts=None): - from .loaders import ContentLoader - - with self._cache_context(): - loader = ContentLoader.from_dataset_config( - path=self.path, - statistics_tmp=self.statistics_tmp, - print=self.print, - parts=parts, - ) - loader.load() - - def statistics(self, force=False, output=None, start=None, end=None): - from .loaders import StatisticsLoader - - loader = StatisticsLoader.from_dataset( - path=self.path, - print=self.print, - force=force, - statistics_tmp=self.statistics_tmp, - statistics_output=output, - recompute=False, - statistics_start=start, - statistics_end=end, - ) - loader.run() - - def size(self): - from .loaders import SizeLoader - - loader = SizeLoader.from_dataset(path=self.path, print=self.print) - loader.add_total_size() - - def cleanup(self): - from .loaders import CleanupLoader - - loader = CleanupLoader.from_dataset( - path=self.path, - print=self.print, - statistics_tmp=self.statistics_tmp, - ) - loader.run() - - def patch(self, **kwargs): - from .patch import apply_patch - - apply_patch(self.path, **kwargs) - - def finalise(self, **kwargs): - self.statistics(**kwargs) - self.size() - - def create(self): - self.init() - self.load() - self.finalise() - self.cleanup() - - def _cache_context(self): - from .utils import cache_context - - return cache_context(self.cache) - - def _path_readable(self): - import zarr - - try: - zarr.open(self.path, "r") - return True - except zarr.errors.PathNotFoundError: - return False diff --git a/docs/building/handling-missing-dates.rst b/docs/building/handling-missing-dates.rst index 3d68b377..0473cd11 100644 --- a/docs/building/handling-missing-dates.rst +++ b/docs/building/handling-missing-dates.rst @@ -2,6 +2,8 @@ Handling missing dates ######################## +By default, the package will raise an error if there are missing dates. + Missing dates can be handled by specifying a list of dates in the configuration file. The dates should be in the same format as the dates in the time series. The missing dates will be filled ``np.nan`` values. diff --git a/docs/conf.py b/docs/conf.py index 62a1f298..c7d40d58 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,7 +12,7 @@ # # import os # import sys -# sys.path.insert(0, os.path.abspath('.')) +# sys.path.insert(0, os.path.join(os.path.abspath('.'), 'src')) import datetime import os diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..619ba9f5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# https://packaging.python.org/en/latest/guides/writing-pyproject-toml/ + +[build-system] +requires = ["setuptools>=60", "setuptools-scm>=8.0"] + +[project] +description = "A package to hold various functions to support training of ML models on ECMWF data." +name = "anemoi-dataset" + +dynamic = ["version"] +license = { file = "LICENSE" } +requires-python = ">=3.9" + +authors = [ + { name = "European Centre for Medium-Range Weather Forecasts (ECMWF)", email = "software.support@ecmwf.int" }, +] + +keywords = ["tools", "datasets", "ai"] + +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Operating System :: OS Independent", +] + +dependencies = [ + "anemoi-utils[provenance]>=0.1.7", + "zarr", + "pyyaml", + "numpy", + "tqdm", + "semantic-version", +] + +[project.optional-dependencies] +remote = ["boto3", "requests", "s3fs"] + +create = [ + "climetlab>=0.22.1", # "earthkit-data" + "earthkit-meteo", + "pyproj", + "ecmwflibs>=0.6.3", +] + +docs = ["sphinx", "sphinx_rtd_theme", "nbsphinx", "pandoc"] + +all = [ + "boto3", + "requests", + "s3fs", + "climetlab>=0.22.1", # "earthkit-data" + "earthkit-meteo", + "pyproj", + "ecmwflibs>=0.6.3", +] + +dev = [ + "boto3", + "requests", + "s3fs", + "climetlab>=0.22.1", # "earthkit-data" + "earthkit-meteo", + "pyproj", + "ecmwflibs>=0.6.3", + "sphinx", + "sphinx_rtd_theme", + "nbsphinx", + "pandoc", +] + +[project.urls] +Homepage = "https://github.com/ecmwf/anemoi-datasets/" +Documentation = "https://anemoi-datasets.readthedocs.io/" +Repository = "https://github.com/ecmwf/anemoi-datasets/" +Issues = "https://github.com/ecmwf/anemoi-datasets/issues" +# Changelog = "https://github.com/ecmwf/anemoi-datasets/CHANGELOG.md" + +[project.scripts] +anemoi-datasets = "anemoi.datasets.__main__:main" + +[tool.setuptools_scm] +version_file = "src/anemoi/datasets/_version.py" + +[tool.setuptools.package-data] +"anemoi.datasets.data" = ["*.css"] diff --git a/setup.py b/setup.py deleted file mode 100644 index 6f02b4d3..00000000 --- a/setup.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python -# (C) Copyright 2023 ECMWF. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# - - -import io -import os - -import setuptools - - -def read(fname): - file_path = os.path.join(os.path.dirname(__file__), fname) - return io.open(file_path, encoding="utf-8").read() - - -version = None -for line in read("anemoi/datasets/__init__.py").split("\n"): - if line.startswith("__version__"): - version = line.split("=")[-1].strip()[1:-1] - - -assert version - - -data_requires = [ - "anemoi-utils[provenance]", - "zarr", - "pyyaml", - "numpy", - "tqdm", - "semantic-version", -] - -remote_requires = [ - "boto3", - "requests", - "s3fs", # prepml copy only -] - - -create_requires = [ - "zarr", - "numpy", - "tqdm", - "climetlab", # "earthkit-data" - "earthkit-meteo", - "pyproj", - "ecmwflibs>=0.6.3", -] - - -all_requires = data_requires + create_requires + remote_requires -dev_requires = ["sphinx", "sphinx_rtd_theme", "nbsphinx", "pandoc"] + all_requires - -setuptools.setup( - name="anemoi-datasets", - version=version, - description="A package to hold various functions to support training of ML models on ECMWF data.", - long_description=read("README.md"), - long_description_content_type="text/markdown", - author="European Centre for Medium-Range Weather Forecasts (ECMWF)", - author_email="software.support@ecmwf.int", - license="Apache License Version 2.0", - url="https://github.com/ecmwf/anemoi-datasets", - packages=setuptools.find_namespace_packages(include=["anemoi.*"]), - include_package_data=True, - install_requires=data_requires, - extras_require={ - "data": [], - "remote": data_requires + remote_requires, - "create": create_requires, - "dev": dev_requires, - "all": all_requires, - }, - zip_safe=True, - keywords="tool", - classifiers=[ - "Development Status :: 3 - Alpha", - "Intended Audience :: Developers", - "License :: OSI Approved :: Apache Software License", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", - "Operating System :: OS Independent", - ], - entry_points={"console_scripts": ["anemoi-datasets=anemoi.datasets.__main__:main"]}, -) diff --git a/anemoi/datasets/__init__.py b/src/anemoi/datasets/__init__.py similarity index 95% rename from anemoi/datasets/__init__.py rename to src/anemoi/datasets/__init__.py index 258394e9..0e29a95b 100644 --- a/anemoi/datasets/__init__.py +++ b/src/anemoi/datasets/__init__.py @@ -5,13 +5,12 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from ._version import __version__ from .data import MissingDateError from .data import add_dataset_path from .data import add_named_dataset from .data import open_dataset -__version__ = "0.1.4" - __all__ = [ "open_dataset", "MissingDateError", diff --git a/anemoi/datasets/__main__.py b/src/anemoi/datasets/__main__.py similarity index 100% rename from anemoi/datasets/__main__.py rename to src/anemoi/datasets/__main__.py diff --git a/anemoi/datasets/commands/__init__.py b/src/anemoi/datasets/commands/__init__.py similarity index 100% rename from anemoi/datasets/commands/__init__.py rename to src/anemoi/datasets/commands/__init__.py diff --git a/anemoi/datasets/commands/compare.py b/src/anemoi/datasets/commands/compare.py similarity index 100% rename from anemoi/datasets/commands/compare.py rename to src/anemoi/datasets/commands/compare.py diff --git a/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py similarity index 100% rename from anemoi/datasets/commands/copy.py rename to src/anemoi/datasets/commands/copy.py diff --git a/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py similarity index 100% rename from anemoi/datasets/commands/create.py rename to src/anemoi/datasets/commands/create.py diff --git a/anemoi/datasets/commands/inspect/__init__.py b/src/anemoi/datasets/commands/inspect/__init__.py similarity index 100% rename from anemoi/datasets/commands/inspect/__init__.py rename to src/anemoi/datasets/commands/inspect/__init__.py index d62ab192..d9937118 100644 --- a/anemoi/datasets/commands/inspect/__init__.py +++ b/src/anemoi/datasets/commands/inspect/__init__.py @@ -9,9 +9,9 @@ import os from .. import Command +from .zarr import InspectZarr # from .checkpoint import InspectCheckpoint -from .zarr import InspectZarr class Inspect(Command, InspectZarr): diff --git a/anemoi/datasets/commands/inspect/zarr.py b/src/anemoi/datasets/commands/inspect/zarr.py similarity index 100% rename from anemoi/datasets/commands/inspect/zarr.py rename to src/anemoi/datasets/commands/inspect/zarr.py diff --git a/anemoi/datasets/commands/scan.py b/src/anemoi/datasets/commands/scan.py similarity index 100% rename from anemoi/datasets/commands/scan.py rename to src/anemoi/datasets/commands/scan.py diff --git a/anemoi/datasets/create/functions/actions/perturbations.py b/src/anemoi/datasets/compute/perturbations.py similarity index 66% rename from anemoi/datasets/create/functions/actions/perturbations.py rename to src/anemoi/datasets/compute/perturbations.py index a678afed..c09041f3 100644 --- a/anemoi/datasets/create/functions/actions/perturbations.py +++ b/src/anemoi/datasets/compute/perturbations.py @@ -6,8 +6,8 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. # + import warnings -from copy import deepcopy import numpy as np from climetlab.core.temporary import temp_file @@ -15,49 +15,18 @@ from anemoi.datasets.create.check import check_data_values from anemoi.datasets.create.functions import assert_is_fieldset -from anemoi.datasets.create.functions.actions.mars import mars - - -def to_list(x): - if isinstance(x, (list, tuple)): - return x - if isinstance(x, str): - return x.split("/") - return [x] - - -def normalise_number(number): - number = to_list(number) - - if len(number) > 4 and (number[1] == "to" and number[3] == "by"): - return list(range(int(number[0]), int(number[2]) + 1, int(number[4]))) - - if len(number) > 2 and number[1] == "to": - return list(range(int(number[0]), int(number[2]) + 1)) - - return number -def normalise_request(request): - request = deepcopy(request) - if "number" in request: - request["number"] = normalise_number(request["number"]) - if "time" in request: - request["time"] = to_list(request["time"]) - request["param"] = to_list(request["param"]) - return request - - -def load_if_needed(context, dates, dict_or_dataset): - if isinstance(dict_or_dataset, dict): - dict_or_dataset = normalise_request(dict_or_dataset) - dict_or_dataset = mars(context, dates, dict_or_dataset) - return dict_or_dataset - - -def perturbations(context, dates, members, center, remapping={}, patches={}): - members = load_if_needed(context, dates, members) - center = load_if_needed(context, dates, center) +def perturbations( + members, + center, + positive_clipping_variables=[ + "q", + "cp", + "lsp", + "tp", + ], # add "swl4", "swl3", "swl2", "swl1", "swl0", and more ? +): keys = ["param", "level", "valid_datetime", "date", "time", "step", "number"] @@ -113,16 +82,9 @@ def check_compatible(f1, f2, ignore=["number"]): assert e.shape == c.shape == m.shape, (e.shape, c.shape, m.shape) - FORCED_POSITIVE = [ - "q", - "cp", - "lsp", - "tp", - ] # add "swl4", "swl3", "swl2", "swl1", "swl0", and more ? - x = c - m + e - if param in FORCED_POSITIVE: + if param in positive_clipping_variables: warnings.warn(f"Clipping {param} to be positive") x = np.maximum(x, 0) @@ -145,6 +107,3 @@ def check_compatible(f1, f2, ignore=["number"]): assert len(ds) == len(members), (len(ds), len(members)) return ds - - -execute = perturbations diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py new file mode 100644 index 00000000..b829f641 --- /dev/null +++ b/src/anemoi/datasets/create/__init__.py @@ -0,0 +1,170 @@ +# (C) Copyright 2023 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import os + + +class Creator: + def __init__( + self, + path, + config=None, + cache=None, + print=print, + statistics_tmp=None, + overwrite=False, + **kwargs, + ): + self.path = path # Output path + self.config = config + self.cache = cache + self.print = print + self.statistics_tmp = statistics_tmp + self.overwrite = overwrite + + def init(self, check_name=False): + # check path + _, ext = os.path.splitext(self.path) + assert ext != "zarr", f"Unsupported extension={ext}" + from .loaders import InitialiserLoader + + if self._path_readable() and not self.overwrite: + raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.") + + with self._cache_context(): + obj = InitialiserLoader.from_config( + path=self.path, + config=self.config, + statistics_tmp=self.statistics_tmp, + print=self.print, + ) + obj.initialise(check_name=check_name) + + def load(self, parts=None): + from .loaders import ContentLoader + + with self._cache_context(): + loader = ContentLoader.from_dataset_config( + path=self.path, + statistics_tmp=self.statistics_tmp, + print=self.print, + parts=parts, + ) + loader.load() + + def statistics(self, force=False, output=None, start=None, end=None): + from .loaders import StatisticsAdder + + loader = StatisticsAdder.from_dataset( + path=self.path, + print=self.print, + statistics_tmp=self.statistics_tmp, + statistics_output=output, + recompute=False, + statistics_start=start, + statistics_end=end, + ) + loader.run() + + def size(self): + from .loaders import DatasetHandler + from .size import compute_directory_sizes + + metadata = compute_directory_sizes(self.path) + handle = DatasetHandler.from_dataset(path=self.path, print=self.print) + handle.update_metadata(**metadata) + + def cleanup(self): + from .loaders import DatasetHandlerWithStatistics + + cleaner = DatasetHandlerWithStatistics.from_dataset( + path=self.path, print=self.print, statistics_tmp=self.statistics_tmp + ) + cleaner.tmp_statistics.delete() + cleaner.registry.clean() + + def patch(self, **kwargs): + from .patch import apply_patch + + apply_patch(self.path, **kwargs) + + def init_additions(self, delta=[1, 3, 6, 12]): + from .loaders import StatisticsAddition + from .loaders import TendenciesStatisticsAddition + from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency + + a = StatisticsAddition.from_dataset(path=self.path, print=self.print) + a.initialise() + + for d in delta: + try: + a = TendenciesStatisticsAddition.from_dataset(path=self.path, print=self.print, delta=d) + a.initialise() + except TendenciesStatisticsDeltaNotMultipleOfFrequency: + self.print(f"Skipping delta={d} as it is not a multiple of the frequency.") + + def run_additions(self, parts=None, delta=[1, 3, 6, 12]): + from .loaders import StatisticsAddition + from .loaders import TendenciesStatisticsAddition + from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency + + a = StatisticsAddition.from_dataset(path=self.path, print=self.print) + a.run(parts) + + for d in delta: + try: + a = TendenciesStatisticsAddition.from_dataset(path=self.path, print=self.print, delta=d) + a.run(parts) + except TendenciesStatisticsDeltaNotMultipleOfFrequency: + self.print(f"Skipping delta={d} as it is not a multiple of the frequency.") + + def finalise_additions(self, delta=[1, 3, 6, 12]): + from .loaders import StatisticsAddition + from .loaders import TendenciesStatisticsAddition + from .loaders import TendenciesStatisticsDeltaNotMultipleOfFrequency + + a = StatisticsAddition.from_dataset(path=self.path, print=self.print) + a.finalise() + + for d in delta: + try: + a = TendenciesStatisticsAddition.from_dataset(path=self.path, print=self.print, delta=d) + a.finalise() + except TendenciesStatisticsDeltaNotMultipleOfFrequency: + self.print(f"Skipping delta={d} as it is not a multiple of the frequency.") + + def finalise(self, **kwargs): + self.statistics(**kwargs) + self.size() + + def create(self): + self.init() + self.load() + self.finalise() + self.additions() + self.cleanup() + + def additions(self): + self.init_additions() + self.run_additions() + self.finalise_additions() + + def _cache_context(self): + from .utils import cache_context + + return cache_context(self.cache) + + def _path_readable(self): + import zarr + + try: + zarr.open(self.path, "r") + return True + except zarr.errors.PathNotFoundError: + return False diff --git a/anemoi/datasets/create/check.py b/src/anemoi/datasets/create/check.py similarity index 93% rename from anemoi/datasets/create/check.py rename to src/anemoi/datasets/create/check.py index 20398356..e24c7e02 100644 --- a/anemoi/datasets/create/check.py +++ b/src/anemoi/datasets/create/check.py @@ -8,29 +8,14 @@ # import logging -import os import re import warnings import numpy as np -import tqdm LOG = logging.getLogger(__name__) -def compute_directory_size(path): - if not os.path.isdir(path): - return None - size = 0 - n = 0 - for dirpath, _, filenames in tqdm.tqdm(os.walk(path), desc="Computing size", leave=False): - for filename in filenames: - file_path = os.path.join(dirpath, filename) - size += os.path.getsize(file_path) - n += 1 - return size, n - - class DatasetName: def __init__( self, diff --git a/src/anemoi/datasets/create/chunks.py b/src/anemoi/datasets/create/chunks.py new file mode 100644 index 00000000..4dc988f6 --- /dev/null +++ b/src/anemoi/datasets/create/chunks.py @@ -0,0 +1,78 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import logging +import warnings + +LOG = logging.getLogger(__name__) + +ALL = object() + + +class ChunkFilter: + def __init__(self, *, parts, total): + self.total = total + + if isinstance(parts, list): + if len(parts) == 1: + parts = parts[0] + elif len(parts) == 0: + parts = None + else: + raise ValueError(f"Invalid parts format: {parts}. Must be in the form 'i/n'.") + + if not parts: + parts = "all" + + assert isinstance(parts, str), f"Argument parts must be a string, got {parts}." + + if parts.lower() == "all" or parts == "*": + self.allowed = ALL + return + + assert "/" in parts, f"Invalid parts format: {parts}. Must be in the form 'i/n'." + + i, n = parts.split("/") + i, n = int(i), int(n) + + assert i > 0, f"Chunk number {i} must be positive." + assert i <= n, f"Chunk number {i} must be less than total chunks {n}." + if n > total: + warnings.warn( + f"Number of chunks {n} is larger than the total number of chunks: {total}. " + "Some chunks will be empty." + ) + + chunk_size = total / n + parts = [x for x in range(total) if x >= (i - 1) * chunk_size and x < i * chunk_size] + + for i in parts: + if i < 0 or i >= total: + raise AssertionError(f"Invalid chunk number {i}. Must be between 0 and {total - 1}.") + if not parts: + warnings.warn(f"Nothing to do for chunk {i}/{n}.") + + LOG.info(f"Running parts: {parts}") + + self.allowed = parts + + def __call__(self, i): + if i < 0 or i >= self.total: + raise AssertionError(f"Invalid chunk number {i}. Must be between 0 and {self.total - 1}.") + + if self.allowed == ALL: + return True + return i in self.allowed + + def __iter__(self): + for i in range(self.total): + if self(i): + yield i + + def __len__(self): + return len([_ for _ in self]) diff --git a/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py similarity index 99% rename from anemoi/datasets/create/config.py rename to src/anemoi/datasets/create/config.py index 7ff760fa..6acc17ac 100644 --- a/anemoi/datasets/create/config.py +++ b/src/anemoi/datasets/create/config.py @@ -153,6 +153,7 @@ def __init__(self, config, *args, **kwargs): self.setdefault("build", Config()) self.build.setdefault("group_by", "monthly") + self.build.setdefault("use_grib_paramid", False) self.setdefault("output", Config()) self.output.setdefault("order_by", ["valid_datetime", "param_level", "number"]) diff --git a/anemoi/datasets/create/functions/__init__.py b/src/anemoi/datasets/create/functions/__init__.py similarity index 100% rename from anemoi/datasets/create/functions/__init__.py rename to src/anemoi/datasets/create/functions/__init__.py diff --git a/anemoi/datasets/create/functions/actions/__init__.py b/src/anemoi/datasets/create/functions/filters/__init__.py similarity index 100% rename from anemoi/datasets/create/functions/actions/__init__.py rename to src/anemoi/datasets/create/functions/filters/__init__.py diff --git a/anemoi/datasets/create/functions/filters/empty.py b/src/anemoi/datasets/create/functions/filters/empty.py similarity index 100% rename from anemoi/datasets/create/functions/filters/empty.py rename to src/anemoi/datasets/create/functions/filters/empty.py diff --git a/anemoi/datasets/create/functions/filters/noop.py b/src/anemoi/datasets/create/functions/filters/noop.py similarity index 100% rename from anemoi/datasets/create/functions/filters/noop.py rename to src/anemoi/datasets/create/functions/filters/noop.py diff --git a/anemoi/datasets/create/functions/filters/rename.py b/src/anemoi/datasets/create/functions/filters/rename.py similarity index 100% rename from anemoi/datasets/create/functions/filters/rename.py rename to src/anemoi/datasets/create/functions/filters/rename.py diff --git a/anemoi/datasets/create/functions/filters/rotate_winds.py b/src/anemoi/datasets/create/functions/filters/rotate_winds.py similarity index 98% rename from anemoi/datasets/create/functions/filters/rotate_winds.py rename to src/anemoi/datasets/create/functions/filters/rotate_winds.py index a41e2d78..f8ca190a 100644 --- a/anemoi/datasets/create/functions/filters/rotate_winds.py +++ b/src/anemoi/datasets/create/functions/filters/rotate_winds.py @@ -13,9 +13,7 @@ def rotate_winds(lats, lons, x_wind, y_wind, source_projection, target_projection): - """ - Code provided by MetNO - """ + """Code provided by MetNO""" import numpy as np import pyproj diff --git a/anemoi/datasets/create/functions/filters/unrotate_winds.py b/src/anemoi/datasets/create/functions/filters/unrotate_winds.py similarity index 98% rename from anemoi/datasets/create/functions/filters/unrotate_winds.py rename to src/anemoi/datasets/create/functions/filters/unrotate_winds.py index c5fcde08..074d0806 100644 --- a/anemoi/datasets/create/functions/filters/unrotate_winds.py +++ b/src/anemoi/datasets/create/functions/filters/unrotate_winds.py @@ -78,9 +78,7 @@ def __getattr__(self, name): def execute(context, input, u, v): - """ - Unrotate the wind components of a GRIB file. - """ + """Unrotate the wind components of a GRIB file.""" result = FieldArray() wind_params = (u, v) diff --git a/anemoi/datasets/create/functions/filters/__init__.py b/src/anemoi/datasets/create/functions/sources/__init__.py similarity index 100% rename from anemoi/datasets/create/functions/filters/__init__.py rename to src/anemoi/datasets/create/functions/sources/__init__.py diff --git a/anemoi/datasets/create/functions/actions/accumulations.py b/src/anemoi/datasets/create/functions/sources/accumulations.py similarity index 100% rename from anemoi/datasets/create/functions/actions/accumulations.py rename to src/anemoi/datasets/create/functions/sources/accumulations.py diff --git a/anemoi/datasets/create/functions/actions/constants.py b/src/anemoi/datasets/create/functions/sources/constants.py similarity index 100% rename from anemoi/datasets/create/functions/actions/constants.py rename to src/anemoi/datasets/create/functions/sources/constants.py diff --git a/anemoi/datasets/create/functions/actions/empty.py b/src/anemoi/datasets/create/functions/sources/empty.py similarity index 100% rename from anemoi/datasets/create/functions/actions/empty.py rename to src/anemoi/datasets/create/functions/sources/empty.py diff --git a/anemoi/datasets/create/functions/actions/forcings.py b/src/anemoi/datasets/create/functions/sources/forcings.py similarity index 100% rename from anemoi/datasets/create/functions/actions/forcings.py rename to src/anemoi/datasets/create/functions/sources/forcings.py diff --git a/anemoi/datasets/create/functions/actions/grib.py b/src/anemoi/datasets/create/functions/sources/grib.py similarity index 100% rename from anemoi/datasets/create/functions/actions/grib.py rename to src/anemoi/datasets/create/functions/sources/grib.py diff --git a/anemoi/datasets/create/functions/actions/mars.py b/src/anemoi/datasets/create/functions/sources/mars.py similarity index 88% rename from anemoi/datasets/create/functions/actions/mars.py rename to src/anemoi/datasets/create/functions/sources/mars.py index d133e9df..e2a9fe24 100644 --- a/anemoi/datasets/create/functions/actions/mars.py +++ b/src/anemoi/datasets/create/functions/sources/mars.py @@ -82,6 +82,20 @@ def factorise_requests(dates, *requests): yield r +def use_grib_paramid(r): + from anemoi.utils.grib import shortname_to_paramid + + params = r["param"] + if isinstance(params, str): + params = params.split("/") + assert isinstance(params, (list, tuple)), params + + params = [shortname_to_paramid(p) for p in params] + r["param"] = "/".join(str(p) for p in params) + + return r + + def mars(context, dates, *requests, **kwargs): if not requests: requests = [kwargs] @@ -90,6 +104,10 @@ def mars(context, dates, *requests, **kwargs): ds = load_source("empty") for r in requests: r = {k: v for k, v in r.items() if v != ("-",)} + + if context.use_grib_paramid and "param" in r: + r = use_grib_paramid(r) + if DEBUG: context.trace("✅", f"load_source(mars, {r}") diff --git a/anemoi/datasets/create/functions/actions/netcdf.py b/src/anemoi/datasets/create/functions/sources/netcdf.py similarity index 100% rename from anemoi/datasets/create/functions/actions/netcdf.py rename to src/anemoi/datasets/create/functions/sources/netcdf.py diff --git a/anemoi/datasets/create/functions/actions/opendap.py b/src/anemoi/datasets/create/functions/sources/opendap.py similarity index 100% rename from anemoi/datasets/create/functions/actions/opendap.py rename to src/anemoi/datasets/create/functions/sources/opendap.py diff --git a/src/anemoi/datasets/create/functions/sources/perturbations.py b/src/anemoi/datasets/create/functions/sources/perturbations.py new file mode 100644 index 00000000..53428da8 --- /dev/null +++ b/src/anemoi/datasets/create/functions/sources/perturbations.py @@ -0,0 +1,59 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +from copy import deepcopy + +from anemoi.datasets.compute.perturbations import perturbations as compute_perturbations + +from .mars import mars + + +def to_list(x): + if isinstance(x, (list, tuple)): + return x + if isinstance(x, str): + return x.split("/") + return [x] + + +def normalise_number(number): + number = to_list(number) + + if len(number) > 4 and (number[1] == "to" and number[3] == "by"): + return list(range(int(number[0]), int(number[2]) + 1, int(number[4]))) + + if len(number) > 2 and number[1] == "to": + return list(range(int(number[0]), int(number[2]) + 1)) + + return number + + +def normalise_request(request): + request = deepcopy(request) + if "number" in request: + request["number"] = normalise_number(request["number"]) + if "time" in request: + request["time"] = to_list(request["time"]) + request["param"] = to_list(request["param"]) + return request + + +def load_if_needed(context, dates, dict_or_dataset): + if isinstance(dict_or_dataset, dict): + dict_or_dataset = normalise_request(dict_or_dataset) + dict_or_dataset = mars(context, dates, dict_or_dataset) + return dict_or_dataset + + +def perturbations(context, dates, members, center, remapping={}, patches={}): + members = load_if_needed(context, dates, members) + center = load_if_needed(context, dates, center) + return compute_perturbations(members, center) + + +execute = perturbations diff --git a/anemoi/datasets/create/functions/actions/source.py b/src/anemoi/datasets/create/functions/sources/source.py similarity index 100% rename from anemoi/datasets/create/functions/actions/source.py rename to src/anemoi/datasets/create/functions/sources/source.py diff --git a/anemoi/datasets/create/functions/actions/tendencies.py b/src/anemoi/datasets/create/functions/sources/tendencies.py similarity index 100% rename from anemoi/datasets/create/functions/actions/tendencies.py rename to src/anemoi/datasets/create/functions/sources/tendencies.py diff --git a/anemoi/datasets/create/input.py b/src/anemoi/datasets/create/input.py similarity index 98% rename from anemoi/datasets/create/input.py rename to src/anemoi/datasets/create/input.py index 68854666..ab08798d 100644 --- a/anemoi/datasets/create/input.py +++ b/src/anemoi/datasets/create/input.py @@ -561,7 +561,7 @@ def select(self, dates): @property def function(self): # name, delta = parse_function_name(self.name) - return import_function(self.name, "actions") + return import_function(self.name, "sources") def __repr__(self): content = "" @@ -825,7 +825,7 @@ def action_factory(config, context, action_path): }.get(key) if cls is None: - if not is_function(key, "actions"): + if not is_function(key, "sources"): raise ValueError(f"Unknown action '{key}' in {config}") cls = FunctionAction args = [key] + args @@ -869,21 +869,24 @@ def step_factory(config, context, action_path, previous_step): class FunctionContext: """A FunctionContext is passed to all functions, it will be used to pass information - to the functions from the other actions and filters and results.""" + to the functions from the other actions and filters and results. + """ def __init__(self, owner): self.owner = owner + self.use_grib_paramid = owner.context.use_grib_paramid def trace(self, emoji, *args): trace(emoji, *args) class ActionContext(Context): - def __init__(self, /, order_by, flatten_grid, remapping): + def __init__(self, /, order_by, flatten_grid, remapping, use_grib_paramid): super().__init__() self.order_by = order_by self.flatten_grid = flatten_grid self.remapping = build_remapping(remapping) + self.use_grib_paramid = use_grib_paramid class InputBuilder: diff --git a/anemoi/datasets/create/loaders.py b/src/anemoi/datasets/create/loaders.py similarity index 61% rename from anemoi/datasets/create/loaders.py rename to src/anemoi/datasets/create/loaders.py index f20f230c..780a078e 100644 --- a/anemoi/datasets/create/loaders.py +++ b/src/anemoi/datasets/create/loaders.py @@ -9,29 +9,33 @@ import os import time import uuid +import warnings from functools import cached_property import numpy as np import zarr +from anemoi.datasets import MissingDateError from anemoi.datasets import open_dataset +from anemoi.datasets.create.persistent import build_storage from anemoi.datasets.data.misc import as_first_date from anemoi.datasets.data.misc import as_last_date from anemoi.datasets.dates.groups import Groups from .check import DatasetName from .check import check_data_values +from .chunks import ChunkFilter from .config import build_output from .config import loader_config from .input import build_input -from .statistics import TempStatistics +from .statistics import Summary +from .statistics import TmpStatistics +from .statistics import check_variance from .statistics import compute_statistics -from .utils import bytes -from .utils import compute_directory_sizes +from .statistics import default_statistics_dates from .utils import normalize_and_check_dates from .utils import progress_bar from .utils import seconds -from .writer import CubesFilter from .writer import ViewCacheArray from .zarr import ZarrBuiltRegistry from .zarr import add_zarr_dataset @@ -41,49 +45,7 @@ VERSION = "0.20" -def default_statistics_dates(dates): - """ - Calculate default statistics dates based on the given list of dates. - - Args: - dates (list): List of datetime objects representing dates. - - Returns: - tuple: A tuple containing the default start and end dates. - """ - - def to_datetime(d): - if isinstance(d, np.datetime64): - return d.tolist() - assert isinstance(d, datetime.datetime), d - return d - - first = dates[0] - last = dates[-1] - - first = to_datetime(first) - last = to_datetime(last) - - n_years = round((last - first).total_seconds() / (365.25 * 24 * 60 * 60)) - - if n_years < 10: - # leave out 20% of the data - k = int(len(dates) * 0.8) - end = dates[k - 1] - LOG.info(f"Number of years {n_years} < 10, leaving out 20%. {end=}") - return dates[0], end - - delta = 1 - if n_years >= 20: - delta = 3 - LOG.info(f"Number of years {n_years}, leaving out {delta} years.") - end_year = last.year - delta - - end = max(d for d in dates if to_datetime(d).year == end_year) - return dates[0], end - - -class Loader: +class GenericDatasetHandler: def __init__(self, *, path, print=print, **kwargs): # Catch all floating point errors, including overflow, sqrt(<0), etc np.seterr(all="raise", under="warn") @@ -94,10 +56,6 @@ def __init__(self, *, path, print=print, **kwargs): self.kwargs = kwargs self.print = print - statistics_tmp = kwargs.get("statistics_tmp") or self.path + ".statistics" - - self.statistics_registry = TempStatistics(statistics_tmp) - @classmethod def from_config(cls, *, config, path, print=print, **kwargs): # config is the path to the config file or a dict with the config @@ -117,37 +75,6 @@ def from_dataset(cls, *, path, **kwargs): assert os.path.exists(path), f"Path {path} does not exist." return cls(path=path, **kwargs) - def build_input(self): - from climetlab.core.order import build_remapping - - builder = build_input( - self.main_config.input, - data_sources=self.main_config.get("data_sources", {}), - order_by=self.output.order_by, - flatten_grid=self.output.flatten_grid, - remapping=build_remapping(self.output.remapping), - ) - LOG.info("✅ INPUT_BUILDER") - LOG.info(builder) - return builder - - def build_statistics_dates(self, start, end): - ds = open_dataset(self.path) - dates = ds.dates - - default_start, default_end = default_statistics_dates(dates) - if start is None: - start = default_start - if end is None: - end = default_end - - start = as_first_date(start, dates) - end = as_last_date(end, dates) - - start = start.astype(datetime.datetime) - end = end.astype(datetime.datetime) - return (start.isoformat(), end.isoformat()) - def read_dataset_metadata(self): ds = open_dataset(self.path) self.dataset_shape = ds.shape @@ -155,21 +82,17 @@ def read_dataset_metadata(self): assert len(self.variables_names) == ds.shape[1], self.dataset_shape self.dates = ds.dates - z = zarr.open(self.path, "r") - self.missing_dates = z.attrs.get("missing_dates", []) - self.missing_dates = [np.datetime64(d) for d in self.missing_dates] + self.missing_dates = sorted(list([self.dates[i] for i in ds.missing])) - def allow_nan(self, name): - return name in self.main_config.statistics.get("allow_nans", []) + z = zarr.open(self.path, "r") + missing_dates = z.attrs.get("missing_dates", []) + missing_dates = sorted([np.datetime64(d) for d in missing_dates]) + assert missing_dates == self.missing_dates, (missing_dates, self.missing_dates) @cached_property def registry(self): return ZarrBuiltRegistry(self.path) - def initialise_dataset_backend(self): - z = zarr.open(self.path, mode="w") - z.create_group("_build") - def update_metadata(self, **kwargs): LOG.info(f"Updating metadata {kwargs}") z = zarr.open(self.path, mode="w+") @@ -196,12 +119,43 @@ def print_info(self): LOG.info(e) -class InitialiseLoader(Loader): +class DatasetHandler(GenericDatasetHandler): + pass + + +class DatasetHandlerWithStatistics(GenericDatasetHandler): + def __init__(self, statistics_tmp=None, **kwargs): + super().__init__(**kwargs) + statistics_tmp = kwargs.get("statistics_tmp") or os.path.join(self.path + ".tmp_data", "statistics") + self.tmp_statistics = TmpStatistics(statistics_tmp) + + +class Loader(DatasetHandlerWithStatistics): + def build_input(self): + from climetlab.core.order import build_remapping + + builder = build_input( + self.main_config.input, + data_sources=self.main_config.get("data_sources", {}), + order_by=self.output.order_by, + flatten_grid=self.output.flatten_grid, + remapping=build_remapping(self.output.remapping), + use_grib_paramid=self.main_config.build.use_grib_paramid, + ) + LOG.info("✅ INPUT_BUILDER") + LOG.info(builder) + return builder + + def allow_nan(self, name): + return name in self.main_config.statistics.get("allow_nans", []) + + +class InitialiserLoader(Loader): def __init__(self, config, **kwargs): super().__init__(**kwargs) self.main_config = loader_config(config) - self.statistics_registry.delete() + self.tmp_statistics.delete() LOG.info(self.main_config.dates) self.groups = Groups(**self.main_config.dates) @@ -217,6 +171,27 @@ def __init__(self, config, **kwargs): LOG.info("MINIMAL INPUT :") LOG.info(self.minimal_input) + def build_statistics_dates(self, start, end): + ds = open_dataset(self.path) + dates = ds.dates + + default_start, default_end = default_statistics_dates(dates) + if start is None: + start = default_start + if end is None: + end = default_end + + start = as_first_date(start, dates) + end = as_last_date(end, dates) + + start = start.astype(datetime.datetime) + end = end.astype(datetime.datetime) + return (start.isoformat(), end.isoformat()) + + def initialise_dataset_backend(self): + z = zarr.open(self.path, mode="w") + z.create_group("_build") + def initialise(self, check_name=True): """Create empty dataset.""" @@ -330,8 +305,8 @@ def initialise(self, check_name=True): self._add_dataset(name="longitudes", array=grid_points[1]) self.registry.create(lengths=lengths) - self.statistics_registry.create(exist_ok=False) - self.registry.add_to_history("statistics_registry_initialised", version=self.statistics_registry.version) + self.tmp_statistics.create(exist_ok=False) + self.registry.add_to_history("tmp_statistics_initialised", version=self.tmp_statistics.version) statistics_start, statistics_end = self.build_statistics_dates( self.main_config.statistics.get("start"), @@ -360,7 +335,7 @@ def __init__(self, config, parts, **kwargs): self.parts = parts total = len(self.registry.get_flags()) - self.cube_filter = CubesFilter(parts=self.parts, total=total) + self.chunk_filter = ChunkFilter(parts=self.parts, total=total) self.data_array = zarr.open(self.path, mode="r+")["data"] self.n_groups = len(self.groups) @@ -369,12 +344,12 @@ def load(self): self.registry.add_to_history("loading_data_start", parts=self.parts) for igroup, group in enumerate(self.groups): - if not self.cube_filter(igroup): + if not self.chunk_filter(igroup): continue if self.registry.get_flag(igroup): LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)") continue - self.print(f" -> Processing {igroup} total={len(self.groups)}") + # self.print(f" -> Processing {igroup} total={len(self.groups)}") # print("========", group) assert isinstance(group[0], datetime.datetime), group @@ -392,7 +367,7 @@ def load(self): self.registry.add_to_history("loading_data_end", parts=self.parts) self.registry.add_provenance(name="provenance_load") - self.statistics_registry.add_provenance(name="provenance_load", config=self.main_config) + self.tmp_statistics.add_provenance(name="provenance_load", config=self.main_config) self.print_info() @@ -430,7 +405,7 @@ def dates_to_indexes(dates, all_dates): self.load_cube(cube, array) stats = compute_statistics(array.cache, self.variables_names, allow_nan=self.allow_nan) - self.statistics_registry.write(indexes, stats, dates=dates_in_data) + self.tmp_statistics.write(indexes, stats, dates=dates_in_data) array.flush() @@ -476,16 +451,12 @@ def load_cube(self, cube, array): LOG.info(msg) -class StatisticsLoader(Loader): - main_config = {} - +class StatisticsAdder(DatasetHandlerWithStatistics): def __init__( self, - config=None, statistics_output=None, statistics_start=None, statistics_end=None, - force=False, **kwargs, ): super().__init__(**kwargs) @@ -499,11 +470,16 @@ def __init__( "-": self.write_stats_to_stdout, }.get(self.statistics_output, self.write_stats_to_file) - if config: - self.main_config = loader_config(config) - self.read_dataset_metadata() + def allow_nan(self, name): + z = zarr.open(self.path, mode="r") + if "variables_with_nans" in z.attrs: + return name in z.attrs["variables_with_nans"] + + warnings.warn(f"Cannot find 'variables_with_nans' in {self.path}. Assuming nans allowed for {name}.") + return True + def _get_statistics_dates(self): dates = self.dates dtype = type(dates[0]) @@ -513,10 +489,7 @@ def assert_dtype(d): # remove missing dates if self.missing_dates: - assert type(self.missing_dates[0]) is dtype, ( - type(self.missing_dates[0]), - dtype, - ) + assert_dtype(self.missing_dates[0]) dates = [d for d in dates if d not in self.missing_dates] # filter dates according the the start and end dates in the metadata @@ -543,11 +516,11 @@ def assert_dtype(d): def run(self): dates = self._get_statistics_dates() - stats = self.statistics_registry.get_aggregated(dates, self.variables_names, self.allow_nan) + stats = self.tmp_statistics.get_aggregated(dates, self.variables_names, self.allow_nan) self.output_writer(stats) def write_stats_to_file(self, stats): - stats.save(self.statistics_output, provenance=dict(config=self.main_config)) + stats.save(self.statistics_output) LOG.info(f"✅ Statistics written in {self.statistics_output}") def write_stats_to_dataset(self, stats): @@ -572,24 +545,232 @@ def write_stats_to_stdout(self, stats): LOG.info(stats) -class SizeLoader(Loader): - def __init__(self, path, print): - self.path = path - self.print = print +class GenericAdditions(GenericDatasetHandler): + def __init__(self, name="", **kwargs): + super().__init__(**kwargs) + self.name = name - def add_total_size(self): - dic = compute_directory_sizes(self.path) + storage_path = f"{self.path}.tmp_storage_{name}" + self.tmp_storage = build_storage(directory=storage_path, create=True) - size = dic["total_size"] - n = dic["total_number_of_files"] + def initialise(self): + self.tmp_storage.delete() + self.tmp_storage.create() + LOG.info(f"Dataset {self.path} additions initialized.") - LOG.info(f"Total size: {bytes(size)}") - LOG.info(f"Total number of files: {n}") + @cached_property + def _variables_with_nans(self): + z = zarr.open(self.path, mode="r") + if "variables_with_nans" in z.attrs: + return z.attrs["variables_with_nans"] + return None - self.update_metadata(total_size=size, total_number_of_files=n) + def allow_nan(self, name): + if self._variables_with_nans is not None: + return name in self._variables_with_nans + warnings.warn(f"❗Cannot find 'variables_with_nans' in {self.path}, Assuming nans allowed for {name}.") + return True + @classmethod + def _check_type_equal(cls, a, b): + a = list(a) + b = list(b) + a = a[0] if a else None + b = b[0] if b else None + assert type(a) is type(b), (type(a), type(b)) + + def finalise(self): + shape = (len(self.dates), len(self.variables)) + agg = dict( + minimum=np.full(shape, np.nan, dtype=np.float64), + maximum=np.full(shape, np.nan, dtype=np.float64), + sums=np.full(shape, np.nan, dtype=np.float64), + squares=np.full(shape, np.nan, dtype=np.float64), + count=np.full(shape, -1, dtype=np.int64), + has_nans=np.full(shape, False, dtype=np.bool_), + ) + LOG.info(f"Aggregating {self.name} statistics on shape={shape}. Variables : {self.variables}") + + found = set() + ifound = set() + missing = set() + for _date, (date, i, stats) in self.tmp_storage.items(): + assert _date == date + if stats == "missing": + missing.add(date) + continue -class CleanupLoader(Loader): - def run(self): - self.statistics_registry.delete() - self.registry.clean() + assert date not in found, f"Duplicates found {date}" + found.add(date) + ifound.add(i) + + for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: + agg[k][i, ...] = stats[k] + + assert len(found) + len(missing) == len(self.dates), (len(found), len(missing), len(self.dates)) + assert found.union(missing) == set(self.dates), (found, missing, set(self.dates)) + + mask = sorted(list(ifound)) + for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: + agg[k] = agg[k][mask, ...] + + for k in ["minimum", "maximum", "sums", "squares", "count", "has_nans"]: + assert agg[k].shape == agg["count"].shape, (agg[k].shape, agg["count"].shape) + + minimum = np.nanmin(agg["minimum"], axis=0) + maximum = np.nanmax(agg["maximum"], axis=0) + sums = np.nansum(agg["sums"], axis=0) + squares = np.nansum(agg["squares"], axis=0) + count = np.nansum(agg["count"], axis=0) + has_nans = np.any(agg["has_nans"], axis=0) + + assert sums.shape == count.shape + assert sums.shape == squares.shape + assert sums.shape == minimum.shape + assert sums.shape == maximum.shape + assert sums.shape == has_nans.shape + + mean = sums / count + assert sums.shape == mean.shape + + x = squares / count - mean * mean + # remove negative variance due to numerical errors + # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0 + check_variance(x, self.variables, minimum, maximum, mean, count, sums, squares) + + stdev = np.sqrt(x) + assert sums.shape == stdev.shape + + self.summary = Summary( + minimum=minimum, + maximum=maximum, + mean=mean, + count=count, + sums=sums, + squares=squares, + stdev=stdev, + variables_names=self.variables, + has_nans=has_nans, + ) + LOG.info(f"Dataset {self.path} additions finalized.") + self.check_statistics() + self._write(self.summary) + self.tmp_storage.delete() + + def _write(self, summary): + for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: + self._add_dataset(name=k, array=summary[k]) + self.registry.add_to_history("compute_statistics_end") + LOG.info(f"Wrote {self.name} additions in {self.path}") + + def check_statistics(self): + pass + + +class StatisticsAddition(GenericAdditions): + def __init__(self, **kwargs): + super().__init__("statistics_", **kwargs) + + z = zarr.open(self.path, mode="r") + start = z.attrs["statistics_start_date"] + end = z.attrs["statistics_end_date"] + self.ds = open_dataset(self.path, start=start, end=end) + + self.variables = self.ds.variables + self.dates = self.ds.dates + + assert len(self.variables) == self.ds.shape[1], self.ds.shape + self.total = len(self.dates) + + def run(self, parts): + chunk_filter = ChunkFilter(parts=parts, total=self.total) + for i in range(0, self.total): + if not chunk_filter(i): + continue + date = self.dates[i] + try: + arr = self.ds[i : i + 1, ...] + stats = compute_statistics(arr, self.variables, allow_nan=self.allow_nan) + self.tmp_storage.add([date, i, stats], key=date) + except MissingDateError: + self.tmp_storage.add([date, i, "missing"], key=date) + self.tmp_storage.flush() + LOG.info(f"Dataset {self.path} additions run.") + + def check_statistics(self): + ds = open_dataset(self.path) + ref = ds.statistics + for k in ds.statistics: + assert np.all(np.isclose(ref[k], self.summary[k], rtol=1e-4, atol=1e-4)), ( + k, + ref[k], + self.summary[k], + ) + + +class DeltaDataset: + def __init__(self, ds, idelta): + self.ds = ds + self.idelta = idelta + + def __getitem__(self, i): + j = i - self.idelta + if j < 0: + raise MissingDateError(f"Missing date {j}") + return self.ds[i : i + 1, ...] - self.ds[j : j + 1, ...] + + +class TendenciesStatisticsDeltaNotMultipleOfFrequency(ValueError): + pass + + +class TendenciesStatisticsAddition(GenericAdditions): + DATASET_NAME_PATTERN = "statistics_tendencies_{delta}" + + def __init__(self, path, delta=None, **kwargs): + full_ds = open_dataset(path) + self.variables = full_ds.variables + + frequency = full_ds.frequency + if delta is None: + delta = frequency + assert isinstance(delta, int), delta + if not delta % frequency == 0: + raise TendenciesStatisticsDeltaNotMultipleOfFrequency( + f"Delta {delta} is not a multiple of frequency {frequency}" + ) + idelta = delta // frequency + + super().__init__(path=path, name=self.DATASET_NAME_PATTERN.format(delta=f"{delta}h"), **kwargs) + + z = zarr.open(self.path, mode="r") + start = z.attrs["statistics_start_date"] + end = z.attrs["statistics_end_date"] + start = datetime.datetime.fromisoformat(start) + ds = open_dataset(self.path, start=start + datetime.timedelta(hours=delta), end=end) + self.dates = ds.dates + self.total = len(self.dates) + + ds = open_dataset(self.path, start=start, end=end) + self.ds = DeltaDataset(ds, idelta) + + def run(self, parts): + chunk_filter = ChunkFilter(parts=parts, total=self.total) + for i in range(0, self.total): + if not chunk_filter(i): + continue + date = self.dates[i] + try: + arr = self.ds[i] + stats = compute_statistics(arr, self.variables, allow_nan=self.allow_nan) + self.tmp_storage.add([date, i, stats], key=date) + except MissingDateError: + self.tmp_storage.add([date, i, "missing"], key=date) + self.tmp_storage.flush() + LOG.info(f"Dataset {self.path} additions run.") + + def _write(self, summary): + for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: + self._add_dataset(name=f"{self.name}_{k}", array=summary[k]) + self.registry.add_to_history(f"compute_{self.name}_end") + LOG.info(f"Wrote {self.name} additions in {self.path}") diff --git a/anemoi/datasets/create/patch.py b/src/anemoi/datasets/create/patch.py similarity index 100% rename from anemoi/datasets/create/patch.py rename to src/anemoi/datasets/create/patch.py diff --git a/src/anemoi/datasets/create/persistent.py b/src/anemoi/datasets/create/persistent.py new file mode 100644 index 00000000..29950553 --- /dev/null +++ b/src/anemoi/datasets/create/persistent.py @@ -0,0 +1,152 @@ +# (C) Copyright 2023 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import glob +import hashlib +import json +import logging +import os +import pickle +import shutil +import socket + +import numpy as np +from anemoi.utils.provenance import gather_provenance_info + +LOG = logging.getLogger(__name__) + + +class PersistentDict: + version = 3 + + # Used in parrallel, during data loading, + # to write data in pickle files. + def __init__(self, directory, create=True): + """dirname: str + The directory where the data will be stored. + """ + self.dirname = directory + self.name, self.ext = os.path.splitext(os.path.basename(self.dirname)) + if create: + self.create() + + def create(self): + os.makedirs(self.dirname, exist_ok=True) + + def delete(self): + try: + shutil.rmtree(self.dirname) + except FileNotFoundError: + pass + + def __str__(self): + return f"{self.__class__.__name__}({self.dirname})" + + def items(self): + # use glob to read all pickles + files = glob.glob(self.dirname + "/*.pickle") + LOG.info(f"Reading {self.name} data, found {len(files)} files in {self.dirname}") + assert len(files) > 0, f"No files found in {self.dirname}" + for f in files: + with open(f, "rb") as f: + yield pickle.load(f) + + def add_provenance(self, **kwargs): + out = dict(provenance=gather_provenance_info(), **kwargs) + with open(os.path.join(self.dirname, "provenance.json"), "w") as f: + json.dump(out, f) + + def add(self, elt, *, key): + self[key] = elt + + def __setitem__(self, key, elt): + h = hashlib.sha256(str(key).encode("utf-8")).hexdigest() + path = os.path.join(self.dirname, f"{h}.pickle") + + if os.path.exists(path): + LOG.warn(f"{path} already exists") + + tmp_path = path + f".tmp-{os.getpid()}-on-{socket.gethostname()}" + with open(tmp_path, "wb") as f: + pickle.dump((key, elt), f) + shutil.move(tmp_path, path) + + LOG.debug(f"Written {self.name} data for len {key} in {path}") + + def flush(self): + pass + + +class BufferedPersistentDict(PersistentDict): + def __init__(self, buffer_size=1000, **kwargs): + self.buffer_size = buffer_size + self.elements = [] + self.keys = [] + self.storage = PersistentDict(**kwargs) + + def add(self, elt, *, key): + self.elements.append(elt) + self.keys.append(key) + if len(self.keys) > self.buffer_size: + self.flush() + + def flush(self): + k = sorted(self.keys) + self.storage.add(self.elements, key=k) + self.elements = [] + self.keys = [] + + def items(self): + for keys, elements in self.storage.items(): + for key, elt in zip(keys, elements): + yield key, elt + + def delete(self): + self.storage.delete() + + def create(self): + self.storage.create() + + +def build_storage(directory, create=True): + return BufferedPersistentDict(directory=directory, create=create) + + +if __name__ == "__main__": + N = 3 + P = 2 + directory = "h" + p = PersistentDict(directory=directory) + print(p) + assert os.path.exists(directory) + import numpy as np + + arrs = [np.random.randint(1, 101, size=(P,)) for _ in range(N)] + dates = [np.array([np.datetime64(f"2021-01-0{_+1}") + np.timedelta64(i, "h") for i in range(P)]) for _ in range(N)] + + print() + print("Writing the data") + for i in range(N): + _arr = arrs[i] + _dates = dates[i] + print(f"Writing : {i=}, {_arr=} {_dates=}") + p[_dates] = (i, _arr) + + print() + print("Reading the data back") + + p = PersistentDict(directory="h") + for _dates, (i, _arr) in p.items(): + print(f"{i=}, {_arr=}, {_dates=}") + + assert np.allclose(_arr, arrs[i]) + + assert len(_dates) == len(dates[i]) + for a, b in zip(_dates, dates[i]): + assert a == b diff --git a/src/anemoi/datasets/create/size.py b/src/anemoi/datasets/create/size.py new file mode 100644 index 00000000..1671290f --- /dev/null +++ b/src/anemoi/datasets/create/size.py @@ -0,0 +1,33 @@ +# (C) Copyright 2023 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +import os + +from anemoi.datasets.create.utils import progress_bar + +LOG = logging.getLogger(__name__) + + +def compute_directory_sizes(path): + if not os.path.isdir(path): + return None + + size, n = 0, 0 + bar = progress_bar(iterable=os.walk(path), desc=f"Computing size of {path}") + for dirpath, _, filenames in bar: + for filename in filenames: + file_path = os.path.join(dirpath, filename) + size += os.path.getsize(file_path) + n += 1 + + LOG.info(f"Total size: {bytes(size)}") + LOG.info(f"Total number of files: {n}") + + return dict(total_size=size, total_number_of_files=n) diff --git a/anemoi/datasets/create/statistics.py b/src/anemoi/datasets/create/statistics/__init__.py similarity index 61% rename from anemoi/datasets/create/statistics.py rename to src/anemoi/datasets/create/statistics/__init__.py index 9abe2e7a..cc3d8095 100644 --- a/anemoi/datasets/create/statistics.py +++ b/src/anemoi/datasets/create/statistics/__init__.py @@ -15,18 +15,58 @@ import pickle import shutil import socket -from collections import defaultdict import numpy as np from anemoi.utils.provenance import gather_provenance_info -from .check import StatisticsValueError -from .check import check_data_values -from .check import check_stats +from ..check import check_data_values +from .summary import Summary LOG = logging.getLogger(__name__) +def default_statistics_dates(dates): + """ + Calculate default statistics dates based on the given list of dates. + + Args: + dates (list): List of datetime objects representing dates. + + Returns: + tuple: A tuple containing the default start and end dates. + """ + + def to_datetime(d): + if isinstance(d, np.datetime64): + return d.tolist() + assert isinstance(d, datetime.datetime), d + return d + + first = dates[0] + last = dates[-1] + + first = to_datetime(first) + last = to_datetime(last) + + n_years = round((last - first).total_seconds() / (365.25 * 24 * 60 * 60)) + + if n_years < 10: + # leave out 20% of the data + k = int(len(dates) * 0.8) + end = dates[k - 1] + LOG.info(f"Number of years {n_years} < 10, leaving out 20%. {end=}") + return dates[0], end + + delta = 1 + if n_years >= 20: + delta = 3 + LOG.info(f"Number of years {n_years}, leaving out {delta} years.") + end_year = last.year - delta + + end = max(d for d in dates if to_datetime(d).year == end_year) + return dates[0], end + + def to_datetime(date): if isinstance(date, str): return np.datetime64(date) @@ -45,11 +85,12 @@ def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squa print(x) print(variables_names) print(count) - for i, (var, y) in enumerate(zip(variables_names, x)): + for i, (name, y) in enumerate(zip(variables_names, x)): if y >= 0: continue + print("---") print( - var, + name, y, maximum[i], minimum[i], @@ -59,9 +100,9 @@ def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squa squares[i], ) - print(var, np.min(sums[i]), np.max(sums[i]), np.argmin(sums[i])) - print(var, np.min(squares[i]), np.max(squares[i]), np.argmin(squares[i])) - print(var, np.min(count[i]), np.max(count[i]), np.argmin(count[i])) + print(name, np.min(sums[i]), np.max(sums[i]), np.argmin(sums[i])) + print(name, np.min(squares[i]), np.max(squares[i]), np.argmin(squares[i])) + print(name, np.min(count[i]), np.max(count[i]), np.argmin(count[i])) raise ValueError("Negative variance") @@ -69,7 +110,7 @@ def check_variance(x, variables_names, minimum, maximum, mean, count, sums, squa def compute_statistics(array, check_variables_names=None, allow_nan=False): nvars = array.shape[1] - LOG.info("Stats %s", (nvars, array.shape, check_variables_names)) + LOG.info(f"Stats {nvars}, {array.shape}, {check_variables_names}") if check_variables_names: assert nvars == len(check_variables_names), (nvars, check_variables_names) stats_shape = (array.shape[0], nvars) @@ -108,7 +149,7 @@ def compute_statistics(array, check_variables_names=None, allow_nan=False): } -class TempStatistics: +class TmpStatistics: version = 3 # Used in parrallel, during data loading, # to write statistics in pickled npz files. @@ -162,7 +203,7 @@ def get_aggregated(self, *args, **kwargs): return aggregator.aggregate() def __str__(self): - return f"TempStatistics({self.dirname})" + return f"TmpStatistics({self.dirname})" def normalise_date(d): @@ -289,7 +330,7 @@ def aggregate(self): allow_nan=False, ) - return Statistics( + return Summary( minimum=minimum, maximum=maximum, mean=mean, @@ -302,82 +343,119 @@ def aggregate(self): ) -class Statistics(dict): - STATS_NAMES = ["minimum", "maximum", "mean", "stdev", "has_nans"] # order matter for __str__. +class SummaryAggregator: + NAMES = ["minimum", "maximum", "sums", "squares", "count", "has_nans"] - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.check() + def __init__(self, owner, dates, variables_names, allow_nan): + dates = sorted(dates) + dates = to_datetimes(dates) + assert dates, "No dates selected" + self.owner = owner + self.dates = dates + self.variables_names = variables_names + self.allow_nan = allow_nan - @property - def size(self): - return len(self["variables_names"]) + self.shape = (len(self.dates), len(self.variables_names)) + LOG.info(f"Aggregating statistics on shape={self.shape}. Variables : {self.variables_names}") - def check(self): - for k, v in self.items(): - if k == "variables_names": - assert len(v) == self.size - continue - assert v.shape == (self.size,) - if k == "count": - assert (v >= 0).all(), (k, v) - assert v.dtype == np.int64, (k, v) - continue - if k == "has_nans": - assert v.dtype == np.bool_, (k, v) - continue - if k == "stdev": - assert (v >= 0).all(), (k, v) - assert v.dtype == np.float64, (k, v) - - for i, name in enumerate(self["variables_names"]): - try: - check_stats(**{k: v[i] for k, v in self.items()}, msg=f"{i} {name}") - check_data_values(self["minimum"][i], name=name) - check_data_values(self["maximum"][i], name=name) - check_data_values(self["mean"][i], name=name) - except StatisticsValueError as e: - e.args += (i, name) - raise + self.minimum = np.full(self.shape, np.nan, dtype=np.float64) + self.maximum = np.full(self.shape, np.nan, dtype=np.float64) + self.sums = np.full(self.shape, np.nan, dtype=np.float64) + self.squares = np.full(self.shape, np.nan, dtype=np.float64) + self.count = np.full(self.shape, -1, dtype=np.int64) + self.has_nans = np.full(self.shape, False, dtype=np.bool_) - def __str__(self): - header = ["Variables"] + self.STATS_NAMES - out = [" ".join(header)] - - out += [ - " ".join([v] + [f"{self[n][i]:.2f}" for n in self.STATS_NAMES]) - for i, v in enumerate(self["variables_names"]) - ] - return "\n".join(out) - - def save(self, filename, provenance=None): - assert filename.endswith(".json"), filename - dic = {} - for k in self.STATS_NAMES: - dic[k] = list(self[k]) - - out = dict(data=defaultdict(dict)) - for i, name in enumerate(self["variables_names"]): - for k in self.STATS_NAMES: - out["data"][name][k] = dic[k][i] - - out["provenance"] = provenance - - with open(filename, "w") as f: - json.dump(out, f, indent=2) - - def load(self, filename): - assert filename.endswith(".json"), filename - with open(filename) as f: - dic = json.load(f) - - dic_ = {} - for k, v in dic.items(): - if k == "count": - dic_[k] = np.array(v, dtype=np.int64) - continue - if k == "variables": - dic_[k] = v + self._read() + + def _read(self): + def check_type(a, b): + a = list(a) + b = list(b) + a = a[0] if a else None + b = b[0] if b else None + assert type(a) is type(b), (type(a), type(b)) + + found = set() + offset = 0 + for _, _dates, stats in self.owner._gather_data(): + for n in self.NAMES: + assert n in stats, (n, list(stats.keys())) + _dates = to_datetimes(_dates) + check_type(_dates, self.dates) + if found: + check_type(found, self.dates) + assert found.isdisjoint(_dates), "Duplicate dates found in precomputed statistics" + + # filter dates + dates = set(_dates) & set(self.dates) + + if not dates: + # dates have been completely filtered for this chunk continue - dic_[k] = np.array(v, dtype=np.float64) - return Statistics(dic_) + + # filter data + bitmap = np.isin(_dates, self.dates) + for k in self.NAMES: + stats[k] = stats[k][bitmap] + + assert stats["minimum"].shape[0] == len(dates), ( + stats["minimum"].shape, + len(dates), + ) + + # store data in self + found |= set(dates) + for name in self.NAMES: + array = getattr(self, name) + assert stats[name].shape[0] == len(dates), ( + stats[name].shape, + len(dates), + ) + array[offset : offset + len(dates)] = stats[name] + offset += len(dates) + + for d in self.dates: + assert d in found, f"Statistics for date {d} not precomputed." + assert len(self.dates) == len(found), "Not all dates found in precomputed statistics" + assert len(self.dates) == offset, "Not all dates found in precomputed statistics." + LOG.info(f"Statistics for {len(found)} dates found.") + + def aggregate(self): + minimum = np.nanmin(self.minimum, axis=0) + maximum = np.nanmax(self.maximum, axis=0) + sums = np.nansum(self.sums, axis=0) + squares = np.nansum(self.squares, axis=0) + count = np.nansum(self.count, axis=0) + has_nans = np.any(self.has_nans, axis=0) + mean = sums / count + + assert sums.shape == count.shape == squares.shape == mean.shape == minimum.shape == maximum.shape + + x = squares / count - mean * mean + # remove negative variance due to numerical errors + # x[- 1e-15 < (x / (np.sqrt(squares / count) + np.abs(mean))) < 0] = 0 + check_variance(x, self.variables_names, minimum, maximum, mean, count, sums, squares) + stdev = np.sqrt(x) + + for j, name in enumerate(self.variables_names): + check_data_values( + np.array( + [ + mean[j], + ] + ), + name=name, + allow_nan=False, + ) + + return Summary( + minimum=minimum, + maximum=maximum, + mean=mean, + count=count, + sums=sums, + squares=squares, + stdev=stdev, + variables_names=self.variables_names, + has_nans=has_nans, + ) diff --git a/src/anemoi/datasets/create/statistics/summary.py b/src/anemoi/datasets/create/statistics/summary.py new file mode 100644 index 00000000..1434688e --- /dev/null +++ b/src/anemoi/datasets/create/statistics/summary.py @@ -0,0 +1,108 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import json +from collections import defaultdict + +import numpy as np + +from ..check import StatisticsValueError +from ..check import check_data_values +from ..check import check_stats + + +class Summary(dict): + """This class is used to store the summary statistics of a dataset. + It can be saved and loaded from a json file. + And does some basic checks on the data. + """ + + STATS_NAMES = [ + "minimum", + "maximum", + "mean", + "stdev", + "has_nans", + ] # order matter for __str__. + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.check() + + @property + def size(self): + return len(self["variables_names"]) + + def check(self): + for k, v in self.items(): + if k == "variables_names": + assert len(v) == self.size + continue + assert v.shape == (self.size,) + if k == "count": + assert (v >= 0).all(), (k, v) + assert v.dtype == np.int64, (k, v) + continue + if k == "has_nans": + assert v.dtype == np.bool_, (k, v) + continue + if k == "stdev": + assert (v >= 0).all(), (k, v) + assert v.dtype == np.float64, (k, v) + + for i, name in enumerate(self["variables_names"]): + try: + check_stats(**{k: v[i] for k, v in self.items()}, msg=f"{i} {name}") + check_data_values(self["minimum"][i], name=name) + check_data_values(self["maximum"][i], name=name) + check_data_values(self["mean"][i], name=name) + except StatisticsValueError as e: + e.args += (i, name) + raise + + def __str__(self): + header = ["Variables"] + self.STATS_NAMES + out = [" ".join(header)] + + out += [ + " ".join([v] + [f"{self[n][i]:.2f}" for n in self.STATS_NAMES]) + for i, v in enumerate(self["variables_names"]) + ] + return "\n".join(out) + + def save(self, filename, **metadata): + assert filename.endswith(".json"), filename + dic = {} + for k in self.STATS_NAMES: + dic[k] = list(self[k]) + + out = dict(data=defaultdict(dict)) + for i, name in enumerate(self["variables_names"]): + for k in self.STATS_NAMES: + out["data"][name][k] = dic[k][i] + + out["metadata"] = metadata + + with open(filename, "w") as f: + json.dump(out, f, indent=2) + + def load(self, filename): + assert filename.endswith(".json"), filename + with open(filename) as f: + dic = json.load(f) + + dic_ = {} + for k, v in dic.items(): + if k == "count": + dic_[k] = np.array(v, dtype=np.int64) + continue + if k == "variables": + dic_[k] = v + continue + dic_[k] = np.array(v, dtype=np.float64) + return Summary(dic_) diff --git a/anemoi/datasets/create/template.py b/src/anemoi/datasets/create/template.py similarity index 100% rename from anemoi/datasets/create/template.py rename to src/anemoi/datasets/create/template.py diff --git a/anemoi/datasets/create/utils.py b/src/anemoi/datasets/create/utils.py similarity index 88% rename from anemoi/datasets/create/utils.py rename to src/anemoi/datasets/create/utils.py index 8579038f..71c273b6 100644 --- a/anemoi/datasets/create/utils.py +++ b/src/anemoi/datasets/create/utils.py @@ -31,8 +31,7 @@ def no_cache_context(): def bytes(n): - """ - >>> bytes(4096) + """>>> bytes(4096) '4 KiB' >>> bytes(4000) '3.9 KiB' @@ -72,21 +71,6 @@ def load_json_or_yaml(path): raise ValueError(f"Cannot read file {path}. Need json or yaml with appropriate extension.") -def compute_directory_sizes(path): - if not os.path.isdir(path): - return None - - size, n = 0, 0 - bar = progress_bar(iterable=os.walk(path), desc=f"Computing size of {path}") - for dirpath, _, filenames in bar: - for filename in filenames: - file_path = os.path.join(dirpath, filename) - size += os.path.getsize(file_path) - n += 1 - - return dict(total_size=size, total_number_of_files=n) - - def make_list_int(value): if isinstance(value, str): if "/" not in value: diff --git a/anemoi/datasets/create/writer.py b/src/anemoi/datasets/create/writer.py similarity index 51% rename from anemoi/datasets/create/writer.py rename to src/anemoi/datasets/create/writer.py index 2182f8af..3117f5d1 100644 --- a/anemoi/datasets/create/writer.py +++ b/src/anemoi/datasets/create/writer.py @@ -8,52 +8,12 @@ # import logging -import warnings import numpy as np LOG = logging.getLogger(__name__) -class CubesFilter: - def __init__(self, *, parts, total): - if parts is None: - self.parts = None - return - - if len(parts) == 1: - part = parts[0] - if part.lower() in ["all", "*"]: - self.parts = None - return - - if "/" in part: - i_chunk, n_chunks = part.split("/") - i_chunk, n_chunks = int(i_chunk), int(n_chunks) - - assert i_chunk > 0, f"Chunk number {i_chunk} must be positive." - if n_chunks > total: - warnings.warn( - f"Number of chunks {n_chunks} is larger than the total number of chunks: {total}+1. " - "Some chunks will be empty." - ) - - chunk_size = total / n_chunks - parts = [x for x in range(total) if x >= (i_chunk - 1) * chunk_size and x < i_chunk * chunk_size] - - parts = [int(_) for _ in parts] - LOG.info(f"Running parts: {parts}") - if not parts: - warnings.warn(f"Nothing to do for chunk {i_chunk}/{n_chunks}.") - - self.parts = parts - - def __call__(self, i): - if self.parts is None: - return True - return i in self.parts - - class ViewCacheArray: """A class that provides a caching mechanism for writing to a NumPy-like array. diff --git a/anemoi/datasets/create/zarr.py b/src/anemoi/datasets/create/zarr.py similarity index 100% rename from anemoi/datasets/create/zarr.py rename to src/anemoi/datasets/create/zarr.py diff --git a/anemoi/datasets/data/__init__.py b/src/anemoi/datasets/data/__init__.py similarity index 100% rename from anemoi/datasets/data/__init__.py rename to src/anemoi/datasets/data/__init__.py diff --git a/anemoi/datasets/data/concat.py b/src/anemoi/datasets/data/concat.py similarity index 100% rename from anemoi/datasets/data/concat.py rename to src/anemoi/datasets/data/concat.py diff --git a/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py similarity index 98% rename from anemoi/datasets/data/dataset.py rename to src/anemoi/datasets/data/dataset.py index ae9f3e3f..879af680 100644 --- a/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -222,3 +222,7 @@ def _check(ds): def _repr_html_(self): return self.tree().html() + + @property + def label(self): + return self.__class__.__name__.lower() diff --git a/src/anemoi/datasets/data/debug.css b/src/anemoi/datasets/data/debug.css new file mode 100644 index 00000000..edde87d0 --- /dev/null +++ b/src/anemoi/datasets/data/debug.css @@ -0,0 +1,12 @@ +table.dataset td { + vertical-align: top; + text-align: left !important; +} + +table.dataset span.dataset { + font-weight: bold !important; +} + +table.dataset span.values { + font-style: italic !important; +} diff --git a/anemoi/datasets/data/debug.py b/src/anemoi/datasets/data/debug.py similarity index 87% rename from anemoi/datasets/data/debug.py rename to src/anemoi/datasets/data/debug.py index c03c08a7..98ea3d9d 100644 --- a/anemoi/datasets/data/debug.py +++ b/src/anemoi/datasets/data/debug.py @@ -21,6 +21,12 @@ # a.flags.writeable = False +def css(name): + path = os.path.join(os.path.dirname(__file__), f"{name}.css") + with open(path) as f: + return f"" + + class Node: def __init__(self, dataset, kids, **kwargs): self.dataset = dataset @@ -46,7 +52,7 @@ def __repr__(self): return "\n".join(result) def graph(self, digraph, nodes): - label = self.dataset.__class__.__name__.lower() + label = self.dataset.label # dataset.__class__.__name__.lower() if self.kwargs: param = [] for k, v in self.kwargs.items(): @@ -107,7 +113,7 @@ def _html(self, indent, rows): if k == "path": v = v[::-1] kwargs[k] = v - label = self.dataset.__class__.__name__.lower() + label = self.dataset.label label = f'{label}' if len(kwargs) == 1: k, v = list(kwargs.items())[0] @@ -116,25 +122,14 @@ def _html(self, indent, rows): rows.append([indent] + [label]) for k, v in kwargs.items(): - rows.append([indent] + [k, v]) + rows.append([indent] + [f"{k}", f"{v}"]) for kid in self.kids: kid._html(indent + " ", rows) def html(self): - result = [ - """ - - """ - ] + result = [css("debug")] + result.append('