Skip to content

Commit

Permalink
Revert "Feature/merge (#126)" "Feature/new checkpoints (#107)" "uploa…
Browse files Browse the repository at this point in the history
…d with ssh (#94)"
  • Loading branch information
floriankrb committed Nov 14, 2024
1 parent a5a9ff7 commit 008a12d
Show file tree
Hide file tree
Showing 12 changed files with 103 additions and 187 deletions.
7 changes: 1 addition & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@ Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-datasets/compare/0.5.8...HEAD)


### Added

- Call filters from anemoi-transform
### Changed
- make test optional when adls is not installed Pull request #110


## [0.5.8](https://github.com/ecmwf/anemoi-datasets/compare/0.5.7...0.5.8) - 2024-10-26

### Changed
Expand All @@ -38,7 +34,6 @@ Keep it human-readable, your future self will thank you!

### Changed

- Upload with ssh (experimental)
- Remove upstream dependencies from downstream-ci workflow (temporary) (#83)
- ci: pin python versions to 3.9 ... 3.12 for checks (#93)
- Fix `__version__` import in init
Expand Down
101 changes: 67 additions & 34 deletions src/anemoi/datasets/commands/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@

import logging
import os
import shutil
import sys
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import as_completed

import tqdm
from anemoi.utils.remote import Transfer
from anemoi.utils.remote import TransferMethodNotImplementedError
from anemoi.utils.s3 import download
from anemoi.utils.s3 import upload

from . import Command

Expand All @@ -28,7 +29,54 @@
isatty = False


class ZarrCopier:
class S3Downloader:
def __init__(self, source, target, transfers, overwrite, resume, verbosity, **kwargs):
self.source = source
self.target = target
self.transfers = transfers
self.overwrite = overwrite
self.resume = resume
self.verbosity = verbosity

def run(self):
if self.target == ".":
self.target = os.path.basename(self.source)

if self.overwrite and os.path.exists(self.target):
LOG.info(f"Deleting {self.target}")
shutil.rmtree(self.target)

download(
self.source + "/" if not self.source.endswith("/") else self.source,
self.target,
overwrite=self.overwrite,
resume=self.resume,
verbosity=self.verbosity,
threads=self.transfers,
)


class S3Uploader:
def __init__(self, source, target, transfers, overwrite, resume, verbosity, **kwargs):
self.source = source
self.target = target
self.transfers = transfers
self.overwrite = overwrite
self.resume = resume
self.verbosity = verbosity

def run(self):
upload(
self.source,
self.target,
overwrite=self.overwrite,
resume=self.resume,
verbosity=self.verbosity,
threads=self.transfers,
)


class DefaultCopier:
def __init__(self, source, target, transfers, block_size, overwrite, resume, verbosity, nested, rechunk, **kwargs):
self.source = source
self.target = target
Expand All @@ -42,14 +90,6 @@ def __init__(self, source, target, transfers, block_size, overwrite, resume, ver

self.rechunking = rechunk.split(",") if rechunk else []

source_is_ssh = self.source.startswith("ssh://")
target_is_ssh = self.target.startswith("ssh://")

if source_is_ssh or target_is_ssh:
if self.rechunk:
raise NotImplementedError("Rechunking with SSH not implemented.")
assert NotImplementedError("SSH not implemented.")

def _store(self, path, nested=False):
if nested:
import zarr
Expand Down Expand Up @@ -297,33 +337,26 @@ def run(self, args):
if args.source == args.target:
raise ValueError("Source and target are the same.")

kwargs = vars(args)

if args.overwrite and args.resume:
raise ValueError("Cannot use --overwrite and --resume together.")

if not args.rechunk:
# rechunking is only supported for ZARR datasets, it is implemented in this package
try:
if args.source.startswith("s3://") and not args.source.endswith("/"):
args.source = args.source + "/"
copier = Transfer(
args.source,
args.target,
overwrite=args.overwrite,
resume=args.resume,
verbosity=args.verbosity,
threads=args.transfers,
)
copier.run()
return
except TransferMethodNotImplementedError:
# DataTransfer relies on anemoi-utils which is agnostic to the source and target format
# it transfers file and folders, ignoring that it is zarr data
# if it is not implemented, we fallback to the ZarrCopier
pass

copier = ZarrCopier(**vars(args))
source_in_s3 = args.source.startswith("s3://")
target_in_s3 = args.target.startswith("s3://")

copier = None

if args.rechunk or (source_in_s3 and target_in_s3):
copier = DefaultCopier(**kwargs)
else:
if source_in_s3:
copier = S3Downloader(**kwargs)

if target_in_s3:
copier = S3Uploader(**kwargs)

copier.run()
return


class Copy(CopyMixin, Command):
Expand Down
26 changes: 5 additions & 21 deletions src/anemoi/datasets/create/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ def assert_is_fieldlist(obj):

def import_function(name, kind):

from anemoi.transform.filters import filter_registry
from anemoi.transforms import Transform as Transform

name = name.replace("-", "_")

plugins = {}
Expand All @@ -33,21 +30,8 @@ def import_function(name, kind):
if name in plugins:
return plugins[name].load()

try:
module = importlib.import_module(
f".{kind}.{name}",
package=__name__,
)
return module.execute
except ModuleNotFoundError:
pass

if kind == "filters":
if filter_registry.lookup(name, return_none=True):

def proc(context, data, *args, **kwargs):
return filter_registry.create(name, *args, **kwargs)(data)

return proc

raise ValueError(f"Unknown {kind} '{name}'")
module = importlib.import_module(
f".{kind}.{name}",
package=__name__,
)
return module.execute
4 changes: 1 addition & 3 deletions src/anemoi/datasets/create/functions/filters/rename.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ class RenamedFieldMapping:
def __init__(self, field, what, renaming):
self.field = field
self.what = what
self.renaming = {}
for k, v in renaming.items():
self.renaming[k] = {str(a): str(b) for a, b in v.items()}
self.renaming = renaming

def metadata(self, key=None, **kwargs):
if key is None:
Expand Down
86 changes: 8 additions & 78 deletions src/anemoi/datasets/create/functions/sources/mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# nor does it submit to any jurisdiction.

import datetime
import re

from anemoi.utils.humanize import did_you_mean
from earthkit.data import from_source
Expand All @@ -33,25 +32,6 @@ def _date_to_datetime(d):
return datetime.datetime.fromisoformat(d)


def expand_to_by(x):

if isinstance(x, (str, int)):
return expand_to_by(str(x).split("/"))

if len(x) == 3 and x[1] == "to":
start = int(x[0])
end = int(x[2])
return list(range(start, end + 1))

if len(x) == 5 and x[1] == "to" and x[3] == "by":
start = int(x[0])
end = int(x[2])
by = int(x[4])
return list(range(start, end + 1, by))

return x


def normalise_time_delta(t):
if isinstance(t, datetime.timedelta):
assert t == datetime.timedelta(hours=t.hours), t
Expand All @@ -63,48 +43,25 @@ def normalise_time_delta(t):
return t


def _normalise_time(t):
t = int(t)
if t < 100:
t * 100
return "{:04d}".format(t)


def _expand_mars_request(request, date, request_already_using_valid_datetime=False, date_key="date"):
requests = []

user_step = to_list(expand_to_by(request.get("step", [0])))
user_time = None
user_date = None

if not request_already_using_valid_datetime:
user_time = request.get("time")
if user_time is not None:
user_time = to_list(user_time)
user_time = [_normalise_time(t) for t in user_time]

user_date = request.get(date_key)
if user_date is not None:
assert isinstance(user_date, str), user_date
user_date = re.compile("^{}$".format(user_date.replace("-", "").replace("?", ".")))

for step in user_step:
step = to_list(request.get("step", [0]))
for s in step:
r = request.copy()

if not request_already_using_valid_datetime:

if isinstance(step, str) and "-" in step:
assert step.count("-") == 1, step

if isinstance(s, str) and "-" in s:
assert s.count("-") == 1, s
# this takes care of the cases where the step is a period such as 0-24 or 12-24
hours = int(str(step).split("-")[-1])
hours = int(str(s).split("-")[-1])

base = date - datetime.timedelta(hours=hours)
r.update(
{
date_key: base.strftime("%Y%m%d"),
"time": base.strftime("%H%M"),
"step": step,
"step": s,
}
)

Expand All @@ -113,28 +70,12 @@ def _expand_mars_request(request, date, request_already_using_valid_datetime=Fal
if isinstance(r[pproc], (list, tuple)):
r[pproc] = "/".join(str(x) for x in r[pproc])

if user_date is not None:
if not user_date.match(r[date_key]):
continue

if user_time is not None:
# It time is provided by the user, we only keep the requests that match the time
if r["time"] not in user_time:
continue

requests.append(r)

# assert requests, requests

return requests


def factorise_requests(
dates,
*requests,
request_already_using_valid_datetime=False,
date_key="date",
):
def factorise_requests(dates, *requests, request_already_using_valid_datetime=False, date_key="date"):
updates = []
for req in requests:
# req = normalise_request(req)
Expand All @@ -147,9 +88,6 @@ def factorise_requests(
date_key=date_key,
)

if not updates:
return

compressed = Availability(updates)
for r in compressed.iterate():
for k, v in r.items():
Expand Down Expand Up @@ -240,15 +178,7 @@ def use_grib_paramid(r):
]


def mars(
context,
dates,
*requests,
request_already_using_valid_datetime=False,
date_key="date",
**kwargs,
):

def mars(context, dates, *requests, request_already_using_valid_datetime=False, date_key="date", **kwargs):
if not requests:
requests = [kwargs]

Expand Down
1 change: 0 additions & 1 deletion src/anemoi/datasets/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def _convert(x):
def open_dataset(*args, **kwargs):

# That will get rid of OmegaConf objects

args, kwargs = _convert(args), _convert(kwargs)

ds = _open_dataset(*args, **kwargs)
Expand Down
5 changes: 4 additions & 1 deletion src/anemoi/datasets/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import datetime
import json
import logging
import os
import pprint
import warnings
from functools import cached_property
Expand All @@ -27,6 +28,8 @@ def _tidy(v):
return [_tidy(i) for i in v]
if isinstance(v, dict):
return {k: _tidy(v) for k, v in v.items()}
if isinstance(v, str) and v.startswith("/"):
return os.path.basename(v)
if isinstance(v, datetime.datetime):
return v.isoformat()
if isinstance(v, datetime.date):
Expand Down Expand Up @@ -388,7 +391,7 @@ def _supporting_arrays_and_sources(self):

# Arrays from the input sources
for i, source in enumerate(self._input_sources()):
name = source.name if source.name is not None else f"source{i}"
name = source.name if source.name is not None else i
src_arrays = source._supporting_arrays(name)
source_to_arrays[id(source)] = sorted(src_arrays.keys())

Expand Down
Loading

0 comments on commit 008a12d

Please sign in to comment.