From a1fe8594d7e28c51416ef3a9ccaf228d91feb8f9 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Wed, 5 Jun 2024 11:58:12 +0100 Subject: [PATCH] use plugins --- .pre-commit-config.yaml | 2 +- docs/test.ipynb | 2 +- pyproject.toml | 2 ++ .../datasets/create/functions/__init__.py | 23 +++++++++++++++++++ src/anemoi/datasets/create/input.py | 10 +------- 5 files changed, 28 insertions(+), 11 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 627afff3..4064b4f9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -64,7 +64,7 @@ repos: exclude: 'cli/.*' # Because we use argparse - repo: https://github.com/b8raoult/pre-commit-docconvert - rev: "0.1.4" + rev: "0.1.5" hooks: - id: docconvert args: ["numpy"] diff --git a/docs/test.ipynb b/docs/test.ipynb index a92168bd..1af66598 100644 --- a/docs/test.ipynb +++ b/docs/test.ipynb @@ -113,7 +113,7 @@ } ], "source": [ - "print(len(ds))" + "print(len(ds," ] }, { diff --git a/pyproject.toml b/pyproject.toml index e3d15ec4..59b20629 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ optional-dependencies.all = [ "climetlab>=0.22.1", "earthkit-meteo", "ecmwflibs>=0.6.3", + "entrypoints", "numpy", "pyproj", "pyyaml", @@ -77,6 +78,7 @@ optional-dependencies.create = [ "climetlab>=0.22.1", # "earthkit-data" "earthkit-meteo", "ecmwflibs>=0.6.3", + "entrypoints", "pyproj", ] optional-dependencies.dev = [ diff --git a/src/anemoi/datasets/create/functions/__init__.py b/src/anemoi/datasets/create/functions/__init__.py index 337d3ba1..0dc5dde8 100644 --- a/src/anemoi/datasets/create/functions/__init__.py +++ b/src/anemoi/datasets/create/functions/__init__.py @@ -8,7 +8,30 @@ # +import importlib + +import entrypoints + + def assert_is_fieldset(obj): from climetlab.readers.grib.index import FieldSet assert isinstance(obj, FieldSet), type(obj) + + +def import_function(name, kind): + + name = name.replace("-", "_") + + plugins = {} + for e in entrypoints.get_group_all(f"anemoi.datasets.{kind}s"): + plugins[e.name.replace("_", "-")] = e + + if name in plugins: + return plugins[name].load() + + module = importlib.import_module( + f".{kind}.{name}", + package=__name__, + ) + return module.execute diff --git a/src/anemoi/datasets/create/input.py b/src/anemoi/datasets/create/input.py index faef1b19..173d362e 100644 --- a/src/anemoi/datasets/create/input.py +++ b/src/anemoi/datasets/create/input.py @@ -7,7 +7,6 @@ # nor does it submit to any jurisdiction. # import datetime -import importlib import logging import time from collections import defaultdict @@ -21,6 +20,7 @@ from anemoi.datasets.dates import Dates +from .functions import import_function from .template import Context from .template import notify_result from .template import resolve @@ -65,14 +65,6 @@ def time_delta_to_string(delta): return f"minus_{hours}h" -def import_function(name, kind): - module = importlib.import_module( - f"..functions.{kind}.{name}", - package=__name__, - ) - return module.execute - - def is_function(name, kind): name, delta = parse_function_name(name) # noqa try: