Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasprobst committed Oct 5, 2023
1 parent aab4b13 commit 6ff7f93
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 13 deletions.
1 change: 1 addition & 0 deletions h5rdmtoolbox/conventions/standard_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def validate(self, value, parent, attrs=None):
if value is not None:
if isinstance(value, str) and value.startswith('{') and value.endswith('}'):
value = json.loads(value)
if isinstance(value, dict):
try:
model_fields = self.validator.model_fields
if 'value' in model_fields and 'typing.Dict' in str(model_fields['value'].annotation):
Expand Down
22 changes: 18 additions & 4 deletions h5rdmtoolbox/wrapper/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,12 @@ def create_time_dataset(self,
if isinstance(data, np.ndarray):
return self.create_string_dataset(name, data=[t.astype(datetime).isoformat() for t in data],
overwrite=overwrite, attrs=attrs, **kwargs)
return self.create_string_dataset(name, data=[t.isoformat() for t in data],
_data = np.asarray(data)
_orig_shape = _data.shape
_flat_data = _data.flatten()
_flat_data = np.asarray([t.isoformat() for t in _flat_data])
_reshaped_data = _flat_data.reshape(_orig_shape)
return self.create_string_dataset(name, data=_reshaped_data.tolist(),
overwrite=overwrite, attrs=attrs, **kwargs)

def create_string_dataset(self,
Expand All @@ -570,7 +575,7 @@ def create_string_dataset(self,
if isinstance(data, str):
n_letter = len(data)
elif isinstance(data, (tuple, list)):
n_letter = max([len(d) for d in data])
n_letter = max([len(d) for d in np.asarray(data).flatten()])
else:
raise TypeError(f'Unexpected type for parameter "data": {type(data)}. Expected str or List/Tuple of str')
dtype = f'S{n_letter}'
Expand Down Expand Up @@ -1668,7 +1673,8 @@ def __getitem__(self, args, new_dtype=None, nparray=False) -> Union[xr.DataArray
if dim_ds_data.dtype.kind == 'S':
# decode string array
if dim_ds_attrs.get('ISTIMEDS', False):
dim_ds_data = np.array([datetime.fromisoformat(t) for t in dim_ds_data.astype(str)]).astype(datetime)
dim_ds_data = np.array([datetime.fromisoformat(t) for t in dim_ds_data.astype(str)]).astype(
datetime)
if dim_ds_data.ndim == 0:
if isinstance(arg, int):
coords[coord_name] = xr.DataArray(name=coord_name,
Expand Down Expand Up @@ -1718,7 +1724,15 @@ def __getitem__(self, args, new_dtype=None, nparray=False) -> Union[xr.DataArray
# decode string array
_arr = arr.astype(str)
if self.attrs.get('ISTIMEDS', False):
return xr.DataArray([datetime.fromisoformat(t) for t in _arr], attrs=attrs)
if _arr.ndim == 0:
_arr = np.asarray(datetime.fromisoformat(_arr))
elif _arr.ndim == 1:
_arr = [datetime.fromisoformat(t) for t in _arr]
else: # _arr.ndim > 1:
orig_shape = _arr.shape
_flat_arr = np.asarray([datetime.fromisoformat(t) for t in _arr.flatten()])
_arr = _flat_arr.reshape(orig_shape)
return xr.DataArray(_arr, attrs=attrs)
else:
if isinstance(_arr, np.ndarray):
return tuple(_arr)
Expand Down
30 changes: 21 additions & 9 deletions tests/wrapper/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pathlib
import unittest
import xarray as xr
from datetime import datetime
from datetime import datetime, timedelta
from numpy import linspace as ls

import h5rdmtoolbox as h5tbx
Expand Down Expand Up @@ -730,9 +730,8 @@ def test_create_dataset_with_ancillary_ds(self):
h5.create_dataset('vel3', data=[1.5, 2.5], attach_scale='3D')

def test_time(self):
import datetime
tdata = [datetime.datetime.now(),
(datetime.datetime.now() + datetime.timedelta(hours=1))]
tdata = [datetime.now(),
(datetime.now() + timedelta(hours=1))]
tdata_np = np.asarray(tdata, dtype=np.datetime64)
with h5tbx.File() as h5:
h5.create_string_dataset('time', data=[t.isoformat() for t in tdata],
Expand Down Expand Up @@ -761,14 +760,27 @@ def test_time(self):
np.testing.assert_equal(t.values, np.datetime64(tdata[it]))

def test_time_as_coord(self):
import datetime
with h5tbx.File() as h5:
h5.create_time_dataset('time', data=[datetime.datetime.now(),
datetime.datetime.now() + datetime.timedelta(hours=1),
datetime.datetime.now() + datetime.timedelta(hours=3)],
h5.create_time_dataset('time', data=[datetime.now(),
datetime.now() + timedelta(hours=1),
datetime.now() + timedelta(hours=3)],
attrs={'ISTIMEDS': True,
'TIMEFORMAT': 'ISO'}, make_scale=True)
h5.create_dataset('vel', data=[1, 2, -3], attach_scale='time')
v = h5.vel[()]

print(v.time)
def test_multidim_time_ds(self):
with h5tbx.File() as h5:
h5.create_time_dataset('time', data=[[datetime.now(),
datetime.now() + timedelta(hours=1),
datetime.now() + timedelta(hours=3)],
[datetime.now(),
datetime.now() + timedelta(hours=6),
datetime.now() + timedelta(hours=10)]
],
attrs={'ISTIMEDS': True,
'TIMEFORMAT': 'ISO'})
t = h5.time[()]
self.assertIsInstance(t, xr.DataArray)
self.assertEqual(t.shape, (2, 3))
self.assertIsInstance(t[0, 0].values, np.datetime64)

0 comments on commit 6ff7f93

Please sign in to comment.