Skip to content

Commit

Permalink
Bugfix/failing tests (#35)
Browse files Browse the repository at this point in the history
* fix and add tests

* add --cache option

* changelog
  • Loading branch information
floriankrb authored Sep 12, 2024
1 parent ad919bb commit 4c11a33
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 24 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ Keep it human-readable, your future self will thank you!

### Added
### Changed

- Added incremental building of datasets

### Removed

## [0.4.5]
Expand Down
19 changes: 10 additions & 9 deletions src/anemoi/datasets/commands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,37 @@ class Init(Command):
internal = True
timestamp = True

def add_arguments(self, init):
def add_arguments(self, subparser):

init.add_argument("config", help="Configuration yaml file defining the recipe to create the dataset.")
init.add_argument("path", help="Path to store the created data.")
subparser.add_argument("config", help="Configuration yaml file defining the recipe to create the dataset.")
subparser.add_argument("path", help="Path to store the created data.")

init.add_argument(
subparser.add_argument(
"--overwrite",
action="store_true",
help="Overwrite existing files. This will delete the target dataset if it already exists.",
)
init.add_argument(
subparser.add_argument(
"--test",
action="store_true",
help="Build a small dataset, using only the first dates. And, when possible, using low resolution and less ensemble members.",
)
init.add_argument(
subparser.add_argument(
"--check-name",
dest="check_name",
action="store_true",
help="Check if the dataset name is valid before creating it.",
)
init.add_argument(
subparser.add_argument(
"--no-check-name",
dest="check_name",
action="store_false",
help="Do not check if the dataset name is valid before creating it.",
)
init.set_defaults(check_name=False)
subparser.set_defaults(check_name=False)
subparser.add_argument("--cache", help="Location to store the downloaded data.", metavar="DIR")

init.add_argument("--trace", action="store_true")
subparser.add_argument("--trace", action="store_true")

def run(self, args):
options = vars(args)
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/datasets/commands/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def add_arguments(self, subparser):
# )

subparser.add_argument("path", help="Path to store the created data.")
subparser.add_argument("--cache", help="Location to store the downloaded data.", metavar="DIR")
subparser.add_argument("--trace", action="store_true")

def run(self, args):
Expand Down
12 changes: 6 additions & 6 deletions src/anemoi/datasets/create/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,14 +203,14 @@ def __init__(self, path, overwrite=False):


class Actor: # TODO: rename to Creator
cache = None
dataset_class = WritableDataset

def __init__(self, path):
def __init__(self, path, cache=None):
# Catch all floating point errors, including overflow, sqrt(<0), etc
np.seterr(all="raise", under="warn")

self.path = path
self.cache = cache
self.dataset = self.dataset_class(self.path)

def run(self):
Expand Down Expand Up @@ -325,11 +325,11 @@ def build_input_(main_config, output_config):

class Init(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin):
dataset_class = NewDataset
def __init__(self, path, config, check_name=False, overwrite=False, use_threads=False, statistics_temp_dir=None, progress=None, test=False, **kwargs): # fmt: skip
def __init__(self, path, config, check_name=False, overwrite=False, use_threads=False, statistics_temp_dir=None, progress=None, test=False, cache=None, **kwargs): # fmt: skip
if _path_readable(path) and not overwrite:
raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.")

super().__init__(path)
super().__init__(path, cache=cache)
self.config = config
self.check_name = check_name
self.use_threads = use_threads
Expand Down Expand Up @@ -497,8 +497,8 @@ def sanity_check_config(a, b):


class Load(Actor, HasRegistryMixin, HasStatisticTempMixin, HasElementForDataMixin):
def __init__(self, path, parts=None, use_threads=False, statistics_temp_dir=None, progress=None, **kwargs): # fmt: skip
super().__init__(path)
def __init__(self, path, parts=None, use_threads=False, statistics_temp_dir=None, progress=None, cache=None, **kwargs): # fmt: skip
super().__init__(path, cache=cache)
self.use_threads = use_threads
self.statistics_temp_dir = statistics_temp_dir
self.progress = progress
Expand Down
20 changes: 14 additions & 6 deletions src/anemoi/datasets/create/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from anemoi.utils.config import load_any_dict_format
from earthkit.data.core.order import normalize_order_by

from anemoi.datasets.dates.groups import Groups

LOG = logging.getLogger(__name__)


Expand Down Expand Up @@ -212,11 +214,11 @@ def _prepare_serialisation(o):
def set_to_test_mode(cfg):
NUMBER_OF_DATES = 4

dates = cfg.dates
dates = cfg["dates"]
LOG.warn(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.")
groups = Groups(**cfg.dates)
groups = Groups(**LoadersConfig(cfg).dates)
dates = groups.dates
cfg.dates = dict(
cfg["dates"] = dict(
start=dates[0],
end=dates[NUMBER_OF_DATES - 1],
frequency=dates.frequency,
Expand Down Expand Up @@ -250,16 +252,22 @@ def set_element_to_test(obj):

def loader_config(config, is_test=False):
config = Config(config)
obj = LoadersConfig(config)
if is_test:
obj = set_to_test_mode(obj)
set_to_test_mode(config)
obj = LoadersConfig(config)

# yaml round trip to check that serialisation works as expected
copy = obj.get_serialisable_dict()
copy = yaml.load(yaml.dump(copy), Loader=yaml.SafeLoader)
copy = Config(copy)
copy = LoadersConfig(config)
assert yaml.dump(obj) == yaml.dump(copy), (obj, copy)

a = yaml.dump(obj)
b = yaml.dump(copy)
if a != b:
print(a)
print(b)
raise ValueError("Serialisation failed")

return copy

Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/datasets/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def metadata(self):
import anemoi

def tidy(v):
if isinstance(v, (list, tuple)):
if isinstance(v, (list, tuple, set)):
return [tidy(i) for i in v]
if isinstance(v, dict):
return {k: tidy(v) for k, v in v.items()}
Expand Down
2 changes: 0 additions & 2 deletions src/anemoi/datasets/data/subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,8 @@ def _get_slice(self, s):
@expand_list_indexing
def _get_tuple(self, n):
index, changes = index_to_slices(n, self.shape)
# print('INDEX', index, changes)
indices = [self.indices[i] for i in range(*index[0].indices(self._len))]
indices = make_slice_or_index_from_list_or_tuple(indices)
# print('INDICES', indices)
index, _ = update_tuple(index, 0, indices)
result = self.dataset[index]
result = apply_index_to_slices_changes(result, changes)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,18 @@ def test_dates():
assert as_last_date("2021-01-01", dates) == np.datetime64("2021-01-01T23:59:59")


@mockup_open_zarr
def test_dates_using_list():
dates = [np.datetime64("2021-01-01T00:00:00") + i * np.timedelta64(6, "h") for i in range(3, 365 * 4 - 2)]
assert dates[0] == np.datetime64("2021-01-01T18:00:00")
assert dates[-1] == np.datetime64("2021-12-31T06:00:00")

assert as_first_date(2021, dates) == np.datetime64("2021-01-01T18:00:00")
assert as_last_date(2021, dates) == np.datetime64("2021-12-31T06:00:00")
assert as_first_date("2021", dates) == np.datetime64("2021-01-01T18:00:00")
assert as_last_date("2021", dates) == np.datetime64("2021-12-31T06:00:00")


@mockup_open_zarr
def test_slice_1():
test = DatasetTester("test-2021-2021-6h-o96-abcd")
Expand Down

0 comments on commit 4c11a33

Please sign in to comment.