Skip to content

Commit

Permalink
fix netcdf tests
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Feb 20, 2024
1 parent 316c6bc commit c431542
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 16 deletions.
8 changes: 4 additions & 4 deletions climetlab/readers/grib/index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, *args, **kwargs):

@classmethod
def new_mask_index(self, *args, **kwargs):
return MaskFieldSet(*args, **kwargs)
return GribMaskFieldSet(*args, **kwargs)

@property
def availability_path(self):
Expand All @@ -53,7 +53,7 @@ def availability_path(self):
@classmethod
def merge(cls, sources):
assert all(isinstance(_, GribFieldSet) for _ in sources)
return MultiFieldSet(sources)
return GribMultiFieldSet(sources)

def available(self, request, as_list_of_dicts=False):
from climetlab.utils.availability import Availability
Expand Down Expand Up @@ -152,12 +152,12 @@ def _normalize_kwargs_names(self, **kwargs):
return kwargs


class MaskFieldSet(GribFieldSet, MaskIndex):
class GribMaskFieldSet(GribFieldSet, MaskIndex):
def __init__(self, *args, **kwargs):
MaskIndex.__init__(self, *args, **kwargs)


class MultiFieldSet(GribFieldSet, MultiIndex):
class GribMultiFieldSet(GribFieldSet, MultiIndex):
def __init__(self, *args, **kwargs):
MultiIndex.__init__(self, *args, **kwargs)

Expand Down
6 changes: 3 additions & 3 deletions climetlab/readers/grib/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ def update_metadata(self, handle, metadata, compulsary):
if "number" in metadata:
compulsary += ("numberOfForecastsInEnsemble",)
productDefinitionTemplateNumber = {"tp": 11}
metadata["productDefinitionTemplateNumber"] = (
productDefinitionTemplateNumber.get(handle.get("shortName"), 1)
)
metadata[
"productDefinitionTemplateNumber"
] = productDefinitionTemplateNumber.get(handle.get("shortName"), 1)

if metadata.get("type") in ("pf", "cf"):
metadata.setdefault("typeOfGeneratingProcess", 4)
Expand Down
4 changes: 2 additions & 2 deletions climetlab/readers/grib/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import logging

from climetlab.readers import Reader
from climetlab.readers.grib.index import MultiFieldSet
from climetlab.readers.grib.index import GribMultiFieldSet
from climetlab.readers.grib.index.file import FieldSetInOneFile

LOG = logging.getLogger(__name__)
Expand All @@ -31,7 +31,7 @@ def merge(cls, readers):
assert all(isinstance(s, GRIBReader) for s in readers), readers
assert len(readers) > 1

return MultiFieldSet(readers)
return GribMultiFieldSet(readers)

def mutate_source(self):
# A GRIBReader is a source itself
Expand Down
38 changes: 37 additions & 1 deletion climetlab/readers/netcdf/fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from functools import cached_property
from itertools import product

from climetlab.core.index import MaskIndex, MultiIndex
from climetlab.indexing.fieldset import FieldSet
from climetlab.utils.bbox import BoundingBox
from climetlab.utils.dates import to_datetime
Expand All @@ -24,6 +25,10 @@ def __init__(self, path):
self.path = path
self.opendap = path.startswith("http")

@classmethod
def new_mask_index(self, *args, **kwargs):
return NetCDFMaskFieldSet(*args, **kwargs)

def __repr__(self):
return "NetCDFReader(%s)" % (self.path,)

Expand All @@ -40,6 +45,9 @@ def __getitem__(self, n):
def dataset(self):
import xarray as xr

if ".zarr" in self.path:
return xr.open_zarr(self.path)

if self.opendap:
return xr.open_dataset(self.path)
else:
Expand Down Expand Up @@ -146,7 +154,7 @@ def _get_fields(self, ds): # noqa C901
def to_xarray(self, **kwargs):
import xarray as xr

if self.opendap:
if self.path.startswith("http"):
return xr.open_dataset(self.path, **kwargs)
return type(self).to_xarray_multi_from_paths([self.path], **kwargs)

Expand Down Expand Up @@ -185,3 +193,31 @@ def to_datetime_list(self):

def to_bounding_box(self):
return BoundingBox.multi_merge([s.to_bounding_box() for s in self.fields])

@classmethod
def merge(cls, sources):
assert len(sources) > 1
assert all(isinstance(_, NetCDFFieldSet) for _ in sources)
return NetCDFMultiFieldSet(sources)


class NetCDFMaskFieldSet(NetCDFFieldSet, MaskIndex):
def __init__(self, *args, **kwargs):
MaskIndex.__init__(self, *args, **kwargs)


class NetCDFMultiFieldSet(NetCDFFieldSet, MultiIndex):
def __init__(self, *args, **kwargs):
MultiIndex.__init__(self, *args, **kwargs)
self.paths = [s.path for s in args[0]]

def to_xarray(self, **kwargs):
import xarray as xr
if not kwargs:
kwargs = dict(combine="by_coords")
return xr.open_mfdataset(self.paths, **kwargs)


@cached_property
def dataset(self):
return self.to_xarray(combine="by_coords")
4 changes: 2 additions & 2 deletions climetlab/sources/indexed_urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import warnings

from climetlab.indexing import PerUrlIndex
from climetlab.readers.grib.index import MultiFieldSet
from climetlab.readers.grib.index import GribMultiFieldSet
from climetlab.readers.grib.index.sql import FieldsetInFilesWithSqlIndex
from climetlab.sources.indexed import IndexedSource
from climetlab.utils.patterns import Pattern
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(
# This is to avoid keeping them on the request
request.pop(used)

index = MultiFieldSet(
index = GribMultiFieldSet(
FieldsetInFilesWithSqlIndex.from_url(
get_index_url(url, substitute_extension, index_extension),
selection=request,
Expand Down
9 changes: 8 additions & 1 deletion climetlab/utils/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,14 @@ def substitute(self, value, name):
return self.format % value


TYPES = {"": Any, "int": Int, "float": Float, "date": Datetime, "strftime": Datetime, "enum": Enum}
TYPES = {
"": Any,
"int": Int,
"float": Float,
"date": Datetime,
"strftime": Datetime,
"enum": Enum,
}


class Constant:
Expand Down
1 change: 0 additions & 1 deletion tests/readers/test_netcdf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def test_dummy_netcdf_4():
@pytest.mark.long_test
@pytest.mark.download
@pytest.mark.skipif(NO_CDS, reason="No access to CDS")
@pytest.mark.skipif(True, reason="Merging of netcdf files does not work yet")
def test_multi():
s1 = load_source(
"cds",
Expand Down
4 changes: 2 additions & 2 deletions tests/sources/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_nc_merge_custom(custom_merger):
target2 = xr.open_mfdataset([s1.path, s2.path])
assert target2.identical(merged)

@pytest.mark.skipif(True, reason="Merging of netcdf files does not work yet")

def test_nc_merge_var():
s1 = load_source(
"climetlab-testing",
Expand Down Expand Up @@ -125,7 +125,7 @@ def _merge_var_different_coords(kind1, kind2):

assert target.identical(merged)

@pytest.mark.skipif(True, reason="Merging of netcdf files does not work yet")

def test_nc_merge_var_different_coords():
_merge_var_different_coords("netcdf", "netcdf")

Expand Down

0 comments on commit c431542

Please sign in to comment.