Skip to content

Commit

Permalink
make to_units() a accessory
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasprobst committed Apr 29, 2024
1 parent c0432c2 commit fd9256b
Show file tree
Hide file tree
Showing 9 changed files with 551 additions and 159 deletions.
489 changes: 427 additions & 62 deletions docs/userguide/misc/Extensions.ipynb

Large diffs are not rendered by default.

14 changes: 12 additions & 2 deletions h5rdmtoolbox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""h5rdtoolbox repository"""

import appdirs
import logging
import pathlib
from logging.handlers import RotatingFileHandler

import appdirs

_logdir = pathlib.Path(appdirs.user_log_dir('h5rdmtoolbox'))
_logdir.mkdir(parents=True, exist_ok=True)

Expand Down Expand Up @@ -161,10 +160,12 @@ def get_filesize(hdf_filename: Union[str, pathlib.Path]) -> int:
"""Get the size of the HDF5 file in bytes"""
return utils.get_filesize(hdf_filename)


def get_checksum(hdf_filename: Union[str, pathlib.Path]) -> str:
"""Get the checksum of the HDF5 file"""
return utils.get_checksum(hdf_filename)


def register_dataset_decoder(decoder: Callable, decoder_name: str = None, overwrite: bool = False):
"""A decoder function takes a xarray.DataArray and a dataset as input and returns a xarray.DataArray
It is called after the dataset is loaded into memory and before being returned to the user. Be careful:
Expand All @@ -184,6 +185,15 @@ def register_dataset_decoder(decoder: Callable, decoder_name: str = None, overwr
atexit_verbose = False


def set_loglevel(level: Union[int, str]):
"""Set the logging level of the h5rdmtoolbox logger"""
import logging
_logger = logging.getLogger('h5rdmtoolbox')
_logger.setLevel(level)
for h in _logger.handlers:
h.setLevel(level)


@atexit.register
def clean_temp_data(full: bool = False):
"""cleaning up the tmp directory"""
Expand Down
61 changes: 41 additions & 20 deletions h5rdmtoolbox/extensions/units.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,45 @@
import xarray as xr
from typing import Optional

from h5rdmtoolbox import get_ureg
from h5rdmtoolbox.protocols import H5TbxDataset
from ..wrapper.accessory import Accessory, register_special_dataset


class ToUnitsInterface:
def __init__(self,
dataset: H5TbxDataset,
dataset_unit: Optional[str] = None,
**coord_units):
self.dataset = dataset
self.dataset_unit = dataset_unit
self.coord_units = coord_units

def _convert_units(self, data: xr.DataArray):
assert isinstance(data, xr.DataArray)
assert 'units' in data.attrs, 'No units attribute found in the dataset'
for c, cn in self.coord_units.items():
assert 'units' in data.coords[c].attrs, f'No units attribute found in the coordinate {c}'
data.coords[c] = data.coords[c].pint.quantify(unit_registry=get_ureg()).pint.to(
self.coord_units[c]).pint.dequantify()
# convert units
if self.dataset_unit is None:
return data
return data.pint.quantify(unit_registry=get_ureg()).pint.to(self.dataset_unit).pint.dequantify()

def sel(self, method=None, **coords) -> xr.DataArray:
return self._convert_units(self.dataset.sel(method=method, **coords))

def isel(self, **indexers) -> xr.DataArray:
return self._convert_units(self.dataset.isel(**indexers))

def __getitem__(self, *args, **kwargs):
return self._convert_units(self.dataset.__getitem__(*args, **kwargs))


@register_special_dataset("to_units", "Dataset")
class ToUnitsAccessory(Accessory):
"""Accessor to await selected data to be converted to a new units"""

@xr.register_dataarray_accessor("to")
class UnitConversionAccessor:
"""Accessor to convert units of data array. It is
also possible to convert its coordinates"""

def __init__(self, xarray_obj):
self._obj = xarray_obj

def __call__(self, *args, **kwargs):
new_obj = self._obj.copy()
if len(args) > 0:
for arg in args:
if isinstance(arg, str):
new_obj = new_obj.pint.quantify(unit_registry=get_ureg()).pint.quantify(unit_registry=get_ureg()).pint.to(arg).pint.dequantify()
elif isinstance(arg, dict):
for k, v in arg.items():
new_obj.coords[k] = self._obj.coords[k].pint.quantify(unit_registry=get_ureg()).pint.to(v).pint.dequantify()
for k, v in kwargs.items():
new_obj.coords[k] = self._obj.coords[k].pint.quantify(unit_registry=get_ureg()).pint.to(v).pint.dequantify()
return new_obj
def __call__(self, dataset_unit: Optional[str] = None, **coord_units) -> ToUnitsInterface:
return ToUnitsInterface(self._obj, dataset_unit=dataset_unit, **coord_units)
18 changes: 8 additions & 10 deletions h5rdmtoolbox/extensions/vector.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import h5py
import xarray as xr
from typing import List, Tuple

# noinspection PyUnresolvedReferences
from . import magnitude # automatically make magnitude available
from ..wrapper.accessory import Accessor, register_special_dataset
from ..wrapper.accessory import Accessory, register_special_dataset
from ..wrapper.core import Group, File


Expand All @@ -19,32 +20,29 @@ def __init__(self, **datasets):
self._data_vars = list(self._datasets.keys())
self._shape = self._datasets[self._data_vars[0]].shape

def __getitem__(self, item) -> xr.DataArray:
def __getitem__(self, item) -> xr.Dataset:
return xr.merge([da.__getitem__(item).rename(k) for k, da in self._datasets.items()])

def __repr__(self):
return f'<HDF-XrDataset (shape {self.shape} data_vars: {self.data_vars})>'

@property
def data_vars(self):
def data_vars(self) -> List[str]:
"""List of data variables in the dataset"""
return self._data_vars

@property
def shape(self):
def shape(self) -> Tuple[int]:
"""Shape of the dataset (taken from the first dataset)"""
return self._shape


@register_special_dataset("Vector", Group)
@register_special_dataset("Vector", File)
class VectorDataset(Accessor):
class VectorDataset(Accessory):
"""A special dataset for vector data.
The vector components are stored in the group as datasets."""

def __init__(self, h5grp: h5py.Group):
self._grp = h5grp

def __call__(self, *args, **kwargs) -> HDFXrDataset:
"""Returns a xarray dataset with the vector components as data variables.
Expand All @@ -68,7 +66,7 @@ def __call__(self, *args, **kwargs) -> HDFXrDataset:
hdf_datasets = {}
for arg in args:
if isinstance(arg, str):
ds = self._grp[arg]
ds = self._obj[arg]
elif isinstance(arg, h5py.Dataset):
ds = arg
else:
Expand All @@ -77,7 +75,7 @@ def __call__(self, *args, **kwargs) -> HDFXrDataset:

for name, ds in kwargs.items():
if isinstance(ds, str):
ds = self._grp[ds]
ds = self._obj[ds]
elif not isinstance(ds, h5py.Dataset):
raise TypeError(f'Invalid type: {type(ds)}')
hdf_datasets[name.strip('/')] = ds
Expand Down
31 changes: 27 additions & 4 deletions h5rdmtoolbox/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
must have the same method signatures.
"""

import pathlib
from typing import Protocol, Optional, Union, Dict, List, Any, Tuple

import h5py
import numpy as np
import pathlib
import rdflib
import xarray as xr
from typing import Protocol, Optional, Union, Dict, List, Any, Tuple


class NamedObject(Protocol):
Expand Down Expand Up @@ -154,7 +155,13 @@ def attrs(self) -> H5TbxAttributeManager:
...

def __delitem__(self, key): ...
class H5TbxGroup(H5TbxHLObject):


class H5TbxFile(H5TbxHLObject):
"""Protocol for the h5tbx.File class."""


class H5TbxGroup(H5TbxFile):
"""Protocol for the h5tbx.Group class."""

def __getitem__(self, name: str):
Expand All @@ -172,6 +179,22 @@ def coords(self):
def hdf_filename(self) -> pathlib.Path:
"""Return the filename as a pathlib.Path object."""

def sel(self, method=None, **coords) -> xr.DataArray:
"""Return the Dataset selected by the coordinates"""
...

def isel(self, **indexers) -> xr.DataArray:
"""Return the Dataset indexed by the indexers"""
...

def __getitem__(self,
args,
new_dtype=None,
nparray=False,
links_as_strings: bool = False) -> Union[xr.DataArray, np.ndarray]:
"""Return the data array by the item name"""
...


class StandardAttribute(Protocol):

Expand Down
37 changes: 30 additions & 7 deletions h5rdmtoolbox/wrapper/accessory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Module to register attributes of wrapper classes without touching the implementation"""
import h5py
from typing import Union
import logging
from typing import Union, Type

from .core import Group
from ..protocols import H5TbxHLObject

logger = logging.getLogger('h5rdmtoolbox')


class SpecialDatasetRegistrationWarning(Warning):
Expand Down Expand Up @@ -53,22 +55,43 @@ def _register_special_dataset(name, cls, special_dataset, overwrite):
if not overwrite:
raise RuntimeError(f'Cannot register the accessor {special_dataset!r} under name {name!r} '
f'because it already exists and overwrite is set to {overwrite}')
logger.debug(f'Registering special dataset {name!r} for class {cls!r}')
setattr(cls, name, _CachedHDFAccessor(name, special_dataset))
return special_dataset


def register_special_dataset(name, cls: Union["Dataset", "Group"], overwrite=False):
def register_special_dataset(name, cls: Union[str, Type[H5TbxHLObject]], overwrite=False):
"""registers a special dataset to a wrapper class"""

if isinstance(cls, str):
if cls.lower() == 'dataset':
from .core import Dataset
cls = Dataset
elif cls.lower() == 'group':
from .core import Group
cls = Group
elif cls.lower() == 'file':
from .core import File
cls = File
else:
raise ValueError(f'Invalid class type {cls!r}')

def decorator(accessor):
"""decorator"""
return _register_special_dataset(name, cls, accessor, overwrite)

return decorator


class Accessor:
class Accessory:
"""Base class for all special datasets"""

def __init__(self, h5grp: h5py.Group):
self._grp = h5grp
def __init__(self, obj: H5TbxHLObject):
"""Initialize the accessor with the object to be accessed
Parameters
----------
obj : H5TbxHLObject
The object to which the accessor is attached
"""
self._obj = obj
48 changes: 0 additions & 48 deletions h5rdmtoolbox/wrapper/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,32 +1302,6 @@ def wrapper(*args):
return obj


class UnitConversionInterface:
def __init__(self, dataset, dataset_unit, **coord_units):
self.dataset = dataset
self.dataset_unit = dataset_unit
self.coord_units = coord_units

def _convert_units(self, data: xr.DataArray):
assert isinstance(data, xr.DataArray)
assert 'units' in data.attrs, 'No units attribute found in the dataset'
for c, cn in self.coord_units.items():
assert 'units' in data.coords[c].attrs, f'No units attribute found in the coordinate {c}'
data.coords[c] = data.coords[c].pint.quantify(unit_registry=get_ureg()).pint.to(
self.coord_units[c]).pint.dequantify()
# convert units
return data.pint.quantify(unit_registry=get_ureg()).pint.to(self.dataset_unit).pint.dequantify()

def sel(self, method=None, **coords) -> xr.DataArray:
return self._convert_units(self.dataset.sel(method=method, **coords))

def isel(self, **indexers) -> xr.DataArray:
return self._convert_units(self.dataset.isel(**indexers))

def __getitem__(self, *args, **kwargs):
return self._convert_units(self.dataset.__getitem__(*args, **kwargs))


class Dataset(h5py.Dataset):
"""Wrapper around the h5py.Dataset. Some useful methods are added on top of
the underlying *h5py* package.
Expand All @@ -1349,8 +1323,6 @@ class Dataset(h5py.Dataset):
* dumps(): string representation of group
* isel(): Select data by named dimension and index, mimics xarray.isel.
* sel(): Select data by named dimension and values, mimics xarray.sel.
* to_units(): Convert the dataset to a new unit.
* write_iso_timestamp(): Write an ISO 8601 timestamp to the current dataset attribute.
The following properties are added to the h5py.Dataset object:
Expand Down Expand Up @@ -1929,26 +1901,6 @@ def __init__(self, _id):
super().__init__(_id)
self._hdf_filename = Path(self.file.filename)

def to_units(self, dataset_unit, **coord_units) -> UnitConversionInterface:
"""Return interface, which allows to convert the dataset and/or its dimension scales
(coordinates) to a new unit. On the return object, the methods isel() and sel() can be
used to select data based on named dimension and index or values - just in the new
units.
Parameters
----------
dataset_unit : str
The new unit for the dataset.
coord_units : Dict
The new units for the coordinates.
Examples
--------
>>> with h5tbx.File('test.h5', 'r') as h5:
>>> h5.vel.to_units('m/s', time='s', z='m')
"""
return UnitConversionInterface(self, dataset_unit, **coord_units)

def set_primary_scale(self, axis, iscale: int):
"""Set the primary scale for a specific axis.
Expand Down
8 changes: 4 additions & 4 deletions tests/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,16 +220,16 @@ def test_units_to(self):
with h5tbx.File(mode='w') as h5:
ds = h5.create_dataset('x', data=[1, 2, 3], make_scale=True, attrs={'units': 'm'})
y = h5.create_dataset('y', data=[1, 0, 1], attach_scale='x', attrs={'units': 'mm'})
ds_cm = ds[()].to('cm')
ds_cm = ds.to_units('cm')[()]
self.assertEqual('cm', ds_cm.attrs['units'])

y_cm = y[()].to('cm')
y_cm = y.to_units('cm')[()]
self.assertEqual('cm', y_cm.attrs['units'])

y_xcm = y[()].to({'x': 'cm'})
y_xcm = y.to_units({'x': 'cm'})[()]
self.assertEqual('mm', y_xcm.attrs['units'])
self.assertEqual('cm', y_xcm.x.attrs['units'])

y_xcm = y[()].to(x='cm')
y_xcm = y.to_units(x='cm')[()]
self.assertEqual('mm', y_xcm.attrs['units'])
self.assertEqual('cm', y_xcm.x.attrs['units'])
4 changes: 2 additions & 2 deletions tests/wrapper/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from h5rdmtoolbox.wrapper import h5yaml
from h5rdmtoolbox.wrapper.h5attr import AttributeString

logger = h5tbx.logger
# logger.setLevel('ERROR')
logger = h5tbx.set_loglevel('ERROR')

__this_dir__ = pathlib.Path(__file__).parent


Expand Down

0 comments on commit fd9256b

Please sign in to comment.