From 4c11a334e04b9b57e1a106aaea062d030c29bb36 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Thu, 12 Sep 2024 09:53:11 +0200 Subject: [PATCH] Bugfix/failing tests (#35) * fix and add tests * add --cache option * changelog --- CHANGELOG.md | 3 +++ src/anemoi/datasets/commands/init.py | 19 ++++++++++--------- src/anemoi/datasets/commands/load.py | 1 + src/anemoi/datasets/create/__init__.py | 12 ++++++------ src/anemoi/datasets/create/config.py | 20 ++++++++++++++------ src/anemoi/datasets/data/dataset.py | 2 +- src/anemoi/datasets/data/subset.py | 2 -- tests/test_data.py | 12 ++++++++++++ 8 files changed, 47 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index acf87eff..e12e0d8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ Keep it human-readable, your future self will thank you! ### Added ### Changed + +- Added incremental building of datasets + ### Removed ## [0.4.5] diff --git a/src/anemoi/datasets/commands/init.py b/src/anemoi/datasets/commands/init.py index b83d5033..d03d61a0 100644 --- a/src/anemoi/datasets/commands/init.py +++ b/src/anemoi/datasets/commands/init.py @@ -16,36 +16,37 @@ class Init(Command): internal = True timestamp = True - def add_arguments(self, init): + def add_arguments(self, subparser): - init.add_argument("config", help="Configuration yaml file defining the recipe to create the dataset.") - init.add_argument("path", help="Path to store the created data.") + subparser.add_argument("config", help="Configuration yaml file defining the recipe to create the dataset.") + subparser.add_argument("path", help="Path to store the created data.") - init.add_argument( + subparser.add_argument( "--overwrite", action="store_true", help="Overwrite existing files. This will delete the target dataset if it already exists.", ) - init.add_argument( + subparser.add_argument( "--test", action="store_true", help="Build a small dataset, using only the first dates. And, when possible, using low resolution and less ensemble members.", ) - init.add_argument( + subparser.add_argument( "--check-name", dest="check_name", action="store_true", help="Check if the dataset name is valid before creating it.", ) - init.add_argument( + subparser.add_argument( "--no-check-name", dest="check_name", action="store_false", help="Do not check if the dataset name is valid before creating it.", ) - init.set_defaults(check_name=False) + subparser.set_defaults(check_name=False) + subparser.add_argument("--cache", help="Location to store the downloaded data.", metavar="DIR") - init.add_argument("--trace", action="store_true") + subparser.add_argument("--trace", action="store_true") def run(self, args): options = vars(args) diff --git a/src/anemoi/datasets/commands/load.py b/src/anemoi/datasets/commands/load.py index adb4f7ff..d776ea10 100644 --- a/src/anemoi/datasets/commands/load.py +++ b/src/anemoi/datasets/commands/load.py @@ -25,6 +25,7 @@ def add_arguments(self, subparser): # ) subparser.add_argument("path", help="Path to store the created data.") + subparser.add_argument("--cache", help="Location to store the downloaded data.", metavar="DIR") subparser.add_argument("--trace", action="store_true") def run(self, args): diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index d4a2f504..d12df1ba 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -203,14 +203,14 @@ def __init__(self, path, overwrite=False): class Actor: # TODO: rename to Creator - cache = None dataset_class = WritableDataset - def __init__(self, path): + def __init__(self, path, cache=None): # Catch all floating point errors, including overflow, sqrt(<0), etc np.seterr(all="raise", under="warn") self.path = path + self.cache = cache self.dataset = self.dataset_class(self.path) def run(self): @@ -325,11 +325,11 @@ def build_input_(main_config, output_config): class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): dataset_class = NewDataset - def __init__(self, path, config, check_name=False, overwrite=False, use_threads=False, statistics_temp_dir=None, progress=None, test=False, **kwargs): # fmt: skip + def __init__(self, path, config, check_name=False, overwrite=False, use_threads=False, statistics_temp_dir=None, progress=None, test=False, cache=None, **kwargs): # fmt: skip if _path_readable(path) and not overwrite: raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.") - super().__init__(path) + super().__init__(path, cache=cache) self.config = config self.check_name = check_name self.use_threads = use_threads @@ -497,8 +497,8 @@ def sanity_check_config(a, b): class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin): - def __init__(self, path, parts=None, use_threads=False, statistics_temp_dir=None, progress=None, **kwargs): # fmt: skip - super().__init__(path) + def __init__(self, path, parts=None, use_threads=False, statistics_temp_dir=None, progress=None, cache=None, **kwargs): # fmt: skip + super().__init__(path, cache=cache) self.use_threads = use_threads self.statistics_temp_dir = statistics_temp_dir self.progress = progress diff --git a/src/anemoi/datasets/create/config.py b/src/anemoi/datasets/create/config.py index 6f0bad53..e292a6e5 100644 --- a/src/anemoi/datasets/create/config.py +++ b/src/anemoi/datasets/create/config.py @@ -16,6 +16,8 @@ from anemoi.utils.config import load_any_dict_format from earthkit.data.core.order import normalize_order_by +from anemoi.datasets.dates.groups import Groups + LOG = logging.getLogger(__name__) @@ -212,11 +214,11 @@ def _prepare_serialisation(o): def set_to_test_mode(cfg): NUMBER_OF_DATES = 4 - dates = cfg.dates + dates = cfg["dates"] LOG.warn(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.") - groups = Groups(**cfg.dates) + groups = Groups(**LoadersConfig(cfg).dates) dates = groups.dates - cfg.dates = dict( + cfg["dates"] = dict( start=dates[0], end=dates[NUMBER_OF_DATES - 1], frequency=dates.frequency, @@ -250,16 +252,22 @@ def set_element_to_test(obj): def loader_config(config, is_test=False): config = Config(config) - obj = LoadersConfig(config) if is_test: - obj = set_to_test_mode(obj) + set_to_test_mode(config) + obj = LoadersConfig(config) # yaml round trip to check that serialisation works as expected copy = obj.get_serialisable_dict() copy = yaml.load(yaml.dump(copy), Loader=yaml.SafeLoader) copy = Config(copy) copy = LoadersConfig(config) - assert yaml.dump(obj) == yaml.dump(copy), (obj, copy) + + a = yaml.dump(obj) + b = yaml.dump(copy) + if a != b: + print(a) + print(b) + raise ValueError("Serialisation failed") return copy diff --git a/src/anemoi/datasets/data/dataset.py b/src/anemoi/datasets/data/dataset.py index 2dfc9816..2e95ec4f 100644 --- a/src/anemoi/datasets/data/dataset.py +++ b/src/anemoi/datasets/data/dataset.py @@ -216,7 +216,7 @@ def metadata(self): import anemoi def tidy(v): - if isinstance(v, (list, tuple)): + if isinstance(v, (list, tuple, set)): return [tidy(i) for i in v] if isinstance(v, dict): return {k: tidy(v) for k, v in v.items()} diff --git a/src/anemoi/datasets/data/subset.py b/src/anemoi/datasets/data/subset.py index 3ace75be..3c36a0c2 100644 --- a/src/anemoi/datasets/data/subset.py +++ b/src/anemoi/datasets/data/subset.py @@ -111,10 +111,8 @@ def _get_slice(self, s): @expand_list_indexing def _get_tuple(self, n): index, changes = index_to_slices(n, self.shape) - # print('INDEX', index, changes) indices = [self.indices[i] for i in range(*index[0].indices(self._len))] indices = make_slice_or_index_from_list_or_tuple(indices) - # print('INDICES', indices) index, _ = update_tuple(index, 0, indices) result = self.dataset[index] result = apply_index_to_slices_changes(result, changes) diff --git a/tests/test_data.py b/tests/test_data.py index 42e2a001..2925aa8c 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -861,6 +861,18 @@ def test_dates(): assert as_last_date("2021-01-01", dates) == np.datetime64("2021-01-01T23:59:59") +@mockup_open_zarr +def test_dates_using_list(): + dates = [np.datetime64("2021-01-01T00:00:00") + i * np.timedelta64(6, "h") for i in range(3, 365 * 4 - 2)] + assert dates[0] == np.datetime64("2021-01-01T18:00:00") + assert dates[-1] == np.datetime64("2021-12-31T06:00:00") + + assert as_first_date(2021, dates) == np.datetime64("2021-01-01T18:00:00") + assert as_last_date(2021, dates) == np.datetime64("2021-12-31T06:00:00") + assert as_first_date("2021", dates) == np.datetime64("2021-01-01T18:00:00") + assert as_last_date("2021", dates) == np.datetime64("2021-12-31T06:00:00") + + @mockup_open_zarr def test_slice_1(): test = DatasetTester("test-2021-2021-6h-o96-abcd")