Skip to content

Commit

Permalink
Merge pull request #151 from pycroscopy/master
Browse files Browse the repository at this point in the history
Fixed a couple of bugs in fold and rechunk methods under sidpy.Dataset
  • Loading branch information
ramav87 authored Jun 27, 2022
2 parents 7f2e89d + cd67083 commit 7f2e5f7
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions sidpy/sid/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def from_array(cls, x, title='generic', chunks='auto', lock=False,
dask_array = x
else:
dask_array = da.from_array(np.array(x), chunks=chunks, lock=lock)

# view as subclass
sid_dataset = view_subclass(dask_array, cls)
sid_dataset.data_type = datatype
Expand Down Expand Up @@ -1165,7 +1164,7 @@ def persist(self, **kwargs):
return self.like_data(super().persist(**kwargs),
title_prefix='persisted_')

def rechunk(self, chunks, threshold, block_size_limit, balance):
def rechunk(self, chunks='auto', threshold=None, block_size_limit=None, balance=False):
return self.like_data(super().rechunk(chunks=chunks,
threshold=threshold,
block_size_limit=block_size_limit,
Expand Down Expand Up @@ -1210,26 +1209,26 @@ def fold(self, dim_order=None, method=None):
dim_order_list = [list(x) for x in dim_order]

# Book-keeping for unfolding
fold_attr = {'shape': self.shape, '_axes': self._axes.copy()}
fold_attr = {'_axes': self._axes.copy()}

if method == 'spaspec':
dim_order_list = [[], []]
for dim, axis in self._axes.items():
if axis.dimension_type == DimensionType.SPATIAL:
dim_order_list[0].extend(dim)
dim_order_list[0].extend([dim])
elif axis.dimension_type == DimensionType.SPECTRAL:
dim_order_list[1].extend(dim)
dim_order_list[1].extend([dim])
else:
warnings.warn('One of the dimensions is neither Spatial\
nor Spectral Type and is considered to be a \
part of the last collapsed dimension')
dim_order_list[1].extend(dim)
dim_order_list[1].extend([dim])

if method == 'spa':
dim_order_list = [[]]
for dim, axis in self._axes.items():
if axis.dimension_type == DimensionType.SPATIAL:
dim_order_list[0].extend(dim)
dim_order_list[0].extend([dim])
else:
dim_order_list.append([dim])

Expand All @@ -1243,7 +1242,7 @@ def fold(self, dim_order=None, method=None):
dim_order_list = [[]]
for dim, axis in self._axes.items():
if axis.dimension_type == DimensionType.Spectral:
dim_order_list[-1].extend(dim)
dim_order_list[-1].extend([dim])
else:
dim_order_list.insert(-1, [dim])

Expand All @@ -1266,6 +1265,7 @@ def fold(self, dim_order=None, method=None):
dim_order_flattened.extend(list(left_dims))

fold_attr['dim_order_flattened'] = dim_order_flattened
fold_attr['dim_order'] = dim_order_list
# Get the shape of the collapsed array
new_shape = np.ones(len(dim_order_list)).astype(int)
for i, dim in enumerate(dim_order_list):
Expand All @@ -1274,7 +1274,8 @@ def fold(self, dim_order=None, method=None):

# Collapsed dask array
transposed_dset = self.transpose(dim_order_flattened)
folded_dset = self.like_data(da.reshape(transposed_dset, tuple(new_shape)),

folded_dset = self.like_data(da.reshape(transposed_dset, tuple(new_shape), merge_chunks=True),
title_prefix='folded_', checkdims=False)

fold_attr['shape_transposed'] = [self.shape[i] for i in dim_order_flattened]
Expand Down Expand Up @@ -1303,7 +1304,7 @@ def unfold(self):
raise NotImplementedError('unfold only works on the dataset that was collapsed/folded by'
' the fold method')

reshaped_dset = da.reshape(self, shape_transposed)
reshaped_dset = da.reshape(self, shape_transposed, merge_chunks=True)
old_order = [dim_order_flattened.index(d) for d in range(len(dim_order_flattened))]

unfolded_dset = self.like_data(da.transpose(reshaped_dset, old_order),
Expand Down

0 comments on commit 7f2e5f7

Please sign in to comment.