Skip to content

Commit

Permalink
Merge pull request #152 from pycroscopy/dev_rama
Browse files Browse the repository at this point in the history
fitter changes
  • Loading branch information
ramav87 authored Jul 1, 2022
2 parents 7f2e5f7 + a30ba40 commit a0d059b
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 10 deletions.
45 changes: 37 additions & 8 deletions sidpy/proc/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
@author: Rama Vasudevan, Mani Valleti
"""

from xml.dom import NotFoundErr
from dask.distributed import Client
import numpy as np
import dask
import inspect
from ..sid import Dimension, Dataset
from ..sid.dimension import DimensionType
from ..viz.dataset_viz import SpectralImageFitVisualizer
from ..sid.dataset import DataType

try:
from scipy.optimize import curve_fit
Expand Down Expand Up @@ -47,7 +50,7 @@ def __init__(self, sidpy_dataset, fit_fn, xvec=None, ind_dims=None, guess_fn=Non
If NOT provided, it is assumed that all the non-spectral dimensions are independent dimensions.
guess_fn: (function) (optional) This optional function should be utilized to generate priors for the full fit
It takes the same arguments as the fitting function and should return the same type of results array.
It takes (xvec,yvec) as inputs and should return the fit parameters.
If the guess_fn is NOT provided, then the user MUST input the num_fit_parms.
num_fit_parms: (int) Number of fitting parameters. This is needed IF the guess function is not provided to set
Expand Down Expand Up @@ -170,6 +173,7 @@ def __init__(self, sidpy_dataset, fit_fn, xvec=None, ind_dims=None, guess_fn=Non
self.return_std = return_std
self.return_cov = return_cov
self.return_fit = return_fit
self.fitted_dset = None

self.mean_fit_results = []
if self.return_cov:
Expand Down Expand Up @@ -251,7 +255,7 @@ def do_fit(self, **kwargs):
p0 = self.prior[ind, :]

lazy_result = dask.delayed(SidFitter.default_curve_fit)(self.fit_fn, self.dep_vec,
self.folded_dataset[ind, :],
self.folded_dataset[ind, :],self.num_fit_parms,
return_cov=(self.return_cov or self.return_std),
p0=p0, **kwargs)
fit_results.append(lazy_result)
Expand All @@ -263,7 +267,7 @@ def do_fit(self, **kwargs):
self.get_km_priors()
for ind in range(self.num_computations):
lazy_result = dask.delayed(SidFitter.default_curve_fit)(self.fit_fn, self.dep_vec,
self.folded_dataset[ind, :],
self.folded_dataset[ind, :], self.num_fit_parms,
return_cov=(self.return_cov or self.return_std),
p0=self.km_priors[self.km_labels[ind]],
**kwargs)
Expand Down Expand Up @@ -394,7 +398,7 @@ def get_fitted_dataset(self):
fitted_sid_dset_folded = fitted_dset_fold.like_data(np_folded_arr, title=fitted_dset_fold.title)
fitted_sid_dset = fitted_sid_dset_folded.unfold()
fitted_sid_dset.original_metadata = self.dataset.original_metadata.copy()

self.fitted_dset = fitted_sid_dset
return fitted_sid_dset

def get_km_priors(self):
Expand Down Expand Up @@ -422,23 +426,48 @@ def get_km_priors(self):
else:
p0 = np.random.normal(loc=0.5, scale=0.1, size=self.num_fit_parms)

km_priors.append(SidFitter.default_curve_fit(self.fit_fn, self.dep_vec, cen,
km_priors.append(SidFitter.default_curve_fit(self.fit_fn, self.dep_vec, cen, self.num_fit_parms,
return_cov=False,
p0=p0, maxfev=10000))
self.km_priors = np.array(km_priors)
self.num_fit_parms = self.km_priors.shape[-1]

def visualize_fit_results(self, figure = None, horizontal = True):
'''
Calls the interactive visualizer for comparing raw and fit datasets.
Inputs:
- figure: (Optional, default None) - handle to existing figure
- horiziontal: (Optional, default True) - whether spectrum should be plotted horizontally
'''
dset_type = self.dataset.data_type
supported_types = ['SPECTRAL_IMAGE']
if self.fitted_dset == None:
raise NotFoundErr("No fitted dataset found. Re-run with return_fit=True to use this feature")
if dset_type == DataType.SPECTRAL_IMAGE:
visualizer = SpectralImageFitVisualizer(self.dataset, self.fitted_dset,
figure=figure, horizontal=horizontal)
else:
raise NotImplementedError("Data type is {} but currently we only support types {}".format(dset_type, supported_types))

return visualizer


@staticmethod
def default_curve_fit(fit_fn, xvec, yvec, return_cov=True, **kwargs):
def default_curve_fit(fit_fn, xvec, yvec, num_fit_parms, return_cov=True, **kwargs):
xvec = np.array(xvec)
yvec = np.array(yvec)
yvec = yvec.ravel()
xvec = xvec.ravel()
if curve_fit is None:
raise ModuleNotFoundError("scipy is not installed")
else:
popt, pcov = curve_fit(fit_fn, xvec, yvec, **kwargs)

try:
popt, pcov = curve_fit(fit_fn, xvec, yvec, **kwargs)
except:
popt = np.zeros(num_fit_parms)
pcov = np.zeros((num_fit_parms, num_fit_parms))
if return_cov:
return popt, pcov
else:
Expand Down
80 changes: 78 additions & 2 deletions sidpy/viz/dataset_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, dset, spectrum_number=0, figure=None, **kwargs):
self.dim = self.dset._axes[self.spectral_dims[0]]

if is_complex_dtype(dset.dtype):
# Plot real and image
# Plot real and imaginary
self.fig, self.axes = plt.subplots(nrows=2, **fig_args)

self.axes[0].plot(self.dim.values, self.dset.squeeze().abs(), **kwargs)
Expand Down Expand Up @@ -681,7 +681,7 @@ def _update(self, ev=None):
self.axes[1].set_title('spectrum {}, {}'.format(self.x, self.y))

self.axes[1].set_xlim(xlim)
self.axes[1].set_ylim(ylim)
#self.axes[1].set_ylim(ylim)
self.axes[1].set_xlabel(self.xlabel)
self.axes[1].set_ylabel(self.ylabel)

Expand Down Expand Up @@ -961,3 +961,79 @@ def set_legend(self, set_legend):

def get_xy(self):
return [self.x, self.y]


#Let's make a curve fit visualizer

class SpectralImageFitVisualizer(SpectralImageVisualizer):

def __init__(self, original_dataset, fit_dataset, figure=None, horizontal=True):
'''
Visualizer for spectral image datasets, fit by the Sidpy Fitter
This class is called by Sidpy Fitter for visualizing the raw/fit dataset interactively.
Inputs:
- original_dataset: sidpy.Dataset containing the raw data
- fit_dataset: sidpy.Dataset with the fitted data. This is returned by the
Sidpy Fitter after functional fitting.
- figure: (Optional, default None) - handle to existing figure
- horiziontal: (Optional, default True) - whether spectrum should be plotted horizontally
'''

super().__init__(original_dataset, figure, horizontal)

self.fit_dset = fit_dataset
self.axes[1].clear()
self.get_fit_spectrum()
self.axes[1].plot(self.energy_scale, self.spectrum, 'bo')
self.axes[1].plot(self.energy_scale, self.fit_spectrum, 'r-')

def get_fit_spectrum(self):

from ..sid.dimension import DimensionType
if self.x > self.dset.shape[self.image_dims[0]] - self.bin_x:
self.x = self.dset.shape[self.image_dims[0]] - self.bin_x
if self.y > self.dset.shape[self.image_dims[1]] - self.bin_y:
self.y = self.dset.shape[self.image_dims[1]] - self.bin_y
selection = []

for dim, axis in self.dset._axes.items():
if axis.dimension_type == DimensionType.SPATIAL:
if dim == self.image_dims[0]:
selection.append(slice(self.x, self.x + self.bin_x))
else:
selection.append(slice(self.y, self.y + self.bin_y))

elif axis.dimension_type == DimensionType.SPECTRAL:
selection.append(slice(None))
else:
selection.append(slice(0, 1))

self.spectrum = self.dset[tuple(selection)].mean(axis=tuple(self.image_dims))
self.fit_spectrum = self.fit_dset[tuple(selection)].mean(axis=tuple(self.image_dims))
# * self.intensity_scale[self.x,self.y]

return self.fit_spectrum.squeeze(), self.spectrum.squeeze()


def _update(self, ev=None):

xlim = self.axes[1].get_xlim()
ylim = self.axes[1].get_ylim()
self.axes[1].clear()
self.get_fit_spectrum()

self.axes[1].plot(self.energy_scale, self.spectrum, 'bo', label='experiment')
self.axes[1].plot(self.energy_scale, self.fit_spectrum, 'r-', label='fit')

if self.set_title:
self.axes[1].set_title('spectrum {}, {}'.format(self.x, self.y))

self.axes[1].set_xlim(xlim)
#self.axes[1].set_ylim(ylim)
self.axes[1].set_xlabel(self.xlabel)
self.axes[1].set_ylabel(self.ylabel)

self.fig.canvas.draw_idle()


0 comments on commit a0d059b

Please sign in to comment.