diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 00000000..0211a4e3 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,6 @@ +# CODEOWNERS file + +# Protect workflow files +/.github/ @theissenhelen @jesperdramsch @gmertes @b8raoult @floriankrb +/.pre-commit-config.yaml @theissenhelen @jesperdramsch @gmertes @b8raoult @floriankrb +/pyproject.toml @theissenhelen @jesperdramsch @gmertes @b8raoult @floriankrb diff --git a/.github/ci-hpc-config.yml b/.github/ci-hpc-config.yml new file mode 100644 index 00000000..b6e65e42 --- /dev/null +++ b/.github/ci-hpc-config.yml @@ -0,0 +1,18 @@ +build: + python: '3.10' + modules: + - ninja + dependencies: + - ecmwf/ecbuild@develop + - ecmwf/eccodes@develop + - ecmwf/eckit@develop + - ecmwf/odc@develop + python_dependencies: + - ecmwf/anemoi-utils@develop + - ecmwf/earthkit-data@develop + - ecmwf/earthkit-meteo@develop + - ecmwf/earthkit-geo@develop + parallel: 64 + + pytest_cmd: | + python -m pytest -vv -m 'not notebook and not no_cache_init' --cov=. --cov-report=xml diff --git a/.github/workflows/changelog-pr-update.yml b/.github/workflows/changelog-pr-update.yml index 4bc51df6..73cb1ebf 100644 --- a/.github/workflows/changelog-pr-update.yml +++ b/.github/workflows/changelog-pr-update.yml @@ -5,6 +5,9 @@ on: branches: - main - develop + paths-ignore: + - .pre-commit-config.yaml + - .readthedocs.yaml jobs: Check-Changelog: name: Check Changelog Action diff --git a/.github/workflows/changelog-release-update.yml b/.github/workflows/changelog-release-update.yml index 79b85add..17d95250 100644 --- a/.github/workflows/changelog-release-update.yml +++ b/.github/workflows/changelog-release-update.yml @@ -4,6 +4,7 @@ name: "Update Changelog" on: release: types: [released] + workflow_dispatch: ~ permissions: pull-requests: write diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7dde189b..e9eb40fe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,9 +8,17 @@ on: - 'develop' tags-ignore: - '**' + paths-ignore: + - "docs/**" + - "CHANGELOG.md" + - "README.md" # Trigger the workflow on pull request - pull_request: ~ + pull_request: + paths-ignore: + - "docs/**" + - "CHANGELOG.md" + - "README.md" # Trigger the workflow manually workflow_dispatch: ~ @@ -18,6 +26,11 @@ on: # Trigger after public PR approved for CI pull_request_target: types: [labeled] + paths-ignore: + - "docs/**" + - "CHANGELOG.md" + - "README.md" + jobs: # Run CI including downstream packages on self-hosted runners @@ -34,7 +47,7 @@ jobs: downstream-ci-hpc: name: downstream-ci-hpc if: ${{ !github.event.pull_request.head.repo.fork && github.event.action != 'labeled' || github.event.label.name == 'approved-for-ci' }} - uses: ecmwf-actions/downstream-ci/.github/workflows/downstream-ci.yml@main + uses: ecmwf-actions/downstream-ci/.github/workflows/downstream-ci-hpc.yml@main with: anemoi-datasets: ecmwf/anemoi-datasets@${{ github.event.pull_request.head.sha || github.sha }} secrets: inherit diff --git a/.github/workflows/python-pull-request.yml b/.github/workflows/python-pull-request.yml index 0ebecb13..21852644 100644 --- a/.github/workflows/python-pull-request.yml +++ b/.github/workflows/python-pull-request.yml @@ -5,7 +5,7 @@ name: Code Quality checks for PRs on: push: - pull_request_target: + pull_request: types: [opened, synchronize, reopened] jobs: diff --git a/.gitignore b/.gitignore index eefa3c25..d7d615b1 100644 --- a/.gitignore +++ b/.gitignore @@ -120,6 +120,7 @@ celerybeat.pid *.sage.py # Environments +.envrc .env .venv env/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a6c923a7..6e4341a3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,6 +20,12 @@ repos: - id: no-commit-to-branch # Prevent committing to main / master - id: check-added-large-files # Check for large files added to git - id: check-merge-conflict # Check for files that contain merge conflict +- repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.10.0 # Use the ref you want to point at + hooks: + - id: python-use-type-annotations # Check for missing type annotations + - id: python-check-blanket-noqa # Check for # noqa: all + - id: python-no-log-warn # Check for log.warn - repo: https://github.com/psf/black-pre-commit-mirror rev: 24.8.0 hooks: @@ -37,15 +43,15 @@ repos: rev: v0.6.4 hooks: - id: ruff - # Next line is to exclude for documentation code snippets - exclude: 'docs/(.*/)?[a-z]\w+_.py$' + # Next line if for documenation cod snippets + exclude: '^[^_].*_\.py$' args: - --line-length=120 - --fix - --exit-non-zero-on-fix - --preview - repo: https://github.com/sphinx-contrib/sphinx-lint - rev: v0.9.1 + rev: v1.0.0 hooks: - id: sphinx-lint # For now, we use it. But it does not support a lot of sphinx features @@ -59,12 +65,10 @@ repos: hooks: - id: docconvert args: ["numpy"] -- repo: https://github.com/b8raoult/optional-dependencies-all - rev: "0.0.6" - hooks: - - id: optional-dependencies-all - args: ["--inplace", "--exclude-keys=dev,docs,tests", "--group=dev=all,docs,tests"] - repo: https://github.com/tox-dev/pyproject-fmt rev: "2.2.3" hooks: - id: pyproject-fmt + +ci: + autoupdate_schedule: monthly diff --git a/CHANGELOG.md b/CHANGELOG.md index e12e0d8f..ac86f1b2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,11 +11,21 @@ Keep it human-readable, your future self will thank you! ## [Unreleased] ### Added + +- New `rescale` keyword in `open_dataset` to change units of variables #36 +- Simplify imports + ### Changed - Added incremental building of datasets +- Add missing dependency for documentation building +- Fix failing test due to previous merge +- Bug fix when creating dataset from zarr +- Bug fix with area selection in cutout operation +- add paths-ignore to ci workflow ### Removed +- pytest for notebooks ## [0.4.5] @@ -25,6 +35,7 @@ Keep it human-readable, your future self will thank you! - CI workflow to update the changelog on release - adds the reusable cd pypi workflow - merge strategy for changelog in .gitattributes #25 +- adds ci hpc config (#43) ### Changed @@ -74,6 +85,9 @@ Keep it human-readable, your future self will thank you! - combine datasets ## Git Diffs: +[Unreleased]: https://github.com/ecmwf/anemoi-datasets/compare/0.4.5...HEAD +[0.4.5]: https://github.com/ecmwf/anemoi-datasets/compare/0.4.4...0.4.5 +[0.4.4]: https://github.com/ecmwf/anemoi-datasets/compare/0.4.0...0.4.4 [0.4.0]: https://github.com/ecmwf/anemoi-datasets/compare/0.3.0...0.4.0 [0.3.0]: https://github.com/ecmwf/anemoi-datasets/compare/0.2.0...0.3.0 [0.2.0]: https://github.com/ecmwf/anemoi-datasets/compare/0.1.0...0.2.0 diff --git a/docs/building/incremental.rst b/docs/building/incremental.rst index b93d0920..a42bef08 100644 --- a/docs/building/incremental.rst +++ b/docs/building/incremental.rst @@ -86,8 +86,8 @@ To add statistics for 6h increments: .. code:: bash - anemoi-datasets init-additions dataset.zarr --delta 6h anemoi-datasets - anemoi-datasets load-additions dataset.zarr --part 1/2 --delta 6h anemoi-datasets + anemoi-datasets init-additions dataset.zarr --delta 6h + anemoi-datasets load-additions dataset.zarr --part 1/2 --delta 6h anemoi-datasets load-additions dataset.zarr --part 2/2 --delta 6h anemoi-datasets finalise-additions dataset.zarr --delta 6h @@ -96,7 +96,7 @@ To add statistics for 12h increments: .. code:: bash anemoi-datasets init-additions dataset.zarr --delta 12h - anemoi-datasets load-additions dataset.zarr --part 1/2 --delta 12h anemoi-datasets + anemoi-datasets load-additions dataset.zarr --part 1/2 --delta 12h anemoi-datasets load-additions dataset.zarr --part 2/2 --delta 12h anemoi-datasets finalise-additions dataset.zarr --delta 12h diff --git a/docs/using/code/rescale_.py b/docs/using/code/rescale_.py new file mode 100644 index 00000000..d7746d9e --- /dev/null +++ b/docs/using/code/rescale_.py @@ -0,0 +1,30 @@ +# Scale and offset can be passed as a dictionnary... + +ds = open_dataset( + dataset, + rescale={"2t": {"scale": 1.0, "offset": -273.15}}, +) + +# ... a tuple of floating points .... + +ds = open_dataset( + dataset, + rescale={"2t": (1.0, -273.15)}, +) + +# ... or a tuple of strings representing units. + +ds = open_dataset( + dataset, + rescale={"2t": ("K", "degC")}, +) + +# Several variables can be rescaled at once. + +ds = open_dataset( + dataset, + rescale={ + "2t": ("K", "degC"), + "tp": ("m", "mm"), + }, +) diff --git a/docs/using/selecting.rst b/docs/using/selecting.rst index dc27dcbd..ef21d9a3 100644 --- a/docs/using/selecting.rst +++ b/docs/using/selecting.rst @@ -66,3 +66,28 @@ You can also rename variables: This will be useful when you join datasets and do not want variables from one dataset to override the ones from the other. + +********* + rescale +********* + +When combining datasets, you may want to rescale the variables so that +their have matching units. This can be done with the `rescale` option: + +.. literalinclude:: code/rescale_.py + :language: python + +The `rescale` option will also rescale the statistics. The rescaling is +currently limited to simple linear conversions. + +When provided with units, the `rescale` option uses the cfunits_ package +find the `scale` and `offset` attributes of the units and uses these to +rescale the data. + +.. warning:: + + When providing units, the library assumes that the mapping between + them is a linear transformation. No check is does to ensure this is + the case. + +.. _cfunits: https://github.com/NCAS-CMS/cfunits diff --git a/pyproject.toml b/pyproject.toml index eb6b18e1..5210d87c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,44 +51,42 @@ dynamic = [ ] dependencies = [ "anemoi-utils[provenance]>=0.3.15", + "cfunits", "numpy", "pyyaml", "semantic-version", "tqdm", - "zarr", + "zarr<=2.17", ] optional-dependencies.all = [ - "aiohttp", "boto3", "earthkit-data[mars]>=0.9", "earthkit-geo>=0.2", "earthkit-meteo", - "eccodes>=2.37", + "ecmwflibs>=0.6.3", "entrypoints", "gcsfs", "kerchunk", "pyproj", "requests", - "s3fs", ] optional-dependencies.create = [ "earthkit-data[mars]>=0.9", "earthkit-geo>=0.2", "earthkit-meteo", - "eccodes>=2.37", + "ecmwflibs>=0.6.3", "entrypoints", "pyproj", ] optional-dependencies.dev = [ - "aiohttp", "boto3", "earthkit-data[mars]>=0.9", "earthkit-geo>=0.2", "earthkit-meteo", - "eccodes>=2.37", + "ecmwflibs>=0.6.3", "entrypoints", "gcsfs", "kerchunk", @@ -97,39 +95,32 @@ optional-dependencies.dev = [ "pyproj", "pytest", "requests", - "rstfmt", - "s3fs", "sphinx", - "sphinx-argparse<0.5", "sphinx-rtd-theme", ] optional-dependencies.docs = [ "nbsphinx", "pandoc", - "rstfmt", "sphinx", - "sphinx-argparse<0.5", + "sphinx-argparse", "sphinx-rtd-theme", ] -optional-dependencies.kerchunk = [ - "gcsfs", - "kerchunk", - "s3fs", -] - optional-dependencies.remote = [ - "aiohttp", "boto3", "requests", - "s3fs", ] optional-dependencies.tests = [ "pytest", ] +optional-dependencies.xarray = [ + "gcsfs", + "kerchunk", +] + urls.Documentation = "https://anemoi-datasets.readthedocs.io/" urls.Homepage = "https://github.com/ecmwf/anemoi-datasets/" urls.Issues = "https://github.com/ecmwf/anemoi-datasets/issues" @@ -145,3 +136,6 @@ scripts.anemoi-datasets = "anemoi.datasets.__main__:main" [tool.setuptools_scm] version_file = "src/anemoi/datasets/_version.py" + +[tool.isort] +profile = "black" diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 17bd7a87..2b44a1f3 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -19,7 +19,7 @@ def task(what, options, *args, **kwargs): """ now = datetime.datetime.now() - LOG.info(f"Task {what}({args},{kwargs}) starting") + LOG.info(f"🎬 Task {what}({args},{kwargs}) starting") from anemoi.datasets.create import creator_factory @@ -28,7 +28,7 @@ def task(what, options, *args, **kwargs): c = creator_factory(what.replace("-", "_"), **options) result = c.run() - LOG.debug(f"Task {what}({args},{kwargs}) completed ({datetime.datetime.now()-now})") + LOG.info(f"🏁 Task {what}({args},{kwargs}) completed ({datetime.datetime.now()-now})") return result @@ -57,6 +57,7 @@ def add_arguments(self, command_parser): command_parser.add_argument("--trace", action="store_true") def run(self, args): + now = time.time() if args.threads + args.processes: self.parallel_create(args) diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index d12df1ba..461aad73 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -132,7 +132,7 @@ def update_metadata(self, **kwargs): v = v.isoformat() z.attrs[k] = json.loads(json.dumps(v, default=json_tidy)) - @property + @cached_property def anemoi_dataset(self): return open_dataset(self.path) @@ -245,9 +245,9 @@ def check_missing_dates(expected): missing_dates = z.attrs.get("missing_dates", []) missing_dates = sorted([np.datetime64(d) for d in missing_dates]) if missing_dates != expected: - LOG.warn("Missing dates given in recipe do not match the actual missing dates in the dataset.") - LOG.warn(f"Missing dates in recipe: {sorted(str(x) for x in missing_dates)}") - LOG.warn(f"Missing dates in dataset: {sorted(str(x) for x in expected)}") + LOG.warning("Missing dates given in recipe do not match the actual missing dates in the dataset.") + LOG.warning(f"Missing dates in recipe: {sorted(str(x) for x in missing_dates)}") + LOG.warning(f"Missing dates in dataset: {sorted(str(x) for x in expected)}") raise ValueError("Missing dates given in recipe do not match the actual missing dates in the dataset.") check_missing_dates(self.missing_dates) @@ -327,7 +327,7 @@ class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixi dataset_class = NewDataset def __init__(self, path, config, check_name=False, overwrite=False, use_threads=False, statistics_temp_dir=None, progress=None, test=False, cache=None, **kwargs): # fmt: skip if _path_readable(path) and not overwrite: - raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.") + raise Exception(f"{path} already exists. Use overwrite=True to overwrite.") super().__init__(path, cache=cache) self.config = config @@ -345,9 +345,12 @@ def __init__(self, path, config, check_name=False, overwrite=False, use_threads= assert isinstance(self.main_config.output.order_by, dict), self.main_config.output.order_by self.create_elements(self.main_config) - first_date = self.groups.dates[0] - self.minimal_input = self.input.select([first_date]) - LOG.info("Minimal input for 'init' step (using only the first date) :") + LOG.info(f"Groups: {self.groups}") + + one_date = self.groups.one_date() + # assert False, (type(one_date), type(self.groups)) + self.minimal_input = self.input.select(one_date) + LOG.info(f"Minimal input for 'init' step (using only the first date) : {one_date}") LOG.info(self.minimal_input) def run(self): @@ -363,13 +366,15 @@ def _run(self): LOG.info("Config loaded ok:") # LOG.info(self.main_config) - dates = self.groups.dates - frequency = dates.frequency + dates = self.groups.provider.values + frequency = self.groups.provider.frequency + missing = self.groups.provider.missing + assert isinstance(frequency, datetime.timedelta), frequency LOG.info(f"Found {len(dates)} datetimes.") LOG.info(f"Dates: Found {len(dates)} datetimes, in {len(self.groups)} groups: ") - LOG.info(f"Missing dates: {len(dates.missing)}") + LOG.info(f"Missing dates: {len(missing)}") lengths = tuple(len(g) for g in self.groups) variables = self.minimal_input.variables @@ -426,7 +431,7 @@ def _run(self): metadata["start_date"] = dates[0].isoformat() metadata["end_date"] = dates[-1].isoformat() metadata["frequency"] = frequency - metadata["missing_dates"] = [_.isoformat() for _ in dates.missing] + metadata["missing_dates"] = [_.isoformat() for _ in missing] metadata["version"] = VERSION @@ -481,17 +486,6 @@ def _run(self): assert chunks == self.dataset.get_zarr_chunks(), (chunks, self.dataset.get_zarr_chunks()) - def sanity_check_config(a, b): - a = json.dumps(a, sort_keys=True, default=str) - b = json.dumps(b, sort_keys=True, default=str) - b = b.replace("T", " ") # dates are expected to be different because - if a != b: - print("❌❌❌ FIXME: Config serialisation to be checked") - print(a) - print(b) - - sanity_check_config(self.main_config, self.dataset.get_main_config()) - # Return the number of groups to process, so we can show a nice progress bar return len(lengths) @@ -527,11 +521,11 @@ def _run(self): LOG.info(f" -> Skipping {igroup} total={len(self.groups)} (already done)") continue - assert isinstance(group[0], datetime.datetime), group + # assert isinstance(group[0], datetime.datetime), type(group[0]) LOG.debug(f"Building data for group {igroup}/{self.n_groups}") result = self.input.select(dates=group) - assert result.dates == group, (len(result.dates), len(group)) + assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) # There are several groups. # There is one result to load for each group. @@ -545,7 +539,7 @@ def _run(self): def load_result(self, result): # There is one cube to load for each result. - dates = result.dates + dates = list(result.group_of_dates) cube = result.get_cube() shape = cube.extended_user_shape @@ -555,7 +549,9 @@ def load_result(self, result): def check_shape(cube, dates, dates_in_data): if cube.extended_user_shape[0] != len(dates): - print(f"Cube shape does not match the number of dates {cube.extended_user_shape[0]}, {len(dates)}") + print( + f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" + ) print("Requested dates", compress_dates(dates)) print("Cube dates", compress_dates(dates_in_data)) @@ -566,7 +562,7 @@ def check_shape(cube, dates, dates_in_data): print("Extra dates", compress_dates(b - a)) raise ValueError( - f"Cube shape does not match the number of dates {cube.extended_user_shape[0]}, {len(dates)}" + f"Cube shape does not match the number of dates got {cube.extended_user_shape[0]}, expected {len(dates)}" ) check_shape(cube, dates, dates_in_data) @@ -846,7 +842,7 @@ def run(self): ) if len(ifound) < 2: - LOG.warn(f"Not enough data found in {self.path} to compute {self.__class__.__name__}. Skipped.") + LOG.warning(f"Not enough data found in {self.path} to compute {self.__class__.__name__}. Skipped.") self.tmp_storage.delete() return @@ -919,7 +915,7 @@ def __init__(self, *args, **kwargs): self.actors.append(cls(*args, delta=k, **kwargs)) if not self.actors: - LOG.warning("No delta found in kwargs, no addtions will be computed.") + LOG.warning("No delta found in kwargs, no additions will be computed.") def run(self): for actor in self.actors: @@ -947,7 +943,9 @@ def run(self): ) start, end = np.datetime64(start), np.datetime64(end) dates = self.dataset.anemoi_dataset.dates - assert type(dates[0]) == type(start), (type(dates[0]), type(start)) # noqa + + assert type(dates[0]) is type(start), (type(dates[0]), type(start)) + dates = [d for d in dates if d >= start and d <= end] dates = [d for i, d in enumerate(dates) if i not in self.dataset.anemoi_dataset.missing] variables = self.dataset.anemoi_dataset.variables @@ -956,7 +954,7 @@ def run(self): LOG.info(stats) if not all(self.registry.get_flags(sync=False)): - raise Exception(f"❗Zarr {self.path} is not fully built, not writting statistics into dataset.") + raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.") for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]: self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",)) diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py index e292a6e5..d64bdb7f 100644 --- a/src/anemoi/datasets/create/config.py +++ b/src/anemoi/datasets/create/config.py @@ -215,8 +215,9 @@ def set_to_test_mode(cfg): NUMBER_OF_DATES = 4 dates = cfg["dates"] - LOG.warn(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.") + LOG.warning(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.") groups = Groups(**LoadersConfig(cfg).dates) + dates = groups.dates cfg["dates"] = dict( start=dates[0], @@ -234,12 +235,12 @@ def set_element_to_test(obj): if "grid" in obj: previous = obj["grid"] obj["grid"] = "20./20." - LOG.warn(f"Running in test mode. Setting grid to {obj['grid']} instead of {previous}") + LOG.warning(f"Running in test mode. Setting grid to {obj['grid']} instead of {previous}") if "number" in obj: if isinstance(obj["number"], (list, tuple)): previous = obj["number"] obj["number"] = previous[0:3] - LOG.warn(f"Running in test mode. Setting number to {obj['number']} instead of {previous}") + LOG.warning(f"Running in test mode. Setting number to {obj['number']} instead of {previous}") for k, v in obj.items(): set_element_to_test(v) if "constants" in obj: diff --git a/src/anemoi/datasets/create/functions/filters/pressure_level_relative_humidity_to_specific_humidity.py b/src/anemoi/datasets/create/functions/filters/pressure_level_relative_humidity_to_specific_humidity.py new file mode 100644 index 00000000..28d9ed6a --- /dev/null +++ b/src/anemoi/datasets/create/functions/filters/pressure_level_relative_humidity_to_specific_humidity.py @@ -0,0 +1,57 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from collections import defaultdict + +from earthkit.data.indexing.fieldlist import FieldArray +from earthkit.meteo import thermo + +from .single_level_specific_humidity_to_relative_humidity import NewDataField + + +def execute(context, input, t, rh, q="q"): + """Convert relative humidity on pressure levels to specific humidity""" + result = FieldArray() + + params = (t, rh) + pairs = defaultdict(dict) + + # Gather all necessary fields + for f in input: + key = f.metadata(namespace="mars") + param = key.pop("param") + if param in params: + key = tuple(key.items()) + + if param in pairs[key]: + raise ValueError(f"Duplicate field {param} for {key}") + + pairs[key][param] = f + if param == t: + result.append(f) + # all other parameters + else: + result.append(f) + + for keys, values in pairs.items(): + # some checks + + if len(values) != 2: + raise ValueError("Missing fields") + + t_pl = values[t].to_numpy(flatten=True) + rh_pl = values[rh].to_numpy(flatten=True) + pressure = keys[4][1] * 100 # TODO: REMOVE HARDCODED INDICES + # print(f"Handling fields for pressure level {pressure}...") + + # actual conversion from rh --> q_v + q_pl = thermo.specific_humidity_from_relative_humidity(t_pl, rh_pl, pressure) + result.append(NewDataField(values[rh], q_pl, q)) + + return result diff --git a/src/anemoi/datasets/create/functions/filters/pressure_level_specific_humidity_to_relative_humidity.py b/src/anemoi/datasets/create/functions/filters/pressure_level_specific_humidity_to_relative_humidity.py new file mode 100644 index 00000000..1c817fb4 --- /dev/null +++ b/src/anemoi/datasets/create/functions/filters/pressure_level_specific_humidity_to_relative_humidity.py @@ -0,0 +1,57 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from collections import defaultdict + +from earthkit.data.indexing.fieldlist import FieldArray +from earthkit.meteo import thermo + +from .single_level_specific_humidity_to_relative_humidity import NewDataField + + +def execute(context, input, t, q, rh="r"): + """Convert specific humidity on pressure levels to relative humidity""" + result = FieldArray() + + params = (t, q) + pairs = defaultdict(dict) + + # Gather all necessary fields + for f in input: + key = f.metadata(namespace="mars") + param = key.pop("param") + if param in params: + key = tuple(key.items()) + + if param in pairs[key]: + raise ValueError(f"Duplicate field {param} for {key}") + + pairs[key][param] = f + if param == t: + result.append(f) + # all other parameters + else: + result.append(f) + + for keys, values in pairs.items(): + # some checks + + if len(values) != 2: + raise ValueError("Missing fields") + + t_pl = values[t].to_numpy(flatten=True) + q_pl = values[q].to_numpy(flatten=True) + pressure = keys[4][1] * 100 # TODO: REMOVE HARDCODED INDICES + # print(f"Handling fields for pressure level {pressure}...") + + # actual conversion from rh --> q_v + rh_pl = thermo.relative_humidity_from_specific_humidity(t_pl, q_pl, pressure) + result.append(NewDataField(values[q], rh_pl, rh)) + + return result diff --git a/src/anemoi/datasets/create/functions/filters/single_level_dewpoint_to_relative_humidity.py b/src/anemoi/datasets/create/functions/filters/single_level_dewpoint_to_relative_humidity.py new file mode 100644 index 00000000..2afd53d2 --- /dev/null +++ b/src/anemoi/datasets/create/functions/filters/single_level_dewpoint_to_relative_humidity.py @@ -0,0 +1,54 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from collections import defaultdict + +from earthkit.data.indexing.fieldlist import FieldArray +from earthkit.meteo import thermo + +from .single_level_specific_humidity_to_relative_humidity import NewDataField + + +def execute(context, input, t, td, rh="d"): + """Convert relative humidity on single levels to dewpoint""" + result = FieldArray() + + params = (t, td) + pairs = defaultdict(dict) + + # Gather all necessary fields + for f in input: + key = f.metadata(namespace="mars") + param = key.pop("param") + if param in params: + key = tuple(key.items()) + + if param in pairs[key]: + raise ValueError(f"Duplicate field {param} for {key}") + + pairs[key][param] = f + if param == t: + result.append(f) + # all other parameters + else: + result.append(f) + + for keys, values in pairs.items(): + # some checks + + if len(values) != 2: + raise ValueError("Missing fields") + + t_values = values[t].to_numpy(flatten=True) + td_values = values[td].to_numpy(flatten=True) + # actual conversion from td --> rh + rh_values = thermo.relative_humidity_from_dewpoint(t=t_values, td=td_values) + result.append(NewDataField(values[td], rh_values, rh)) + + return result diff --git a/src/anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_dewpoint.py b/src/anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_dewpoint.py new file mode 100644 index 00000000..116feaf3 --- /dev/null +++ b/src/anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_dewpoint.py @@ -0,0 +1,59 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from collections import defaultdict + +from earthkit.data.indexing.fieldlist import FieldArray +from earthkit.meteo import thermo + +from .single_level_specific_humidity_to_relative_humidity import NewDataField + +EPS = 1.0e-4 + + +def execute(context, input, t, rh, td="d"): + """Convert relative humidity on single levels to dewpoint""" + result = FieldArray() + + params = (t, rh) + pairs = defaultdict(dict) + + # Gather all necessary fields + for f in input: + key = f.metadata(namespace="mars") + param = key.pop("param") + if param in params: + key = tuple(key.items()) + + if param in pairs[key]: + raise ValueError(f"Duplicate field {param} for {key}") + + pairs[key][param] = f + if param == t: + result.append(f) + # all other parameters + else: + result.append(f) + + for keys, values in pairs.items(): + # some checks + + if len(values) != 2: + raise ValueError("Missing fields") + + t_values = values[t].to_numpy(flatten=True) + rh_values = values[rh].to_numpy(flatten=True) + # Prevent 0 % Relative humidity which cannot be converted to dewpoint + # Seems to happen over Egypt in the CERRA dataset + rh_values[rh_values == 0] = EPS + # actual conversion from rh --> td + td_values = thermo.dewpoint_from_relative_humidity(t=t_values, r=rh_values) + result.append(NewDataField(values[rh], td_values, td)) + + return result diff --git a/src/anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_specific_humidity.py b/src/anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_specific_humidity.py new file mode 100644 index 00000000..f7be41a3 --- /dev/null +++ b/src/anemoi/datasets/create/functions/filters/single_level_relative_humidity_to_specific_humidity.py @@ -0,0 +1,115 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +import numpy as np +from earthkit.data.indexing.fieldlist import FieldArray +from earthkit.meteo import thermo + +from .single_level_specific_humidity_to_relative_humidity import AutoDict +from .single_level_specific_humidity_to_relative_humidity import NewDataField +from .single_level_specific_humidity_to_relative_humidity import pressure_at_height_level + + +def execute(context, input, height, t, rh, sp, new_name="2q", **kwargs): + """Convert the single (height) level relative humidity to specific humidity""" + result = FieldArray() + + MANDATORY_KEYS = ["A", "B"] + OPTIONAL_KEYS = ["t_ml", "q_ml"] + MISSING_KEYS = [] + DEFAULTS = dict(t_ml="t", q_ml="q") + + for key in OPTIONAL_KEYS: + if key not in kwargs: + print(f"key {key} not found in yaml-file, using default key: {DEFAULTS[key]}") + kwargs[key] = DEFAULTS[key] + + for key in MANDATORY_KEYS: + if key not in kwargs: + MISSING_KEYS.append(key) + + if MISSING_KEYS: + raise KeyError(f"Following keys are missing: {', '.join(MISSING_KEYS)}") + + single_level_params = (t, rh, sp) + model_level_params = (kwargs["t_ml"], kwargs["q_ml"]) + + needed_fields = AutoDict() + + # Gather all necessary fields + for f in input: + key = f.metadata(namespace="mars") + param = key.pop("param") + # check single level parameters + if param in single_level_params: + levtype = key.pop("levtype") + key = tuple(key.items()) + + if param in needed_fields[key][levtype]: + raise ValueError(f"Duplicate single level field {param} for {key}") + + needed_fields[key][levtype][param] = f + if param == rh: + if kwargs.get("keep_rh", False): + result.append(f) + else: + result.append(f) + + # check model level parameters + elif param in model_level_params: + levtype = key.pop("levtype") + levelist = key.pop("levelist") + key = tuple(key.items()) + + if param in needed_fields[key][levtype][levelist]: + raise ValueError(f"Duplicate model level field {param} for {key} at level {levelist}") + + needed_fields[key][levtype][levelist][param] = f + + # all other parameters + else: + result.append(f) + + for _, values in needed_fields.items(): + # some checks + if len(values["sfc"]) != 3: + raise ValueError("Missing surface fields") + + rh_sl = values["sfc"][rh].to_numpy(flatten=True) + t_sl = values["sfc"][t].to_numpy(flatten=True) + sp_sl = values["sfc"][sp].to_numpy(flatten=True) + + nlevels = len(kwargs["A"]) - 1 + if len(values["ml"]) != nlevels: + raise ValueError("Missing model levels") + + for key in values["ml"].keys(): + if len(values["ml"][key]) != 2: + raise ValueError(f"Missing field on level {key}") + + # create 3D arrays for upper air fields + levels = list(values["ml"].keys()) + levels.sort() + t_ml = [] + q_ml = [] + for level in levels: + t_ml.append(values["ml"][level][kwargs["t_ml"]].to_numpy(flatten=True)) + q_ml.append(values["ml"][level][kwargs["q_ml"]].to_numpy(flatten=True)) + + t_ml = np.stack(t_ml) + q_ml = np.stack(q_ml) + + # actual conversion from rh --> q_v + p_sl = pressure_at_height_level(height, q_ml, t_ml, sp_sl, np.array(kwargs["A"]), np.array(kwargs["B"])) + q_sl = thermo.specific_humidity_from_relative_humidity(t_sl, rh_sl, p_sl) + + result.append(NewDataField(values["sfc"][rh], q_sl, new_name)) + + return result diff --git a/src/anemoi/datasets/create/functions/filters/single_level_specific_humidity_to_relative_humidity.py b/src/anemoi/datasets/create/functions/filters/single_level_specific_humidity_to_relative_humidity.py new file mode 100644 index 00000000..aac55f03 --- /dev/null +++ b/src/anemoi/datasets/create/functions/filters/single_level_specific_humidity_to_relative_humidity.py @@ -0,0 +1,390 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + + +import numpy as np +from earthkit.data.indexing.fieldlist import FieldArray +from earthkit.meteo import constants +from earthkit.meteo import thermo + + +# Alternative proposed by Baudouin Raoult +class AutoDict(dict): + def __missing__(self, key): + value = self[key] = type(self)() + return value + + +class NewDataField: + def __init__(self, field, data, new_name): + self.field = field + self.data = data + self.new_name = new_name + + def to_numpy(self, *args, **kwargs): + return self.data + + def metadata(self, key=None, **kwargs): + if key is None: + return self.field.metadata(**kwargs) + + value = self.field.metadata(key, **kwargs) + if key == "param": + return self.new_name + return value + + def __getattr__(self, name): + return getattr(self.field, name) + + +def model_level_pressure(A, B, surface_pressure): + """Calculates: + - pressure at the model full- and half-levels + - delta: depth of log(pressure) at full levels + - alpha: alpha term #TODO: more descriptive information + + Parameters + ---------- + A : ndarray + A-coefficients defining the model levels + B : ndarray + B-coefficients defining the model levels + surface_pressure: number or ndarray + surface pressure (Pa) + + Returns + ------- + ndarray + pressure at model full-levels + ndarray + pressure at model half-levels + ndarray + delta at full-levels + ndarray + alpha at full levels + """ + + # constants + PRESSURE_TOA = 0.1 # safety when highest pressure level = 0.0 + + # make the calculation agnostic to the number of dimensions + ndim = surface_pressure.ndim + new_shape_half = (A.shape[0],) + (1,) * ndim + A_reshaped = A.reshape(new_shape_half) + B_reshaped = B.reshape(new_shape_half) + + # calculate pressure on model half-levels + p_half_level = A_reshaped + B_reshaped * surface_pressure[np.newaxis, ...] + + # calculate delta + new_shape_full = (A.shape[0] - 1,) + surface_pressure.shape + delta = np.zeros(new_shape_full) + delta[1:, ...] = np.log(p_half_level[2:, ...] / p_half_level[1:-1, ...]) + + # pressure at highest half level<= 0.1 + if np.any(p_half_level[0, ...] <= PRESSURE_TOA): + delta[0, ...] = np.log(p_half_level[1, ...] / PRESSURE_TOA) + # pressure at highest half level > 0.1 + else: + delta[0, ...] = np.log(p_half_level[1, ...] / p_half_level[0, ...]) + + # calculate alpha + alpha = np.zeros(new_shape_full) + + alpha[1:, ...] = 1.0 - p_half_level[1:-1, ...] / (p_half_level[2:, ...] - p_half_level[1:-1, ...]) * delta[1:, ...] + + # pressure at highest half level <= 0.1 + if np.any(p_half_level[0, ...] <= PRESSURE_TOA): + alpha[0, ...] = 1.0 # ARPEGE choice, ECMWF IFS uses log(2) + # pressure at highest half level > 0.1 + else: + alpha[0, ...] = 1.0 - p_half_level[0, ...] / (p_half_level[1, ...] - p_half_level[0, ...]) * delta[0, ...] + + # calculate pressure on model full levels + # TODO: is there a faster way to calculate the averages? + # TODO: introduce option to calculate full levels in more complicated way + p_full_level = np.apply_along_axis(lambda m: np.convolve(m, np.ones(2) / 2, mode="valid"), axis=0, arr=p_half_level) + + return p_full_level, p_half_level, delta, alpha + + +def calc_specific_gas_constant(q): + """Calculates the specific gas constant of moist air + (specific content of cloud particles and hydrometeors are neglected) + + Parameters + ---------- + q : number or ndarray + specific humidity + + Returns + ------- + number or ndarray + specific gas constant of moist air + """ + + R = constants.Rd + (constants.Rv - constants.Rd) * q + return R + + +def relative_geopotential_thickness(alpha, q, T): + """Calculates the geopotential thickness w.r.t the surface on model full-levels + + Parameters + ---------- + alpha : ndarray + alpha term of pressure calculations + q : ndarray + specific humidity (in kg/kg) on model full-levels + T : ndarray + temperature (in Kelvin) on model full-levels + + Returns + ------- + ndarray + geopotential thickness of model full-levels w.r.t. the surface + """ + + R = calc_specific_gas_constant(q) + dphi = np.cumsum(np.flip(alpha * R * T, axis=0), axis=0) + dphi = np.flip(dphi, axis=0) + + return dphi + + +def pressure_at_height_level(height, q, T, sp, A, B): + """Calculates the pressure at a height level given in meters above surface. + This is done by finding the model level above and below the specified height + and interpolating the pressure + + Parameters + ---------- + height : number + height (in meters) above the surface for which the pressure is wanted + q : ndarray + specific humidity (kg/kg) at model full-levels + T : ndarray + temperature (K) at model full-levels + sp : ndarray + surface pressure (Pa) + A : ndarray + A-coefficients defining the model levels + B : ndarray + B-coefficients defining the model levels + + Returns + ------- + number or ndarray + pressure (Pa) at the given height level + """ + + # geopotential thickness of the height level + tdphi = height * constants.g + + # pressure(-related) variables + p_full, p_half, _, alpha = model_level_pressure(A, B, sp) + + # relative geopot. thickness of full levels + dphi = relative_geopotential_thickness(alpha, q, T) + + # find the model full level right above the height level + i_phi = (tdphi > dphi).sum(0) + + # initialize the output array + p_height = np.zeros_like(i_phi, dtype=np.float64) + + # define mask: requested height is below the lowest model full-level + mask = i_phi == 0 + + # CASE 1: requested height is below the lowest model full-level + # --> interpolation between surface pressure and lowest model full-level + p_height[mask] = (p_half[-1, ...] + tdphi / dphi[-1, ...] * (p_full[-1, ...] - p_half[-1, ...]))[mask] + + # CASE 2: requested height is above the lowest model full-level + # --> interpolation between between model full-level above and below + + # define some indices for masking and readability + i_lev = alpha.shape[0] - i_phi - 1 # convert phi index to model level index + indices = np.indices(i_lev.shape) + masked_indices = tuple(dim[~mask] for dim in indices) + above = (i_lev[~mask],) + masked_indices + below = (i_lev[~mask] + 1,) + masked_indices + + dphi_above = dphi[above] + dphi_below = dphi[below] + + factor = (tdphi - dphi_above) / (dphi_below - dphi_above) + p_height[~mask] = p_full[above] + factor * (p_full[below] - p_full[above]) + + return p_height + + +def execute(context, input, height, t, q, sp, new_name="2r", **kwargs): + """Convert the single (height) level specific humidity to relative humidity""" + result = FieldArray() + + MANDATORY_KEYS = ["A", "B"] + OPTIONAL_KEYS = ["t_ml", "q_ml"] + MISSING_KEYS = [] + DEFAULTS = dict(t_ml="t", q_ml="q") + + for key in OPTIONAL_KEYS: + if key not in kwargs: + print(f"key {key} not found in yaml-file, using default key: {DEFAULTS[key]}") + kwargs[key] = DEFAULTS[key] + + for key in MANDATORY_KEYS: + if key not in kwargs: + MISSING_KEYS.append(key) + + if MISSING_KEYS: + raise KeyError(f"Following keys are missing: {', '.join(MISSING_KEYS)}") + + single_level_params = (t, q, sp) + model_level_params = (kwargs["t_ml"], kwargs["q_ml"]) + + needed_fields = AutoDict() + + # Gather all necessary fields + for f in input: + key = f.metadata(namespace="mars") + param = key.pop("param") + # check single level parameters + if param in single_level_params: + levtype = key.pop("levtype") + key = tuple(sorted(key.items())) + + if param in needed_fields[key][levtype]: + raise ValueError(f"Duplicate single level field {param} for {key}") + + needed_fields[key][levtype][param] = f + if param == q: + if kwargs.get("keep_q", False): + result.append(f) + else: + result.append(f) + + # check model level parameters + elif param in model_level_params: + levtype = key.pop("levtype") + levelist = key.pop("levelist") + key = tuple(sorted(key.items())) + + if param in needed_fields[key][levtype][levelist]: + raise ValueError(f"Duplicate model level field {param} for {key} at level {levelist}") + + needed_fields[key][levtype][levelist][param] = f + + # all other parameters + else: + result.append(f) + + for _, values in needed_fields.items(): + # some checks + if len(values["sfc"]) != 3: + raise ValueError("Missing surface fields") + + q_sl = values["sfc"][q].to_numpy(flatten=True) + t_sl = values["sfc"][t].to_numpy(flatten=True) + sp_sl = values["sfc"][sp].to_numpy(flatten=True) + + nlevels = len(kwargs["A"]) - 1 + if len(values["ml"]) != nlevels: + raise ValueError("Missing model levels") + + for key in values["ml"].keys(): + if len(values["ml"][key]) != 2: + raise ValueError(f"Missing field on level {key}") + + # create 3D arrays for upper air fields + levels = list(values["ml"].keys()) + levels.sort() + t_ml = [] + q_ml = [] + for level in levels: + t_ml.append(values["ml"][level][kwargs["t_ml"]].to_numpy(flatten=True)) + q_ml.append(values["ml"][level][kwargs["q_ml"]].to_numpy(flatten=True)) + + t_ml = np.stack(t_ml) + q_ml = np.stack(q_ml) + + # actual conversion from qv --> rh + # FIXME: + # For now We need to go from qv --> td --> rh to take into account + # the mixed / ice phase when T ~ 0C / T < 0C + # See https://github.com/ecmwf/earthkit-meteo/issues/15 + p_sl = pressure_at_height_level(height, q_ml, t_ml, sp_sl, np.array(kwargs["A"]), np.array(kwargs["B"])) + td_sl = thermo.dewpoint_from_specific_humidity(q=q_sl, p=p_sl) + rh_sl = thermo.relative_humidity_from_dewpoint(t=t_sl, td=td_sl) + + result.append(NewDataField(values["sfc"][q], rh_sl, new_name)) + + return result + + +def test(): + from earthkit.data import from_source + from earthkit.data.readers.grib.index import GribFieldList + + # IFS forecasts have both specific humidity and dewpoint + sl = from_source( + "mars", + { + "date": "2022-01-01", + "class": "od", + "expver": "1", + "stream": "oper", + "levtype": "sfc", + "param": "96.174/134.128/167.128/168.128", + "time": "00:00:00", + "type": "fc", + "step": "2", + "grid": "O640", + }, + ) + + ml = from_source( + "mars", + { + "date": "2022-01-01", + "class": "od", + "expver": "1", + "stream": "oper", + "levtype": "ml", + "levelist": "130/131/132/133/134/135/136/137", + "param": "130/133", + "time": "00:00:00", + "type": "fc", + "step": "2", + "grid": "O640", + }, + ) + source = GribFieldList.merge([sl, ml]) + + # IFS A and B coeffients for level 137 - 129 + kwargs = { + "A": [424.414063, 302.476563, 202.484375, 122.101563, 62.781250, 22.835938, 3.757813, 0.0, 0.0], + "B": [0.969513, 0.975078, 0.980072, 0.984542, 0.988500, 0.991984, 0.995003, 0.997630, 1.000000], + } + source = execute(None, source, 2, "2t", "2sh", "sp", "2r", **kwargs) + + temperature = source[2].to_numpy(flatten=True) + dewpoint = source[3].to_numpy(flatten=True) + relhum = source[4].to_numpy() + newdew = thermo.dewpoint_from_relative_humidity(temperature, relhum) + + print(f"Mean difference in dewpoint temperature: {np.abs(newdew - dewpoint).mean():02f} degC") + print(f"Median difference in dewpoint temperature: {np.median(np.abs(newdew - dewpoint)):02f} degC") + print(f"Maximum difference in dewpoint temperature: {np.abs(newdew - dewpoint).max():02f} degC") + + # source.save("source.grib") + + +if __name__ == "__main__": + test() diff --git a/src/anemoi/datasets/create/functions/filters/speeddir_to_uv.py b/src/anemoi/datasets/create/functions/filters/speeddir_to_uv.py new file mode 100644 index 00000000..80603d5b --- /dev/null +++ b/src/anemoi/datasets/create/functions/filters/speeddir_to_uv.py @@ -0,0 +1,77 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from collections import defaultdict + +import numpy as np +from earthkit.data.indexing.fieldlist import FieldArray +from earthkit.meteo.wind.array import polar_to_xy + + +class NewDataField: + def __init__(self, field, data, new_name): + self.field = field + self.data = data + self.new_name = new_name + + def to_numpy(self, *args, **kwargs): + return self.data + + def metadata(self, key=None, **kwargs): + if key is None: + return self.field.metadata(**kwargs) + + value = self.field.metadata(key, **kwargs) + if key == "param": + return self.new_name + return value + + def __getattr__(self, name): + return getattr(self.field, name) + + +def execute(context, input, wind_speed, wind_dir, u_component="u", v_component="v", in_radians=False): + + result = FieldArray() + + wind_params = (wind_speed, wind_dir) + wind_pairs = defaultdict(dict) + + for f in input: + key = f.metadata(namespace="mars") + param = key.pop("param") + + if param not in wind_params: + result.append(f) + continue + + key = tuple(key.items()) + + if param in wind_pairs[key]: + raise ValueError(f"Duplicate wind component {param} for {key}") + + wind_pairs[key][param] = f + + for _, pairs in wind_pairs.items(): + if len(pairs) != 2: + raise ValueError("Missing wind component") + + magnitude = pairs[wind_speed] + direction = pairs[wind_dir] + + # assert speed.grid_mapping == dir.grid_mapping + if in_radians: + direction = np.rad2deg(direction) + + u, v = polar_to_xy(magnitude.to_numpy(flatten=True), direction.to_numpy(flatten=True)) + + result.append(NewDataField(magnitude, u, u_component)) + result.append(NewDataField(direction, v, v_component)) + + return result diff --git a/src/anemoi/datasets/create/functions/filters/uv_to_speeddir.py b/src/anemoi/datasets/create/functions/filters/uv_to_speeddir.py new file mode 100644 index 00000000..bfba4e18 --- /dev/null +++ b/src/anemoi/datasets/create/functions/filters/uv_to_speeddir.py @@ -0,0 +1,55 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from collections import defaultdict + +import numpy as np +from earthkit.data.indexing.fieldlist import FieldArray +from earthkit.meteo.wind.array import xy_to_polar + +from anemoi.datasets.create.functions.filters.speeddir_to_uv import NewDataField + + +def execute(context, input, u_component, v_component, wind_speed, wind_dir, in_radians=False): + result = FieldArray() + + wind_params = (u_component, v_component) + wind_pairs = defaultdict(dict) + + for f in input: + key = f.metadata(namespace="mars") + param = key.pop("param") + + if param not in wind_params: + result.append(f) + continue + + key = tuple(key.items()) + + if param in wind_pairs[key]: + raise ValueError(f"Duplicate wind component {param} for {key}") + + wind_pairs[key][param] = f + + for _, pairs in wind_pairs.items(): + if len(pairs) != 2: + raise ValueError("Missing wind component") + + u = pairs[u_component] + v = pairs[v_component] + + # assert speed.grid_mapping == dir.grid_mapping + magnitude, direction = xy_to_polar(u.to_numpy(flatten=True), v.to_numpy(flatten=True)) + if in_radians: + direction = np.deg2rad(direction) + + result.append(NewDataField(u, magnitude, wind_speed)) + result.append(NewDataField(v, direction, wind_dir)) + + return result diff --git a/src/anemoi/datasets/create/functions/sources/grib.py b/src/anemoi/datasets/create/functions/sources/grib.py index 1ddca353..8b2b7e35 100644 --- a/src/anemoi/datasets/create/functions/sources/grib.py +++ b/src/anemoi/datasets/create/functions/sources/grib.py @@ -11,9 +11,87 @@ import glob from earthkit.data import from_source +from earthkit.data.indexing.fieldlist import FieldArray from earthkit.data.utils.patterns import Pattern +def _load(context, name, record): + ds = None + + param = record["param"] + + if "path" in record: + context.info(f"Using {name} from {record['path']} (param={param})") + ds = from_source("file", record["path"]) + + if "url" in record: + context.info(f"Using {name} from {record['url']} (param={param})") + ds = from_source("url", record["url"]) + + ds = ds.sel(param=param) + + assert len(ds) == 1, f"{name} {param}, expected one field, got {len(ds)}" + ds = ds[0] + + return ds.to_numpy(flatten=True), ds.metadata("uuidOfHGrid") + + +class Geography: + """This class retrieve the latitudes and longitudes of unstructured grids, + and checks if the fields are compatible with the grid. + """ + + def __init__(self, context, latitudes, longitudes): + + latitudes, uuidOfHGrid_lat = _load(context, "latitudes", latitudes) + longitudes, uuidOfHGrid_lon = _load(context, "longitudes", longitudes) + + assert ( + uuidOfHGrid_lat == uuidOfHGrid_lon + ), f"uuidOfHGrid mismatch: lat={uuidOfHGrid_lat} != lon={uuidOfHGrid_lon}" + + context.info(f"Latitudes: {len(latitudes)}, Longitudes: {len(longitudes)}") + assert len(latitudes) == len(longitudes) + + self.uuidOfHGrid = uuidOfHGrid_lat + self.latitudes = latitudes + self.longitudes = longitudes + self.first = True + + def check(self, field): + if self.first: + # We only check the first field, for performance reasons + assert ( + field.metadata("uuidOfHGrid") == self.uuidOfHGrid + ), f"uuidOfHGrid mismatch: {field.metadata('uuidOfHGrid')} != {self.uuidOfHGrid}" + self.first = False + + +class AddGrid: + """An earth-kit.data.Field wrapper that adds grid information.""" + + def __init__(self, field, geography): + self._field = field + + geography.check(field) + + self._latitudes = geography.latitudes + self._longitudes = geography.longitudes + + def __getattr__(self, name): + return getattr(self._field, name) + + def __repr__(self) -> str: + return repr(self._field) + + def grid_points(self): + return self._latitudes, self._longitudes + + @property + def resolution(self): + return "unknown" + + def check(ds, paths, **kwargs): count = 1 for k, v in kwargs.items(): @@ -34,9 +112,13 @@ def _expand(paths): yield path -def execute(context, dates, path, *args, **kwargs): +def execute(context, dates, path, latitudes=None, longitudes=None, *args, **kwargs): given_paths = path if isinstance(path, list) else [path] + geography = None + if latitudes is not None and longitudes is not None: + geography = Geography(context, latitudes, longitudes) + ds = from_source("empty") dates = [d.isoformat() for d in dates] @@ -56,4 +138,7 @@ def execute(context, dates, path, *args, **kwargs): if kwargs: check(ds, given_paths, valid_datetime=dates, **kwargs) + if geography is not None: + ds = FieldArray([AddGrid(_, geography) for _ in ds]) + return ds diff --git a/src/anemoi/datasets/create/functions/sources/hindcasts.py b/src/anemoi/datasets/create/functions/sources/hindcasts.py index f4d6b31b..0dc5c9c4 100644 --- a/src/anemoi/datasets/create/functions/sources/hindcasts.py +++ b/src/anemoi/datasets/create/functions/sources/hindcasts.py @@ -6,7 +6,6 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. # -import datetime import logging from earthkit.data.core.fieldlist import MultiFieldList @@ -14,7 +13,6 @@ from anemoi.datasets.create.functions.sources.mars import mars LOGGER = logging.getLogger(__name__) -DEBUG = True def _to_list(x): @@ -23,91 +21,34 @@ def _to_list(x): return [x] -class HindcastCompute: - def __init__(self, base_times, available_steps, request): - self.base_times = base_times - self.available_steps = available_steps - self.request = request - - def compute_hindcast(self, date): - result = [] - for step in sorted(self.available_steps): # Use the shortest step - start_date = date - datetime.timedelta(hours=step) - hours = start_date.hour - if hours in self.base_times: - r = self.request.copy() - r["date"] = start_date - r["time"] = f"{start_date.hour:02d}00" - r["step"] = step - result.append(r) - - if not result: - raise ValueError( - f"Cannot find data for {self.request} for {date} (base_times={self.base_times}, " - f"available_steps={self.available_steps})" - ) - - if len(result) > 1: - raise ValueError( - f"Multiple requests for {self.request} for {date} (base_times={self.base_times}, " - f"available_steps={self.available_steps})" - ) - - return result[0] - - -def use_reference_year(reference_year, request): - request = request.copy() - hdate = request.pop("date") - - if hdate.year >= reference_year: - return None, False +def hindcasts(context, dates, **request): - try: - date = datetime.datetime(reference_year, hdate.month, hdate.day) - except ValueError: - if hdate.month == 2 and hdate.day == 29: - return None, False - raise + from anemoi.datasets.dates import HindcastsDates - request.update(date=date.strftime("%Y-%m-%d"), hdate=hdate.strftime("%Y-%m-%d")) - return request, True + provider = context.dates_provider + assert isinstance(provider, HindcastsDates) + context.trace("H️", f"hindcasts {len(dates)=}") -def hindcasts(context, dates, **request): request["param"] = _to_list(request["param"]) - request["step"] = _to_list(request["step"]) + request["step"] = _to_list(request.get("step", 0)) request["step"] = [int(_) for _ in request["step"]] - if request.get("stream") == "enfh" and "base_times" not in request: - request["base_times"] = [0] - - available_steps = request.pop("step") - available_steps = _to_list(available_steps) - - base_times = request.pop("base_times") - - reference_year = request.pop("reference_year") + context.trace("H️", f"hindcast {request}") - context.trace("H️", f"hindcast {request} {base_times} {available_steps} {reference_year}") - - c = HindcastCompute(base_times, available_steps, request) requests = [] for d in dates: - req = c.compute_hindcast(d) - req, ok = use_reference_year(reference_year, req) - if ok: - requests.append(req) - - # print("HINDCASTS requests", reference_year, base_times, available_steps) - # print("HINDCASTS dates", compress_dates(dates)) + r = request.copy() + hindcast = provider.mapping[d] + r["hdate"] = hindcast.hdate.strftime("%Y-%m-%d") + r["date"] = hindcast.refdate.strftime("%Y-%m-%d") + r["time"] = hindcast.refdate.strftime("%H") + r["step"] = hindcast.step + requests.append(r) if len(requests) == 0: - # print("HINDCASTS no requests") return MultiFieldList([]) - # print("HINDCASTS requests", requests) - return mars( context, dates, diff --git a/src/anemoi/datasets/create/functions/sources/mars.py b/src/anemoi/datasets/create/functions/sources/mars.py index a36ccf21..0e1eebbb 100644 --- a/src/anemoi/datasets/create/functions/sources/mars.py +++ b/src/anemoi/datasets/create/functions/sources/mars.py @@ -203,16 +203,22 @@ def mars(context, dates, *requests, request_already_using_valid_datetime=False, request_already_using_valid_datetime=request_already_using_valid_datetime, date_key=date_key, ) + + requests = list(requests) + ds = from_source("empty") + context.trace("βœ…", f"{[str(d) for d in dates]}") + context.trace("βœ…", f"Will run {len(requests)} requests") + for r in requests: + r = {k: v for k, v in r.items() if v != ("-",)} + context.trace("βœ…", f"mars {r}") + for r in requests: r = {k: v for k, v in r.items() if v != ("-",)} if context.use_grib_paramid and "param" in r: r = use_grib_paramid(r) - if DEBUG: - context.trace("βœ…", f"from_source(mars, {r}") - for k, v in r.items(): if k not in MARS_KEYS: raise ValueError( diff --git a/src/anemoi/datasets/create/functions/sources/xarray/field.py b/src/anemoi/datasets/create/functions/sources/xarray/field.py index cdbd061f..f737e04c 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/field.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/field.py @@ -7,6 +7,7 @@ # nor does it submit to any jurisdiction. # +import datetime import logging from earthkit.data.core.fieldlist import Field @@ -103,7 +104,12 @@ def longitudes(self): @property def forecast_reference_time(self): - return self.owner.forecast_reference_time + date, time = self.metadata("date", "time") + assert len(time) == 4, time + assert len(date) == 8, date + yyyymmdd = int(date) + time = int(time) // 100 + return datetime.datetime(yyyymmdd // 10000, yyyymmdd // 100 % 100, yyyymmdd % 100, time) def __repr__(self): return repr(self._metadata) diff --git a/src/anemoi/datasets/create/functions/sources/xarray/metadata.py b/src/anemoi/datasets/create/functions/sources/xarray/metadata.py index e98f9ea7..85ca95d9 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/metadata.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/metadata.py @@ -70,15 +70,17 @@ def as_namespace(self, namespace=None): return self._as_mars() def _as_mars(self): - return dict( - param=self["variable"], - step=self["step"], - levelist=self["level"], - levtype=self["levtype"], - number=self["number"], - date=self["date"], - time=self["time"], - ) + return {} + # p = dict( + # param=self.get("variable", self.get("param")), + # step=self.get("step"), + # levelist=self.get("levelist", self.get("level")), + # levtype=self.get("levtype"), + # number=self.get("number"), + # date=self.get("date"), + # time=self.get("time"), + # ) + # return {k: v for k, v in p.items() if v is not None} def _base_datetime(self): return self._field.forecast_reference_time @@ -135,12 +137,12 @@ def resolution(self): # TODO: implement resolution return None - @property + # @property def mars_grid(self): # TODO: implement mars_grid return None - @property + # @property def mars_area(self): # TODO: code me # return [self.north, self.west, self.south, self.east] diff --git a/src/anemoi/datasets/create/input.py b/src/anemoi/datasets/create/input.py index e696feb3..d26faa07 100644 --- a/src/anemoi/datasets/create/input.py +++ b/src/anemoi/datasets/create/input.py @@ -23,7 +23,7 @@ from earthkit.data.core.fieldlist import MultiFieldList from earthkit.data.core.order import build_remapping -from anemoi.datasets.dates import Dates +from anemoi.datasets.dates import DatesProvider from .functions import import_function from .template import Context @@ -75,7 +75,7 @@ def time_delta_to_string(delta): def is_function(name, kind): - name, delta = parse_function_name(name) # noqa + name, _ = parse_function_name(name) try: import_function(name, kind) return True @@ -204,11 +204,15 @@ class Result: _coords_already_built = False def __init__(self, context, action_path, dates): + from anemoi.datasets.dates.groups import GroupOfDates + + assert isinstance(dates, GroupOfDates), dates + assert isinstance(context, ActionContext), type(context) assert isinstance(action_path, list), action_path self.context = context - self.dates = dates + self.group_of_dates = dates self.action_path = action_path @property @@ -405,10 +409,10 @@ def __repr__(self, *args, _indent_="\n", **kwargs): more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) dates = " no-dates" - if self.dates is not None: - dates = f" {len(self.dates)} dates" + if self.group_of_dates is not None: + dates = f" {len(self.group_of_dates)} dates" dates += " (" - dates += "/".join(d.strftime("%Y-%m-%d:%H") for d in self.dates) + dates += "/".join(d.strftime("%Y-%m-%d:%H") for d in self.group_of_dates) if len(dates) > 100: dates = dates[:100] + "..." dates += ")" @@ -423,7 +427,7 @@ def _raise_not_implemented(self): raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") def _trace_datasource(self, *args, **kwargs): - return f"{self.__class__.__name__}({shorten(self.dates)})" + return f"{self.__class__.__name__}({self.group_of_dates})" def build_coords(self): if self._coords_already_built: @@ -513,7 +517,7 @@ def proj_string(self): @cached_property def shape(self): return [ - len(self.dates), + len(self.group_of_dates), len(self.variables), len(self.ensembles), len(self.grid_values), @@ -522,7 +526,7 @@ def shape(self): @cached_property def coords(self): return { - "dates": self.dates, + "dates": list(self.group_of_dates), "variables": self.variables, "ensembles": self.ensembles, "values": self.grid_values, @@ -573,7 +577,7 @@ def __init__(self, context, action_path, dates, action): self.args, self.kwargs = substitute(context, (self.action.args, self.action.kwargs)) def _trace_datasource(self, *args, **kwargs): - return f"{self.action.name}({shorten(self.dates)})" + return f"{self.action.name}({self.group_of_dates})" @cached_property @assert_fieldlist @@ -583,14 +587,21 @@ def datasource(self): args, kwargs = resolve(self.context, (self.args, self.kwargs)) try: - return _tidy(self.action.function(FunctionContext(self), self.dates, *args, **kwargs)) + return _tidy( + self.action.function( + FunctionContext(self), + list(self.group_of_dates), # Will provide a list of datetime objects + *args, + **kwargs, + ) + ) except Exception: LOG.error(f"Error in {self.action.function.__name__}", exc_info=True) raise def __repr__(self): try: - return f"{self.action.name}({shorten(self.dates)})" + return f"{self.action.name}({self.group_of_dates})" except Exception: return f"{self.__class__.__name__}(unitialised)" @@ -609,7 +620,7 @@ def __init__(self, context, action_path, dates, results, **kwargs): @notify_result @trace_datasource def datasource(self): - ds = EmptyResult(self.context, self.action_path, self.dates).datasource + ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource for i in self.results: ds += i.datasource return _tidy(ds) @@ -824,7 +835,7 @@ def __init__(self, context, action_path, dates, results, **kwargs): @notify_result @trace_datasource def datasource(self): - ds = EmptyResult(self.context, self.action_path, self.dates).datasource + ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource for i in self.results: ds += i.datasource return _tidy(ds) @@ -904,7 +915,7 @@ def __init__(self, context, action_path, *configs): cfg = deepcopy(cfg) dates_cfg = cfg.pop("dates") assert isinstance(dates_cfg, dict), dates_cfg - filtering_dates = Dates.from_config(**dates_cfg) + filtering_dates = DatesProvider.from_config(**dates_cfg) action = action_factory(cfg, context, action_path + [str(i)]) parts.append((filtering_dates, action)) self.parts = parts @@ -915,9 +926,11 @@ def __repr__(self): @trace_select def select(self, dates): + from anemoi.datasets.dates.groups import GroupOfDates + results = [] for filtering_dates, action in self.parts: - newdates = sorted(set(dates) & set(filtering_dates)) + newdates = GroupOfDates(sorted(set(dates) & set(filtering_dates)), dates.provider) if newdates: results.append(action.select(newdates)) if not results: @@ -953,8 +966,10 @@ def action_factory(config, context, action_path): if isinstance(config[key], list): args, kwargs = config[key], {} - if isinstance(config[key], dict): + elif isinstance(config[key], dict): args, kwargs = [], config[key] + else: + raise ValueError(f"Invalid input config {config[key]} ({type(config[key])}") cls = { # "date_shift": DateShiftAction, @@ -1021,6 +1036,13 @@ def __init__(self, owner): def trace(self, emoji, *args): trace(emoji, *args) + def info(self, *args, **kwargs): + LOG.info(*args, **kwargs) + + @property + def dates_provider(self): + return self.owner.group_of_dates.provider + class ActionContext(Context): def __init__(self, /, order_by, flatten_grid, remapping, use_grib_paramid): diff --git a/src/anemoi/datasets/create/persistent.py b/src/anemoi/datasets/create/persistent.py index 207553e7..ac3a0c15 100644 --- a/src/anemoi/datasets/create/persistent.py +++ b/src/anemoi/datasets/create/persistent.py @@ -68,7 +68,7 @@ def __setitem__(self, key, elt): path = os.path.join(self.dirname, f"{h}.pickle") if os.path.exists(path): - LOG.warn(f"{path} already exists") + LOG.warning(f"{path} already exists") tmp_path = path + f".tmp-{os.getpid()}-on-{socket.gethostname()}" with open(tmp_path, "wb") as f: diff --git a/src/anemoi/datasets/create/utils.py b/src/anemoi/datasets/create/utils.py index 3dcd86c7..a98df19c 100644 --- a/src/anemoi/datasets/create/utils.py +++ b/src/anemoi/datasets/create/utils.py @@ -62,6 +62,9 @@ def make_list_int(value): def normalize_and_check_dates(dates, start, end, frequency, dtype="datetime64[s]"): + + dates = [d.hdate if hasattr(d, "hdate") else d for d in dates] + assert isinstance(frequency, datetime.timedelta), frequency start = np.datetime64(start) end = np.datetime64(end) diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 2e95ec4f..2f6af1b1 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -23,7 +23,11 @@ class Dataset: arguments = {} - def mutate(self): + def mutate(self) -> "Dataset": + """ + Give an opportunity to a subclass to return a new Dataset + object of a different class, if needed. + """ return self def swap_with_parent(self, parent): @@ -90,6 +94,12 @@ def _subset(self, **kwargs): rename = kwargs.pop("rename") return Rename(self, rename)._subset(**kwargs).mutate() + if "rescale" in kwargs: + from .rescale import Rescale + + rescale = kwargs.pop("rescale") + return Rescale(self, rescale)._subset(**kwargs).mutate() + if "statistics" in kwargs: from ..data import open_dataset from .statistics import Statistics diff --git a/src/anemoi/datasets/data/debug.py b/src/anemoi/datasets/data/debug.py index 98ea3d9d..42c84e15 100644 --- a/src/anemoi/datasets/data/debug.py +++ b/src/anemoi/datasets/data/debug.py @@ -209,10 +209,14 @@ def wrapper(self, index): return wrapper +def _identity(x): + return x + + if DEBUG_ZARR_INDEXING: debug_indexing = _debug_indexing else: - debug_indexing = lambda x: x # noqa + debug_indexing = _identity def debug_zarr_loading(on_off): diff --git a/src/anemoi/datasets/data/masked.py b/src/anemoi/datasets/data/masked.py index 002bf929..9b799a9b 100644 --- a/src/anemoi/datasets/data/masked.py +++ b/src/anemoi/datasets/data/masked.py @@ -112,5 +112,5 @@ def __init__(self, forward, area): def tree(self): return Node(self, [self.forward.tree()], area=self.area) - def metadata_specific(self, **kwargs): - return super().metadata_specific(area=self.area, **kwargs) + def subclass_metadata_specific(self): + return dict(area=self.area) diff --git a/src/anemoi/datasets/data/rescale.py b/src/anemoi/datasets/data/rescale.py new file mode 100644 index 00000000..299e703b --- /dev/null +++ b/src/anemoi/datasets/data/rescale.py @@ -0,0 +1,147 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from functools import cached_property + +import numpy as np + +from .debug import Node +from .debug import debug_indexing +from .forwards import Forwards +from .indexing import apply_index_to_slices_changes +from .indexing import expand_list_indexing +from .indexing import index_to_slices +from .indexing import update_tuple + +LOG = logging.getLogger(__name__) + + +def make_rescale(variable, rescale): + + if isinstance(rescale, (tuple, list)): + + assert len(rescale) == 2, rescale + + if isinstance(rescale[0], (int, float)): + return rescale + + from cfunits import Units + + u0 = Units(rescale[0]) + u1 = Units(rescale[1]) + + x1, x2 = 0.0, 1.0 + y1, y2 = Units.conform([x1, x2], u0, u1) + + a = (y2 - y1) / (x2 - x1) + b = y1 - a * x1 + + return a, b + + return rescale + + if isinstance(rescale, dict): + assert "scale" in rescale, rescale + assert "offset" in rescale, rescale + return rescale["scale"], rescale["offset"] + + assert False + + +class Rescale(Forwards): + def __init__(self, dataset, rescale): + super().__init__(dataset) + for n in rescale: + assert n in dataset.variables, n + + variables = dataset.variables + + self._a = np.ones(len(variables)) + self._b = np.zeros(len(variables)) + + self.rescale = {} + for i, v in enumerate(variables): + if v in rescale: + a, b = make_rescale(v, rescale[v]) + self.rescale[v] = a, b + self._a[i], self._b[i] = a, b + + self._a = self._a[np.newaxis, :, np.newaxis, np.newaxis] + self._b = self._b[np.newaxis, :, np.newaxis, np.newaxis] + + self._a = self._a.astype(self.forward.dtype) + self._b = self._b.astype(self.forward.dtype) + + def tree(self): + return Node(self, [self.forward.tree()], rescale=self.rescale) + + def subclass_metadata_specific(self): + return dict(rescale=self.rescale) + + @debug_indexing + @expand_list_indexing + def _get_tuple(self, index): + index, changes = index_to_slices(index, self.shape) + index, previous = update_tuple(index, 1, slice(None)) + result = self.forward[index] + result = result * self._a + self._b + result = result[:, previous] + result = apply_index_to_slices_changes(result, changes) + return result + + @debug_indexing + def __get_slice_(self, n): + data = self.forward[n] + return data * self._a + self._b + + @debug_indexing + def __getitem__(self, n): + + if isinstance(n, tuple): + return self._get_tuple(n) + + if isinstance(n, slice): + return self.__get_slice_(n) + + data = self.forward[n] + + return data * self._a[0] + self._b[0] + + @cached_property + def statistics(self): + result = {} + a = self._a.squeeze() + assert np.all(a >= 0) + + b = self._b.squeeze() + for k, v in self.forward.statistics.items(): + if k in ("maximum", "minimum", "mean"): + result[k] = v * a + b + continue + + if k in ("stdev",): + result[k] = v * a + continue + + raise NotImplementedError("rescale statistics", k) + + return result + + def statistics_tendencies(self, delta=None): + result = {} + a = self._a.squeeze() + assert np.all(a >= 0) + + for k, v in self.forward.statistics_tendencies(delta).items(): + if k in ("maximum", "minimum", "mean", "stdev"): + result[k] = v * a + continue + + raise NotImplementedError("rescale tendencies statistics", k) + + return result diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index 062079b2..9ced8281 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -5,6 +5,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. + import logging import os import warnings @@ -83,6 +84,8 @@ def __getitem__(self, key): class DebugStore(ReadOnlyStore): + """A store to debug the zarr loading.""" + def __init__(self, store): assert not isinstance(store, DebugStore) self.store = store @@ -148,6 +151,8 @@ def open_zarr(path, dont_fail=False, cache=None): class Zarr(Dataset): + """A zarr dataset.""" + def __init__(self, path): if isinstance(path, zarr.hierarchy.Group): self.was_zarr = True @@ -244,14 +249,20 @@ def statistics_tendencies(self, delta=None): delta = self.frequency if isinstance(delta, int): delta = f"{delta}h" - from anemoi.datasets.create.loaders import TendenciesStatisticsAddition + from anemoi.utils.dates import frequency_to_string + from anemoi.utils.dates import frequency_to_timedelta + + delta = frequency_to_timedelta(delta) + delta = frequency_to_string(delta) + + def func(k): + return f"statistics_tendencies_{delta}_{k}" - func = TendenciesStatisticsAddition.final_storage_name_from_delta return dict( - mean=self.z[func("mean", delta)][:], - stdev=self.z[func("stdev", delta)][:], - maximum=self.z[func("maximum", delta)][:], - minimum=self.z[func("minimum", delta)][:], + mean=self.z[func("mean")][:], + stdev=self.z[func("stdev")][:], + maximum=self.z[func("maximum")][:], + minimum=self.z[func("minimum")][:], ) @property @@ -322,11 +333,13 @@ def get_dataset_names(self, names): class ZarrWithMissingDates(Zarr): + """A zarr dataset with missing dates.""" + def __init__(self, path): super().__init__(path) missing_dates = self.z.attrs.get("missing_dates", []) - missing_dates = [np.datetime64(x) for x in missing_dates] + missing_dates = set([np.datetime64(x) for x in missing_dates]) self.missing_to_dates = {i: d for i, d in enumerate(self.dates) if d in missing_dates} self.missing = set(self.missing_to_dates) diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index eb7886c8..fe1054ee 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -10,8 +10,10 @@ import warnings # from anemoi.utils.dates import as_datetime +from anemoi.utils.dates import DateTimes from anemoi.utils.dates import as_datetime from anemoi.utils.dates import frequency_to_timedelta +from anemoi.utils.hindcasts import HindcastDatesTimes from anemoi.utils.humanize import print_dates @@ -30,32 +32,32 @@ def extend(x): step = frequency_to_timedelta(step) while start <= end: yield start - start += datetime.timedelta(hours=step) + start += step return yield as_datetime(x) -class Dates: +class DatesProvider: """Base class for date generation. - >>> Dates.from_config(**{"start": "2023-01-01 00:00", "end": "2023-01-02 00:00", "frequency": "1d"}).values + >>> DatesProvider.from_config(**{"start": "2023-01-01 00:00", "end": "2023-01-02 00:00", "frequency": "1d"}).values [datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 2, 0, 0)] - >>> Dates.from_config(**{"start": "2023-01-01 00:00", "end": "2023-01-03 00:00", "frequency": "18h"}).values + >>> DatesProvider.from_config(**{"start": "2023-01-01 00:00", "end": "2023-01-03 00:00", "frequency": "18h"}).values [datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 1, 18, 0), datetime.datetime(2023, 1, 2, 12, 0)] - >>> Dates.from_config(start="2023-01-01 00:00", end="2023-01-02 00:00", frequency=6).as_dict() + >>> DatesProvider.from_config(start="2023-01-01 00:00", end="2023-01-02 00:00", frequency=6).as_dict() {'start': '2023-01-01T00:00:00', 'end': '2023-01-02T00:00:00', 'frequency': '6h'} - >>> len(Dates.from_config(start="2023-01-01 00:00", end="2023-01-02 00:00", frequency=12)) + >>> len(DatesProvider.from_config(start="2023-01-01 00:00", end="2023-01-02 00:00", frequency=12)) 3 - >>> len(Dates.from_config(start="2023-01-01 00:00", + >>> len(DatesProvider.from_config(start="2023-01-01 00:00", ... end="2023-01-02 00:00", ... frequency=12, ... missing=["2023-01-01 12:00"])) 3 - >>> len(Dates.from_config(start="2023-01-01 00:00", + >>> len(DatesProvider.from_config(start="2023-01-01 00:00", ... end="2023-01-02 00:00", ... frequency=12, ... missing=["2099-01-01 12:00"])) @@ -67,12 +69,18 @@ def __init__(self, missing=None): missing = [] self.missing = list(extend(missing)) if set(self.missing) - set(self.values): - warnings.warn(f"Missing dates {self.missing} not in list.") + diff = set(self.missing) - set(self.values) + warnings.warn(f"Missing dates {len(diff)=} not in list.") @classmethod def from_config(cls, **kwargs): + + if kwargs.pop("hindcasts", False): + return HindcastsDates(**kwargs) + if "values" in kwargs: return ValuesDates(**kwargs) + return StartEndDates(**kwargs) def __iter__(self): @@ -89,7 +97,7 @@ def summary(self): return f"πŸ“… {self.values[0]} ... {self.values[-1]}" -class ValuesDates(Dates): +class ValuesDates(DatesProvider): def __init__(self, values, **kwargs): self.values = sorted([as_datetime(_) for _ in values]) super().__init__(**kwargs) @@ -101,8 +109,9 @@ def as_dict(self): return {"values": self.values[0]} -class StartEndDates(Dates): - def __init__(self, start, end, frequency=1, months=None, **kwargs): +class StartEndDates(DatesProvider): + def __init__(self, start, end, frequency=1, **kwargs): + frequency = frequency_to_timedelta(frequency) assert isinstance(frequency, datetime.timedelta), frequency @@ -123,35 +132,108 @@ def _(x): start = as_datetime(start) end = as_datetime(end) - # if end <= start: - # raise ValueError(f"End date {end} must be after start date {start}") - - increment = frequency - self.start = start self.end = end self.frequency = frequency - date = start - self.values = [] - while date <= end: + missing = kwargs.pop("missing", []) - if months is not None: - if date.month not in months: - date += increment - continue + self.values = list(DateTimes(start, end, increment=frequency, **kwargs)) + self.kwargs = kwargs - self.values.append(date) - date += increment - - super().__init__(**kwargs) + super().__init__(missing=missing) def as_dict(self): return { "start": self.start.isoformat(), "end": self.end.isoformat(), - "frequency": f"{self.frequency}h", - } + "frequency": frequency_to_string(self.frequency), + }.update(self.kwargs) + + +class Hindcast: + + def __init__(self, date, refdate, hdate, step): + self.date = date + self.refdate = refdate + self.hdate = hdate + self.step = step + + +class HindcastsDates(DatesProvider): + def __init__(self, start, end, steps=[0], years=20, **kwargs): + + if not isinstance(start, list): + start = [start] + end = [end] + + reference_dates = [] + for s, e in zip(start, end): + reference_dates.extend(list(DateTimes(s, e, increment=24, **kwargs))) + # reference_dates = list(DateTimes(start, end, increment=24, **kwargs)) + dates = [] + + seen = {} + + for hdate, refdate in HindcastDatesTimes(reference_dates=reference_dates, years=years): + assert refdate - hdate >= datetime.timedelta(days=365), (refdate - hdate, refdate, hdate) + for step in steps: + + date = hdate + datetime.timedelta(hours=step) + + if date in seen: + raise ValueError(f"Duplicate date {date}={hdate}+{step} for {refdate} and {seen[date]}") + + seen[date] = Hindcast(date, refdate, hdate, step) + + assert refdate - date > datetime.timedelta(days=360), (refdate - date, refdate, date, hdate, step) + + dates.append(date) + + dates = sorted(dates) + + mindelta = None + for a, b in zip(dates, dates[1:]): + delta = b - a + assert isinstance(delta, datetime.timedelta), delta + if mindelta is None: + mindelta = delta + else: + mindelta = min(mindelta, delta) + + self.frequency = mindelta + assert mindelta.total_seconds() > 0, mindelta + + print("πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯", dates[0], dates[-1], mindelta) + + # Use all values between start and end by frequency, and set the ones that are missing + self.values = [] + missing = [] + date = dates[0] + last = date + print("------", date, dates[-1]) + dateset = set(dates) + while date <= dates[-1]: + self.values.append(date) + if date not in dateset: + missing.append(date) + seen[date] = seen[last] + else: + last = date + date = date + mindelta + + self.mapping = seen + + print("πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯", self.values[0], self.values[-1], mindelta) + print("πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯πŸ”₯", f"{len(self.values)=} - {len(missing)=}") + + super().__init__(missing=missing) + + def __repr__(self): + return f"{self.__class__.__name__}({self.values[0]}..{self.values[-1]})" + + def as_dict(self): + return {"hindcasts": self.hindcasts} if __name__ == "__main__": diff --git a/src/anemoi/datasets/dates/groups.py b/src/anemoi/datasets/dates/groups.py index 36a071f1..624f308e 100644 --- a/src/anemoi/datasets/dates/groups.py +++ b/src/anemoi/datasets/dates/groups.py @@ -7,11 +7,34 @@ import itertools +from functools import cached_property -from anemoi.datasets.dates import Dates +from anemoi.datasets.create.input import shorten +from anemoi.datasets.dates import DatesProvider from anemoi.datasets.dates import as_datetime +class GroupOfDates: + def __init__(self, dates, provider): + assert isinstance(provider, DatesProvider), type(provider) + assert isinstance(dates, list) + + self.dates = dates + self.provider = provider + + def __len__(self): + return len(self.dates) + + def __iter__(self): + return iter(self.dates) + + def __repr__(self) -> str: + return f"GroupOfDates(dates={shorten(self.dates)})" + + def __eq__(self, other: object) -> bool: + return isinstance(other, GroupOfDates) and self.dates == other.dates + + class Groups: """>>> list(Groups(group_by="daily", start="2023-01-01 00:00", end="2023-01-05 00:00", frequency=12))[0] [datetime.datetime(2023, 1, 1, 0, 0), datetime.datetime(2023, 1, 1, 12, 0)] @@ -41,33 +64,48 @@ class Groups: def __init__(self, **kwargs): group_by = kwargs.pop("group_by") - self.dates = Dates.from_config(**kwargs) - self.grouper = Grouper.from_config(group_by) - self.filter = Filter(self.dates.missing) + self._dates = DatesProvider.from_config(**kwargs) + self._grouper = Grouper.from_config(group_by) + self._filter = Filter(self._dates.missing) + + @property + def provider(self): + return self._dates def __iter__(self): - for dates in self.grouper(self.dates): - dates = self.filter(dates) + for go in self._grouper(self._dates): + dates = self._filter(go.dates) if not dates: continue - yield dates + yield GroupOfDates(dates, go.provider) def __len__(self): - count = 0 - for dates in self.grouper(self.dates): - dates = self.filter(dates) + return self._len + + @cached_property + def _len(self): + n = 0 + for go in self._grouper(self._dates): + dates = self._filter(go.dates) if not dates: continue - count += 1 - return count + n += 1 + return n def __repr__(self): - return f"{self.__class__.__name__}(dates={len(self)})" + return f"{self.__class__.__name__}(dates={len(self)},{shorten(self._dates)})" + + def describe(self): + return self.dates.summary + + def one_date(self): + go = next(iter(self)) + return GroupOfDates([go.dates[0]], go.provider) class Filter: def __init__(self, missing): - self.missing = [as_datetime(m) for m in missing] + self.missing = set(as_datetime(m) for m in missing) def __call__(self, dates): return [d for d in dates if d not in self.missing] @@ -76,10 +114,16 @@ def __call__(self, dates): class Grouper: @classmethod def from_config(cls, group_by): + if isinstance(group_by, int) and group_by > 0: return GrouperByFixedSize(group_by) + if group_by is None: return GrouperOneGroup() + + if group_by == "reference_date": + return ReferenceDateGroup() + key = { "monthly": lambda dt: (dt.year, dt.month), "daily": lambda dt: (dt.year, dt.month, dt.day), @@ -89,30 +133,51 @@ def from_config(cls, group_by): return GrouperByKey(key) +class ReferenceDateGroup(Grouper): + def __call__(self, dates): + assert isinstance(dates, DatesProvider), type(dates) + + mapping = dates.mapping + + def same_refdate(dt): + return mapping[dt].refdate + + for _, g in itertools.groupby(sorted(dates, key=same_refdate), key=same_refdate): + yield GroupOfDates(list(g), dates) + + class GrouperOneGroup(Grouper): def __call__(self, dates): - yield dates.values + assert isinstance(dates, DatesProvider), type(dates) + + yield GroupOfDates(dates.values, dates) class GrouperByKey(Grouper): + """Group dates by a key.""" + def __init__(self, key): self.key = key def __call__(self, dates): - for _, g in itertools.groupby(dates, key=self.key): - yield list(g) + for _, g in itertools.groupby(sorted(dates, key=self.key), key=self.key): + yield GroupOfDates(list(g), dates) class GrouperByFixedSize(Grouper): + """Group dates by a fixed size.""" + def __init__(self, size): self.size = size def __call__(self, dates): batch = [] + for d in dates: batch.append(d) if len(batch) == self.size: - yield batch + yield GroupOfDates(batch, dates) batch = [] + if batch: - yield batch + yield GroupOfDates(batch, dates) diff --git a/tests/create/test_create.py b/tests/create/test_create.py index c553ef9b..2112cd78 100755 --- a/tests/create/test_create.py +++ b/tests/create/test_create.py @@ -117,24 +117,29 @@ def compare_dot_zattrs(a, b, path, errors): "dataset_status", "total_size", ]: - if type(a[k]) != type(b[k]): # noqa : E721 + if type(a[k]) is not type(b[k]): errors.append(f"❌ {path}.{k} : type differs {type(a[k])} != {type(b[k])}") continue + compare_dot_zattrs(a[k], b[k], f"{path}.{k}", errors) + return if isinstance(a, list): if len(a) != len(b): errors.append(f"❌ {path} : lengths are different {len(a)} != {len(b)}") return + for i, (v, w) in enumerate(zip(a, b)): compare_dot_zattrs(v, w, f"{path}.{i}", errors) + return - if type(a) != type(b): # noqa : E721 + if type(a) is not type(b): msg = f"❌ {path} actual != expected : {a} ({type(a)}) != {b} ({type(b)})" errors.append(msg) return + if a != b: msg = f"❌ {path} actual != expected : {a} != {b}" errors.append(msg) diff --git a/tests/xarray/test_netcdf.py b/tests/xarray/test_netcdf.py index bd1245c6..18c49365 100644 --- a/tests/xarray/test_netcdf.py +++ b/tests/xarray/test_netcdf.py @@ -50,6 +50,8 @@ def skip_test_netcdf(): assert len(fs) == checks["length"], (url, len(fs)) + print(fs[0].datetime()) + if __name__ == "__main__": skip_test_netcdf() diff --git a/tests/xarray/test_opendap.py b/tests/xarray/test_opendap.py index 6ae3981f..6c544b70 100644 --- a/tests/xarray/test_opendap.py +++ b/tests/xarray/test_opendap.py @@ -19,6 +19,8 @@ def test_opendap(): assert len(fs) == 79529 + print(fs[0].datetime()) + if __name__ == "__main__": for name, obj in list(globals().items()): diff --git a/tests/xarray/test_zarr.py b/tests/xarray/test_zarr.py index 509e3afe..7f3acec6 100644 --- a/tests/xarray/test_zarr.py +++ b/tests/xarray/test_zarr.py @@ -10,7 +10,7 @@ from anemoi.datasets.create.functions.sources.xarray import XarrayFieldList -def test_arco_era5(): +def test_arco_era5_1(): ds = xr.open_zarr( "gs://gcp-public-data-arco-era5/ar/1959-2022-full_37-1h-0p25deg-chunk-1.zarr-v2", @@ -21,6 +21,25 @@ def test_arco_era5(): fs = XarrayFieldList.from_xarray(ds) print(len(fs)) + print(fs[0].datetime()) + + print(fs[-1].metadata()) + # print(fs[-1].to_numpy()) + + assert len(fs) == 128677526 + + +def test_arco_era5_2(): + + ds = xr.open_zarr( + "gs://gcp-public-data-arco-era5/ar/1959-2022-1h-360x181_equiangular_with_poles_conservative.zarr", + chunks={"time": 48}, + consolidated=True, + ) + + fs = XarrayFieldList.from_xarray(ds) + print(len(fs)) + print(fs[-1].metadata()) # print(fs[-1].to_numpy()) @@ -50,6 +69,8 @@ def test_weatherbench(): assert fs[0].metadata("valid_datetime") == "2020-01-01T06:00:00", fs[0].metadata("valid_datetime") assert fs[-1].metadata("valid_datetime") == "2021-01-10T12:00:00", fs[-1].metadata("valid_datetime") + print(fs[0].datetime()) + def test_inca_one_date(): url = "https://object-store.os-api.cci1.ecmwf.int/ml-tests/test-data/example-inca-one-date.zarr" @@ -65,8 +86,12 @@ def test_inca_one_date(): assert f.metadata("number") == 0 assert f.metadata("variable") == vars[i] + print(fs[0].datetime()) + if __name__ == "__main__": + test_arco_era5_2() + exit() for name, obj in list(globals().items()): if name.startswith("test_") and callable(obj): print(f"Running {name}...")