-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update and unify extension imlementation
- Loading branch information
1 parent
fd9256b
commit c44b398
Showing
13 changed files
with
1,004 additions
and
719 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,77 +1,67 @@ | ||
"""Extension to compute magnitude of xarray datasets""" | ||
import h5py | ||
import numpy as np | ||
import xarray as xr | ||
from typing import Union, Dict | ||
from typing import Dict, Optional | ||
|
||
from h5rdmtoolbox.protocols import H5TbxDataset | ||
from h5rdmtoolbox.wrapper.accessory import Accessory, register_special_dataset | ||
|
||
@xr.register_dataset_accessor("magnitude") | ||
class MagnitudeAccessor: | ||
"""Accessor to convert units of data array. It is | ||
also possible to convert its coordinates""" | ||
|
||
def __init__(self, xarray_obj): | ||
"""Initialize the accessor""" | ||
self._obj = xarray_obj | ||
class MagnitudeInterface: | ||
def __init__(self, | ||
datasets: Dict[str, H5TbxDataset], | ||
name: Optional[str] = None, | ||
keep_attrs: bool = False): | ||
self.datasets = datasets | ||
self.name = name | ||
self.keep_attrs = keep_attrs | ||
|
||
def _compute_magnitude(self, datasets): | ||
assert len(datasets) > 1, 'At least two datasets are required to compute magnitude' | ||
keys = list(datasets.keys()) | ||
mag2 = datasets[keys[0]].pint.quantify() ** 2 | ||
with xr.set_options(keep_attrs=self.keep_attrs): | ||
for key in keys[1:]: | ||
mag2 += datasets[key].pint.quantify() ** 2 | ||
|
||
def compute_from(self, | ||
*data_vars, | ||
name: Union[str, None] = None, | ||
inplace: bool = True, | ||
attrs: Union[Dict, None] = None, | ||
overwrite: bool = False): | ||
"""compute magnitude from data variable names | ||
Parameters | ||
---------- | ||
data_vars: str | ||
Names of data variables to compute magnitude from. | ||
name: str | ||
Name of the magnitude variable to be used in the dataset. | ||
If None, the name is automatically generated. | ||
Example: if data_vars = ['u', 'v', 'w'], then name is 'magnitude_of_u_v_w' | ||
inplace: bool | ||
If True, the magnitude variable is added to the dataset. | ||
Otherwise, a new dataset is returned. | ||
attrs: dict | ||
Attributes to be added to the magnitude variable | ||
overwrite: bool | ||
If True, the magnitude variable is overwritten if it already exists in the dataset. | ||
""" | ||
mag2 = self._obj[data_vars[0]].pint.quantify() ** 2 | ||
from .. import consts | ||
# anc_ds = [] | ||
# anc_ds.append(self._obj[data_vars[0]].attrs.get(consts.ANCILLARY_DATASET, ())) | ||
for data_var in data_vars[1:]: | ||
mag2 += self._obj[data_var].pint.quantify() ** 2 | ||
# anc_ds.append(self._obj[data_var].attrs.get(consts.ANCILLARY_DATASET, ())) | ||
# with xr.set_options(keep_attrs=True): | ||
mag = np.sqrt(mag2).pint.dequantify() | ||
if self.name is None: | ||
mag.name = 'magnitude_of_' + '_and_'.join(k.replace(' ', '_') for k in keys) | ||
else: | ||
mag.name = self.name | ||
return mag | ||
|
||
# drop ancillary dataset information: | ||
mag.attrs.pop(consts.ANCILLARY_DATASET, None) | ||
def __getitem__(self, *args, **kwargs): | ||
return self._compute_magnitude( | ||
{k: v.__getitem__(*args, **kwargs) for k, v in self.datasets.items()} | ||
) | ||
|
||
# gather ancillary dataset information from vector components: | ||
_anc = [self._obj[da].attrs.get(consts.ANCILLARY_DATASET, None) for da in data_vars] | ||
def isel(self, **indexers): | ||
return self._compute_magnitude( | ||
{k: v.isel(**indexers) for k, v in self.datasets.items()} | ||
) | ||
|
||
_anc = [a for a in _anc if a is not None] | ||
if _anc: | ||
mag.attrs[consts.ANCILLARY_DATASET] = list(set([item for sublist in _anc for item in sublist])) | ||
def sel(self, method=None, **coords): | ||
return self._compute_magnitude( | ||
{k: v.sel(method=method, **coords) for k, v in self.datasets.items()} | ||
) | ||
|
||
joined_names = '_'.join(data_vars) | ||
if name is None: | ||
name = f'magnitude_of_{joined_names}' | ||
if name in self._obj: | ||
if not overwrite: | ||
raise KeyError(f'The name of variable "{name}" is already exists in the dataset.') | ||
del self._obj[name] | ||
mag.name = name | ||
processing_comment = 'processing_comment' | ||
while processing_comment in mag.attrs: | ||
processing_comment = f'_{processing_comment}' | ||
mag.attrs['processing_comment'] = f'computed from: {joined_names.replace("_", ", ")}' | ||
if attrs: | ||
mag.attrs.update(attrs) | ||
|
||
if inplace: | ||
self._obj[name] = mag | ||
return self._obj | ||
return mag | ||
@register_special_dataset("Magnitude", "Group") | ||
@register_special_dataset("Magnitude", "File") | ||
class Magnitude(Accessory): | ||
def __call__(self, *datasets, name: Optional[str] = None, keep_attrs: bool = False) -> MagnitudeInterface: | ||
if len(datasets) < 2: | ||
raise ValueError('Please provide at least two datasets to compute magnitude') | ||
hdf_datasets = {} | ||
for dataset in datasets: | ||
if isinstance(dataset, str): | ||
ds = self._obj[dataset] | ||
elif isinstance(dataset, h5py.Dataset): | ||
ds = dataset | ||
else: | ||
raise TypeError(f'Invalid type: {type(dataset)}') | ||
hdf_datasets[ds.name.strip('/')] = ds | ||
|
||
return MagnitudeInterface(hdf_datasets, name=name, keep_attrs=keep_attrs) |
Oops, something went wrong.