Skip to content

Commit

Permalink
Merge pull request #116 from pycroscopy/dev_gerd_04
Browse files Browse the repository at this point in the history
Dev gerd 04
  • Loading branch information
gduscher authored Apr 11, 2021
2 parents 3f0c623 + 9aab268 commit bc4db8c
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 59 deletions.
105 changes: 78 additions & 27 deletions sidpy/sid/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from __future__ import division, print_function, absolute_import, unicode_literals
import sys

import dask.array.core
import numpy as np
import matplotlib.pylab as plt
import string
Expand Down Expand Up @@ -158,7 +160,6 @@ def __repr__(self):
rep = rep + '\n with metadata: {}'.format(list(self.metadata.keys()))
return rep


def hdf_close(self):
if self.h5_dataset is not None:
self.h5_dataset.file.close()
Expand Down Expand Up @@ -461,7 +462,7 @@ def get_extent(self, dimensions):
Parameters
----------
dimensions: dictionary of dimensions
dimensions: list of dimensions
Returns
-------
Expand Down Expand Up @@ -718,7 +719,7 @@ def fft(self, dimension_type=None):
name='u', units=units_x, dimension_type=new_dimension_type,
quantity='reciprocal'))
if len(axes) > 1:
units_y = '1/' + self._axes[axes[0]].units
units_y = '1/' + self._axes[axes[1]].units
fft_dset.set_dimension(axes[1],
Dimension(np.fft.fftshift(np.fft.fftfreq(self.shape[axes[1]],
d=get_slope(self._axes[axes[1]].values))),
Expand All @@ -734,7 +735,7 @@ def __eq__(self, other): # TODO: Test __eq__
if not isinstance(other, Dataset):
return False
# if (self.__array__() == other.__array__()).all():
if super().__eq__(other).all():
if (self.__array__().__eq__(other.__array__())).all():
if self._units != other._units:
return False
if self._quantity != other._quantity:
Expand Down Expand Up @@ -775,6 +776,9 @@ def choose(self, choices):
def __abs__(self):
return self.like_data(super().__abs__())

def angle(self):
return self.like_data(np.angle(super()))

def __add__(self, other):
return self.like_data(super().__add__(other))

Expand Down Expand Up @@ -883,8 +887,6 @@ def __matmul__(self, other):
def __rmatmul__(self, other):
return self.like_data(super().__rmatmul__(other))



def min(self, axis=None, keepdims=False, split_every=None, out=None):
if axis is None:
return float(super().min())
Expand Down Expand Up @@ -914,27 +916,45 @@ def sum(self, axis=None, dtype=None, keepdims=False, split_every=None, out=None)
if axis is None:
return float(result)
else:
if not keepdims:
dim = 0
dataset = self.from_array(result)
if isinstance(axis, int):
axis = [axis]

return self.adjust_axis(result, axis, title='Sum_of_', keepdims=keepdims)

def mean(self, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
result = super().mean(axis=axis, dtype=dtype, keepdims=keepdims, split_every=split_every, out=out)
if axis is None:
return float(result)
else:
return self.adjust_axis(result, axis, title='Mean_of_', keepdims=keepdims)

for ax, dimension in self._axes.items():
if int(ax) not in axis:
dataset.set_dimension(dim, dimension)
dim += 1
else:
dataset = self.like_data(result)
dataset.title='Sum_of_'+self.title
dataset.modality=f'sum axis {axis}'
dataset.quantity = self.quantity
dataset.source = self.source
dataset.units = self.units
def squeeze(self, axis=None):
result = super().squeeze(axis=axis)
if axis is None:
axis = []
for index, ax in enumerate(self.shape):
if ax == 1:
axis.append(index)
return self.adjust_axis(result, axis, title='Squeezed_')


def adjust_axis(self, result, axis, title='', keepdims=False):
if not keepdims:
dim = 0
dataset = self.from_array(result)
if isinstance(axis, int):
axis = [axis]

for ax, dimension in self._axes.items():
if int(ax) not in axis:
dataset.set_dimension(dim, dimension)
dim += 1
else:
dataset = self.like_data(result)
dataset.title = title + self.title
dataset.modality = f'sum axis {axis}'
dataset.quantity = self.quantity
dataset.source = self.source
dataset.units = self.units

return dataset
return dataset

def swapaxes(self, axis1, axis2):
result = super().swapaxes(axis1, axis2)
Expand All @@ -950,8 +970,39 @@ def swapaxes(self, axis1, axis2):

return dataset

# def __array_ufunc__(self, numpy_ufunc, method, *inputs, **kwargs):
# result = super().__array_ufunc__(super(),numpy_ufunc, method, *inputs, **kwargs)
# return self.like_data(result)
def __array_ufunc__(self, numpy_ufunc, method, *inputs, **kwargs):
out = kwargs.get("out", ())

if method == "__call__":
if numpy_ufunc is np.matmul:
from dask.array.routines import matmul

# special case until apply_gufunc handles optional dimensions
return self.like_data(matmul(*inputs, **kwargs))
if numpy_ufunc.signature is not None:
from dask.array.gufunc import apply_gufunc

return self.like_data(apply_gufunc(
numpy_ufunc, numpy_ufunc.signature, *inputs, **kwargs))
if numpy_ufunc.nout > 1:
from dask.array import ufunc

try:
da_ufunc = getattr(ufunc, numpy_ufunc.__name__)
except AttributeError:
return NotImplemented
return self.like_data(da_ufunc(*inputs, **kwargs))
else:
return self.like_data(dask.array.core.elemwise(numpy_ufunc, *inputs, **kwargs))
elif method == "outer":
from dask.array import ufunc

try:
da_ufunc = getattr(ufunc, numpy_ufunc.__name__)
except AttributeError:
return NotImplemented
return self.like_data(da_ufunc.outer(*inputs, **kwargs))
else:
return NotImplemented

# def prod(self, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
21 changes: 11 additions & 10 deletions sidpy/viz/dataset_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(self, dset, spectrum_number=None, figure=None, **kwargs):
# Plot real and image
fig, axes = plt.subplots(nrows=2, **fig_args)

axes[0].plot(self.dim.values, np.abs(np.squeeze(self.dset)), **kwargs)
axes[0].plot(self.dim.values, self.dset.squeeze().abs(), **kwargs)

axes[0].set_title(self.dset.title + '\n(Magnitude)', pad=15)
axes[0].set_xlabel(self.dset.labels[self.dim])
Expand Down Expand Up @@ -158,7 +158,7 @@ def __init__(self, dset, figure=None, image_number=0, **kwargs):
# Plot complex image
self.axes = []
self.axes.append(self.fig.add_subplot(121))
self.img = self.axes[0].imshow(np.abs(np.squeeze(np.array(self.dset[tuple(self.selection)]))).T,
self.img = self.axes[0].imshow(self.dset[tuple(self.selection)].squeeze().abs().T,
extent=self.dset.get_extent(self.image_dims), **kwargs)
self.axes[0].set_xlabel(self.dset.labels[self.image_dims[0]])
self.axes[0].set_ylabel(self.dset.labels[self.image_dims[1]])
Expand All @@ -168,7 +168,7 @@ def __init__(self, dset, figure=None, image_number=0, **kwargs):
self.axes[0].ticklabel_format(style='sci', scilimits=(-2, 3))

self.axes.append(self.fig.add_subplot(122, sharex=self.axes[0], sharey=self.axes[0]))
self.img_c = self.axes[1].imshow(np.angle(np.squeeze(np.array(self.dset[tuple(self.selection)]))).T,
self.img_c = self.axes[1].imshow(self.dset[tuple(self.selection)].squeeze().angle().T,
extent=self.dset.get_extent(self.image_dims), **kwargs)
self.axes[1].set_xlabel(self.dset.labels[self.image_dims[0]])
self.axes[1].set_ylabel(self.dset.labels[self.image_dims[1]])
Expand All @@ -192,7 +192,7 @@ def plot_image(self, **kwargs):
else:
self.axis.set_title(self.dset.title)

self.img = self.axis.imshow(np.squeeze(np.array(self.dset[tuple(self.selection)])).T,
self.img = self.axis.imshow(self.dset[tuple(self.selection)].squeeze().T,
extent=self.dset.get_extent(self.image_dims), **kwargs)
self.axis.set_xlabel(self.dset.labels[self.image_dims[0]])
self.axis.set_ylabel(self.dset.labels[self.image_dims[1]])
Expand Down Expand Up @@ -271,7 +271,7 @@ def __init__(self, dset, figure=None, **kwargs):
self.image_dims = []
self.selection = []
for dim, axis in dset._axes.items():
if axis.dimension_type == DimensionType.SPATIAL:
if axis.dimension_type in [DimensionType.SPATIAL, DimensionType.RECIPROCAL]:
self.selection.append(slice(None))
self.image_dims.append(dim)
elif axis.dimension_type == DimensionType.TEMPORAL or len(dset) == 3:
Expand Down Expand Up @@ -347,7 +347,7 @@ def plot_image(self, **kwargs):
self.axis = plt.gca()
self.axis.set_title(self.dset.title)

self.img = self.axis.imshow(np.squeeze(self.dset[tuple(self.selection)]).T,
self.img = self.axis.imshow(self.dset[tuple(self.selection)].squeeze().T,
extent=self.dset.get_extent(self.image_dims), **kwargs)
self.axis.set_xlabel(self.dset.labels[self.image_dims[0]])
self.axis.set_ylabel(self.dset.labels[self.image_dims[1]])
Expand Down Expand Up @@ -391,6 +391,7 @@ def _average_slices(self, event):
self.img.set_data(image_stack.mean(axis=self.stack_dim).T)
self.fig.canvas.draw_idle()
elif event.old:
self.ind = self.ind % self.number_of_slices
self._update(self.ind)

def _onscroll(self, event):
Expand Down Expand Up @@ -439,7 +440,7 @@ def __init__(self, dset, figure=None, horizontal=True, **kwargs):
image_dims = []
spectral_dims = []
for dim, axis in dset._axes.items():
if axis.dimension_type == DimensionType.SPATIAL:
if axis.dimension_type in [DimensionType.SPATIAL, DimensionType.RECIPROCAL]:
selection.append(slice(None))
image_dims.append(dim)
elif axis.dimension_type == DimensionType.SPECTRAL:
Expand Down Expand Up @@ -481,7 +482,7 @@ def __init__(self, dset, figure=None, horizontal=True, **kwargs):
self.axes = self.fig.subplots(nrows=2, **fig_args)

self.fig.canvas.set_window_title(self.dset.title)
self.image = np.average(np.array(self.dset), axis=spectral_dims[0])
self.image = dset.mean(axis=spectral_dims[0])

self.axes[0].imshow(self.image.T, extent=self.extent, **kwargs)
if horizontal:
Expand Down Expand Up @@ -561,9 +562,9 @@ def get_spectrum(self):
else:
selection.append(slice(0, 1))

self.spectrum = np.squeeze(np.average(self.dset[tuple(selection)], axis=tuple(self.image_dims)))
self.spectrum = self.dset[tuple(selection)].mean(axis=tuple(self.image_dims))
# * self.intensity_scale[self.x,self.y]
return np.squeeze(self.spectrum)
return self.spectrum.squeeze()

def _onclick(self, event):
self.event = event
Expand Down
56 changes: 34 additions & 22 deletions tests/sid/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,6 @@ def test_mul(self):
self.assertIsInstance(new_dataset, sidpy.Dataset)
self.assertEqual(np.array(new_dataset)[0, 0, 0], 3)

def test_div(self):
input_spectrum = np.ones([3, 3, 3])
dataset = sidpy.Dataset.from_array(input_spectrum)
new_dataset = dataset/3.
self.assertIsInstance(new_dataset, sidpy.Dataset)
self.assertEqual(np.array(new_dataset)[0, 0, 0], 1/3)

def test_min(self):
input_spectrum = np.zeros([3, 3, 3])
dataset = sidpy.Dataset.from_array(input_spectrum)
Expand All @@ -77,6 +70,15 @@ def test_abs(self):
abs_dataset = dataset.abs()
self.assertIsInstance(abs_dataset, sidpy.Dataset)
self.assertEqual(abs_dataset[0, 0, 0], 1)
new_dataset = dataset.__abs__()
self.assertIsInstance(new_dataset, sidpy.Dataset)

def test_angle(self):
input_spectrum = np.ones([3, 3, 3]) * -1
dataset = sidpy.Dataset.from_array(input_spectrum)
angle_dataset = dataset.angle()
self.assertIsInstance(angle_dataset, sidpy.Dataset)
self.assertEqual(float(angle_dataset[0, 0, 0]), np.pi)

def test_dot(self):
input_spectrum = np.ones([3, 3, 3])
Expand Down Expand Up @@ -162,12 +164,6 @@ def test_lshift(self):
new_dataset = dataset.__lshift__(1)
self.assertIsInstance(new_dataset, sidpy.Dataset)

def test_rshift(self):
input_spectrum = np.ones([3, 3, 3])
dataset = sidpy.Dataset.from_array(input_spectrum)
new_dataset = dataset.__rshift__(1)
self.assertIsInstance(new_dataset, sidpy.Dataset)

def test_lt(self):
input_spectrum = np.ones([3, 3, 3])
dataset = sidpy.Dataset.from_array(input_spectrum)
Expand Down Expand Up @@ -264,12 +260,6 @@ def test_rtruediv(self):
new_dataset = dataset.__rtruediv__(2)
self.assertIsInstance(new_dataset, sidpy.Dataset)

def test_rfloordiv(self):
input_spectrum = np.ones([3, 3, 3])
dataset = sidpy.Dataset.from_array(input_spectrum)
new_dataset = dataset.__floordiv__(2)
self.assertIsInstance(new_dataset, sidpy.Dataset)

def test_rfloordiv(self):
input_spectrum = np.ones([3, 3, 3])
dataset = sidpy.Dataset.from_array(input_spectrum)
Expand Down Expand Up @@ -307,9 +297,9 @@ def test_divmod(self):
self.assertIsInstance(new_dataset, sidpy.Dataset)

def test_rdivmod(self):
input_spectrum = np.ones([3, 3, 3])
input_spectrum = np.ones([3])
dataset = sidpy.Dataset.from_array(input_spectrum)
new_dataset, _ = dataset.__rdivmod__(dataset)
new_dataset, _ = dataset.__rdivmod__(8)
self.assertIsInstance(new_dataset, sidpy.Dataset)

def test_real(self):
Expand Down Expand Up @@ -342,12 +332,34 @@ def test_sum(self):
new_dataset = dataset.sum(axis=1)
self.assertIsInstance(new_dataset, sidpy.Dataset)

def test_mean(self):
input_spectrum = np.ones([3, 3, 3])
dataset = sidpy.Dataset.from_array(input_spectrum)
new_dataset = dataset.mean(axis=1)
self.assertIsInstance(new_dataset, sidpy.Dataset)

def test_squeeze(self):
input_spectrum = np.ones([3, 1, 3])
dataset = sidpy.Dataset.from_array(input_spectrum)
new_dataset = dataset.squeeze()
self.assertIsInstance(new_dataset, sidpy.Dataset)

def test_swapaxes(self):
input_spectrum = np.ones([3, 3, 3])
dataset = sidpy.Dataset.from_array(input_spectrum)
new_dataset = dataset.swapaxes(0,1)
new_dataset = dataset.swapaxes(0, 1)
self.assertIsInstance(new_dataset, sidpy.Dataset)

def test_ufunc(self):
# Todo: More testing for better coverage
input_image = np.ones([3, 3, 3])
dataset = sidpy.Dataset.from_array(input_image)
new_dataset = np.sin(dataset)
self.assertIsInstance(new_dataset, sidpy.Dataset)
new_dataset = dataset @ dataset
self.assertIsInstance(new_dataset, sidpy.Dataset)


class TestFftFunctions(unittest.TestCase):

def test_spectrum_fft(self):
Expand Down

0 comments on commit bc4db8c

Please sign in to comment.