diff --git a/sidpy/proc/fitter.py b/sidpy/proc/fitter.py index 09019ceb..498bc7ac 100644 --- a/sidpy/proc/fitter.py +++ b/sidpy/proc/fitter.py @@ -6,12 +6,15 @@ @author: Rama Vasudevan, Mani Valleti """ +from xml.dom import NotFoundErr from dask.distributed import Client import numpy as np import dask import inspect from ..sid import Dimension, Dataset from ..sid.dimension import DimensionType +from ..viz.dataset_viz import SpectralImageFitVisualizer +from ..sid.dataset import DataType try: from scipy.optimize import curve_fit @@ -47,7 +50,7 @@ def __init__(self, sidpy_dataset, fit_fn, xvec=None, ind_dims=None, guess_fn=Non If NOT provided, it is assumed that all the non-spectral dimensions are independent dimensions. guess_fn: (function) (optional) This optional function should be utilized to generate priors for the full fit - It takes the same arguments as the fitting function and should return the same type of results array. + It takes (xvec,yvec) as inputs and should return the fit parameters. If the guess_fn is NOT provided, then the user MUST input the num_fit_parms. num_fit_parms: (int) Number of fitting parameters. This is needed IF the guess function is not provided to set @@ -170,6 +173,7 @@ def __init__(self, sidpy_dataset, fit_fn, xvec=None, ind_dims=None, guess_fn=Non self.return_std = return_std self.return_cov = return_cov self.return_fit = return_fit + self.fitted_dset = None self.mean_fit_results = [] if self.return_cov: @@ -251,7 +255,7 @@ def do_fit(self, **kwargs): p0 = self.prior[ind, :] lazy_result = dask.delayed(SidFitter.default_curve_fit)(self.fit_fn, self.dep_vec, - self.folded_dataset[ind, :], + self.folded_dataset[ind, :],self.num_fit_parms, return_cov=(self.return_cov or self.return_std), p0=p0, **kwargs) fit_results.append(lazy_result) @@ -263,7 +267,7 @@ def do_fit(self, **kwargs): self.get_km_priors() for ind in range(self.num_computations): lazy_result = dask.delayed(SidFitter.default_curve_fit)(self.fit_fn, self.dep_vec, - self.folded_dataset[ind, :], + self.folded_dataset[ind, :], self.num_fit_parms, return_cov=(self.return_cov or self.return_std), p0=self.km_priors[self.km_labels[ind]], **kwargs) @@ -394,7 +398,7 @@ def get_fitted_dataset(self): fitted_sid_dset_folded = fitted_dset_fold.like_data(np_folded_arr, title=fitted_dset_fold.title) fitted_sid_dset = fitted_sid_dset_folded.unfold() fitted_sid_dset.original_metadata = self.dataset.original_metadata.copy() - + self.fitted_dset = fitted_sid_dset return fitted_sid_dset def get_km_priors(self): @@ -422,14 +426,36 @@ def get_km_priors(self): else: p0 = np.random.normal(loc=0.5, scale=0.1, size=self.num_fit_parms) - km_priors.append(SidFitter.default_curve_fit(self.fit_fn, self.dep_vec, cen, + km_priors.append(SidFitter.default_curve_fit(self.fit_fn, self.dep_vec, cen, self.num_fit_parms, return_cov=False, p0=p0, maxfev=10000)) self.km_priors = np.array(km_priors) self.num_fit_parms = self.km_priors.shape[-1] + def visualize_fit_results(self, figure = None, horizontal = True): + ''' + Calls the interactive visualizer for comparing raw and fit datasets. + + Inputs: + - figure: (Optional, default None) - handle to existing figure + - horiziontal: (Optional, default True) - whether spectrum should be plotted horizontally + + ''' + dset_type = self.dataset.data_type + supported_types = ['SPECTRAL_IMAGE'] + if self.fitted_dset == None: + raise NotFoundErr("No fitted dataset found. Re-run with return_fit=True to use this feature") + if dset_type == DataType.SPECTRAL_IMAGE: + visualizer = SpectralImageFitVisualizer(self.dataset, self.fitted_dset, + figure=figure, horizontal=horizontal) + else: + raise NotImplementedError("Data type is {} but currently we only support types {}".format(dset_type, supported_types)) + + return visualizer + + @staticmethod - def default_curve_fit(fit_fn, xvec, yvec, return_cov=True, **kwargs): + def default_curve_fit(fit_fn, xvec, yvec, num_fit_parms, return_cov=True, **kwargs): xvec = np.array(xvec) yvec = np.array(yvec) yvec = yvec.ravel() @@ -437,8 +463,11 @@ def default_curve_fit(fit_fn, xvec, yvec, return_cov=True, **kwargs): if curve_fit is None: raise ModuleNotFoundError("scipy is not installed") else: - popt, pcov = curve_fit(fit_fn, xvec, yvec, **kwargs) - + try: + popt, pcov = curve_fit(fit_fn, xvec, yvec, **kwargs) + except: + popt = np.zeros(num_fit_parms) + pcov = np.zeros((num_fit_parms, num_fit_parms)) if return_cov: return popt, pcov else: diff --git a/sidpy/viz/dataset_viz.py b/sidpy/viz/dataset_viz.py index 09e81152..b13846cb 100644 --- a/sidpy/viz/dataset_viz.py +++ b/sidpy/viz/dataset_viz.py @@ -72,7 +72,7 @@ def __init__(self, dset, spectrum_number=0, figure=None, **kwargs): self.dim = self.dset._axes[self.spectral_dims[0]] if is_complex_dtype(dset.dtype): - # Plot real and image + # Plot real and imaginary self.fig, self.axes = plt.subplots(nrows=2, **fig_args) self.axes[0].plot(self.dim.values, self.dset.squeeze().abs(), **kwargs) @@ -681,7 +681,7 @@ def _update(self, ev=None): self.axes[1].set_title('spectrum {}, {}'.format(self.x, self.y)) self.axes[1].set_xlim(xlim) - self.axes[1].set_ylim(ylim) + #self.axes[1].set_ylim(ylim) self.axes[1].set_xlabel(self.xlabel) self.axes[1].set_ylabel(self.ylabel) @@ -961,3 +961,79 @@ def set_legend(self, set_legend): def get_xy(self): return [self.x, self.y] + + +#Let's make a curve fit visualizer + +class SpectralImageFitVisualizer(SpectralImageVisualizer): + + def __init__(self, original_dataset, fit_dataset, figure=None, horizontal=True): + ''' + Visualizer for spectral image datasets, fit by the Sidpy Fitter + This class is called by Sidpy Fitter for visualizing the raw/fit dataset interactively. + + Inputs: + - original_dataset: sidpy.Dataset containing the raw data + - fit_dataset: sidpy.Dataset with the fitted data. This is returned by the + Sidpy Fitter after functional fitting. + - figure: (Optional, default None) - handle to existing figure + - horiziontal: (Optional, default True) - whether spectrum should be plotted horizontally + ''' + + super().__init__(original_dataset, figure, horizontal) + + self.fit_dset = fit_dataset + self.axes[1].clear() + self.get_fit_spectrum() + self.axes[1].plot(self.energy_scale, self.spectrum, 'bo') + self.axes[1].plot(self.energy_scale, self.fit_spectrum, 'r-') + + def get_fit_spectrum(self): + + from ..sid.dimension import DimensionType + if self.x > self.dset.shape[self.image_dims[0]] - self.bin_x: + self.x = self.dset.shape[self.image_dims[0]] - self.bin_x + if self.y > self.dset.shape[self.image_dims[1]] - self.bin_y: + self.y = self.dset.shape[self.image_dims[1]] - self.bin_y + selection = [] + + for dim, axis in self.dset._axes.items(): + if axis.dimension_type == DimensionType.SPATIAL: + if dim == self.image_dims[0]: + selection.append(slice(self.x, self.x + self.bin_x)) + else: + selection.append(slice(self.y, self.y + self.bin_y)) + + elif axis.dimension_type == DimensionType.SPECTRAL: + selection.append(slice(None)) + else: + selection.append(slice(0, 1)) + + self.spectrum = self.dset[tuple(selection)].mean(axis=tuple(self.image_dims)) + self.fit_spectrum = self.fit_dset[tuple(selection)].mean(axis=tuple(self.image_dims)) + # * self.intensity_scale[self.x,self.y] + + return self.fit_spectrum.squeeze(), self.spectrum.squeeze() + + + def _update(self, ev=None): + + xlim = self.axes[1].get_xlim() + ylim = self.axes[1].get_ylim() + self.axes[1].clear() + self.get_fit_spectrum() + + self.axes[1].plot(self.energy_scale, self.spectrum, 'bo', label='experiment') + self.axes[1].plot(self.energy_scale, self.fit_spectrum, 'r-', label='fit') + + if self.set_title: + self.axes[1].set_title('spectrum {}, {}'.format(self.x, self.y)) + + self.axes[1].set_xlim(xlim) + #self.axes[1].set_ylim(ylim) + self.axes[1].set_xlabel(self.xlabel) + self.axes[1].set_ylabel(self.ylabel) + + self.fig.canvas.draw_idle() + +