Skip to content

Commit

Permalink
added a spectral fn fit visualizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Rama Vasudevan committed Jul 1, 2022
1 parent b41a1d7 commit a30ba40
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 3 deletions.
28 changes: 27 additions & 1 deletion sidpy/proc/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
@author: Rama Vasudevan, Mani Valleti
"""

from xml.dom import NotFoundErr
from dask.distributed import Client
import numpy as np
import dask
import inspect
from ..sid import Dimension, Dataset
from ..sid.dimension import DimensionType
from ..viz.dataset_viz import SpectralImageFitVisualizer
from ..sid.dataset import DataType

try:
from scipy.optimize import curve_fit
Expand Down Expand Up @@ -170,6 +173,7 @@ def __init__(self, sidpy_dataset, fit_fn, xvec=None, ind_dims=None, guess_fn=Non
self.return_std = return_std
self.return_cov = return_cov
self.return_fit = return_fit
self.fitted_dset = None

self.mean_fit_results = []
if self.return_cov:
Expand Down Expand Up @@ -394,7 +398,7 @@ def get_fitted_dataset(self):
fitted_sid_dset_folded = fitted_dset_fold.like_data(np_folded_arr, title=fitted_dset_fold.title)
fitted_sid_dset = fitted_sid_dset_folded.unfold()
fitted_sid_dset.original_metadata = self.dataset.original_metadata.copy()

self.fitted_dset = fitted_sid_dset
return fitted_sid_dset

def get_km_priors(self):
Expand Down Expand Up @@ -428,6 +432,28 @@ def get_km_priors(self):
self.km_priors = np.array(km_priors)
self.num_fit_parms = self.km_priors.shape[-1]

def visualize_fit_results(self, figure = None, horizontal = True):
'''
Calls the interactive visualizer for comparing raw and fit datasets.
Inputs:
- figure: (Optional, default None) - handle to existing figure
- horiziontal: (Optional, default True) - whether spectrum should be plotted horizontally
'''
dset_type = self.dataset.data_type
supported_types = ['SPECTRAL_IMAGE']
if self.fitted_dset == None:
raise NotFoundErr("No fitted dataset found. Re-run with return_fit=True to use this feature")
if dset_type == DataType.SPECTRAL_IMAGE:
visualizer = SpectralImageFitVisualizer(self.dataset, self.fitted_dset,
figure=figure, horizontal=horizontal)
else:
raise NotImplementedError("Data type is {} but currently we only support types {}".format(dset_type, supported_types))

return visualizer


@staticmethod
def default_curve_fit(fit_fn, xvec, yvec, num_fit_parms, return_cov=True, **kwargs):
xvec = np.array(xvec)
Expand Down
80 changes: 78 additions & 2 deletions sidpy/viz/dataset_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, dset, spectrum_number=0, figure=None, **kwargs):
self.dim = self.dset._axes[self.spectral_dims[0]]

if is_complex_dtype(dset.dtype):
# Plot real and image
# Plot real and imaginary
self.fig, self.axes = plt.subplots(nrows=2, **fig_args)

self.axes[0].plot(self.dim.values, self.dset.squeeze().abs(), **kwargs)
Expand Down Expand Up @@ -681,7 +681,7 @@ def _update(self, ev=None):
self.axes[1].set_title('spectrum {}, {}'.format(self.x, self.y))

self.axes[1].set_xlim(xlim)
self.axes[1].set_ylim(ylim)
#self.axes[1].set_ylim(ylim)
self.axes[1].set_xlabel(self.xlabel)
self.axes[1].set_ylabel(self.ylabel)

Expand Down Expand Up @@ -961,3 +961,79 @@ def set_legend(self, set_legend):

def get_xy(self):
return [self.x, self.y]


#Let's make a curve fit visualizer

class SpectralImageFitVisualizer(SpectralImageVisualizer):

def __init__(self, original_dataset, fit_dataset, figure=None, horizontal=True):
'''
Visualizer for spectral image datasets, fit by the Sidpy Fitter
This class is called by Sidpy Fitter for visualizing the raw/fit dataset interactively.
Inputs:
- original_dataset: sidpy.Dataset containing the raw data
- fit_dataset: sidpy.Dataset with the fitted data. This is returned by the
Sidpy Fitter after functional fitting.
- figure: (Optional, default None) - handle to existing figure
- horiziontal: (Optional, default True) - whether spectrum should be plotted horizontally
'''

super().__init__(original_dataset, figure, horizontal)

self.fit_dset = fit_dataset
self.axes[1].clear()
self.get_fit_spectrum()
self.axes[1].plot(self.energy_scale, self.spectrum, 'bo')
self.axes[1].plot(self.energy_scale, self.fit_spectrum, 'r-')

def get_fit_spectrum(self):

from ..sid.dimension import DimensionType
if self.x > self.dset.shape[self.image_dims[0]] - self.bin_x:
self.x = self.dset.shape[self.image_dims[0]] - self.bin_x
if self.y > self.dset.shape[self.image_dims[1]] - self.bin_y:
self.y = self.dset.shape[self.image_dims[1]] - self.bin_y
selection = []

for dim, axis in self.dset._axes.items():
if axis.dimension_type == DimensionType.SPATIAL:
if dim == self.image_dims[0]:
selection.append(slice(self.x, self.x + self.bin_x))
else:
selection.append(slice(self.y, self.y + self.bin_y))

elif axis.dimension_type == DimensionType.SPECTRAL:
selection.append(slice(None))
else:
selection.append(slice(0, 1))

self.spectrum = self.dset[tuple(selection)].mean(axis=tuple(self.image_dims))
self.fit_spectrum = self.fit_dset[tuple(selection)].mean(axis=tuple(self.image_dims))
# * self.intensity_scale[self.x,self.y]

return self.fit_spectrum.squeeze(), self.spectrum.squeeze()


def _update(self, ev=None):

xlim = self.axes[1].get_xlim()
ylim = self.axes[1].get_ylim()
self.axes[1].clear()
self.get_fit_spectrum()

self.axes[1].plot(self.energy_scale, self.spectrum, 'bo', label='experiment')
self.axes[1].plot(self.energy_scale, self.fit_spectrum, 'r-', label='fit')

if self.set_title:
self.axes[1].set_title('spectrum {}, {}'.format(self.x, self.y))

self.axes[1].set_xlim(xlim)
#self.axes[1].set_ylim(ylim)
self.axes[1].set_xlabel(self.xlabel)
self.axes[1].set_ylabel(self.ylabel)

self.fig.canvas.draw_idle()


0 comments on commit a30ba40

Please sign in to comment.