Skip to content

Commit

Permalink
add fit_non_finite flag to model_fitting 1D (#2437)
Browse files Browse the repository at this point in the history
  • Loading branch information
cshanahan1 authored Sep 28, 2023
1 parent 11c9139 commit 8bb2bd4
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 25 deletions.
7 changes: 7 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ Specviz2d
Other Changes and Additions
---------------------------

- Better handling of non-finite uncertainties in model fitting. The 'filter_non_finite' flag (for the
LevMarLSQFitter) now filters datapoints with non-finite weights. In Specviz, if a fully-finite spectrum
with non-finite uncertainties is loaded, the uncertainties will be dropped so every datapoint isn't
filtered. For other scenarios with non-finite uncertainties, there are appropriate warning messages
displayed to alert users that data points are being filtered because of non-finite uncertainties (when
flux is finite). [#2437]

3.7.1 (unreleased)
==================

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def fit_model_to_spectrum(spectrum, component_list, expression,
return _fit_1D(initial_model, spectrum, run_fitter, window=window)


def _fit_1D(initial_model, spectrum, run_fitter, window=None):
def _fit_1D(initial_model, spectrum, run_fitter, filter_non_finite=True, window=None):
"""
Fits an astropy CompoundModel to a Spectrum1D instance.
Expand Down Expand Up @@ -104,7 +104,8 @@ def _fit_1D(initial_model, spectrum, run_fitter, window=None):
weights = 'unc'
else:
weights = None
output_model = fit_lines(spectrum, initial_model, weights=weights, window=window)
output_model = fit_lines(spectrum, initial_model, weights=weights,
filter_non_finite=filter_non_finite, window=window)
output_values = output_model(spectrum.spectral_axis)
else:
# Return without fitting.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
SubsetSelect,
DatasetSelectMixin,
DatasetSpectralSubsetValidMixin,
NonFiniteUncertaintyMismatchMixin,
AutoTextField,
AddResultsMixin,
TableMixin)
Expand All @@ -40,6 +41,7 @@ def __init__(self, value, unit=None):
@tray_registry('g-model-fitting', label="Model Fitting", viewer_requirements='spectrum')
class ModelFitting(PluginTemplateMixin, DatasetSelectMixin,
SpectralSubsetSelectMixin, DatasetSpectralSubsetValidMixin,
NonFiniteUncertaintyMismatchMixin,
AddResultsMixin, TableMixin):
"""
See the :ref:`Model Fitting Plugin Documentation <specviz-model-fitting>` for more details.
Expand Down Expand Up @@ -312,7 +314,7 @@ def _warn_if_no_equation(self):

def _get_1d_spectrum(self):
# retrieves the 1d spectrum (accounting for spatial subset for cubeviz, if necessary)
return self.dataset.selected_spectrum_for_spatial_subset(self.spatial_subset_selected) # noqa
return self.dataset.selected_spectrum_for_spatial_subset(self.spatial_subset_selected) # noqa

@observe("dataset_selected", "spatial_subset_selected")
def _dataset_selected_changed(self, event=None):
Expand Down Expand Up @@ -781,6 +783,7 @@ def calculate_fit(self, add_data=True):
fitted spectrum/cube
residuals (if ``residuals_calculate`` is set to ``True``)
"""

if not self.spectral_subset_valid:
valid, spec_range, subset_range = self._check_dataset_spectral_subset_valid(return_ranges=True) # noqa
raise ValueError(f"spectral subset '{self.spectral_subset.selected}' {subset_range} is outside data range of '{self.dataset.selected}' {spec_range}") # noqa
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,12 @@
</span>
</v-row>

<v-row v-if="non_finite_uncertainty_mismatch">
<span class="v-messages v-messages__message text--secondary" style="color: red !important">
"Non-finite uncertainties exist in the selected data, these data points will be excluded from the fit."
</span>
</v-row>

</div>
</plugin-add-results>

Expand Down
53 changes: 53 additions & 0 deletions jdaviz/configs/default/plugins/model_fitting/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,56 @@ def test_invalid_subset(specviz_helper, spectrum1d):

plugin.dataset = 'left_spectrum'
assert plugin._obj.spectral_subset_valid


def test_all_nan_uncert(specviz_helper):

# test that if you have a fully finite data array, and a fully nan/inf
# uncert array, that it is set to None and the fit proceeds (rather than
# being filtered in the fitter, as would normally happen with nans)

uncertainty = StdDevUncertainty([np.nan, np.nan, np.nan, np.nan, np.nan, np.nan] * u.Jy)
spec = Spectrum1D(flux=[1, 2, 3, 4, 5, 6]*u.Jy, uncertainty=uncertainty)
specviz_helper.load_data(spec)

plugin = specviz_helper.plugins['Model Fitting']
plugin.create_model_component('Linear1D')

with pytest.warns(AstropyUserWarning, match='Model is linear in parameters'):
plugin.calculate_fit()

# check that slope and intercept are fit correctly
plugin._obj.component_models[0]['parameters'][0]['value'] == 1.0
plugin._obj.component_models[0]['parameters'][1]['value'] == 1.0

# and that this value is correctly set to false, even though there IS
# a mismatch, since its the entire array it will be reset
assert plugin._obj.non_finite_uncertainty_mismatch is False


def test_all_nan_uncert_subset(specviz_helper):

# test that nans in uncertainty array are filtered from fit (contrary to
# what is tested in test_all_nan_uncert, when its not the entire array they
# SHOULD be filtered even when corresponding data values are finite), and that
# the `non_finite_uncertainty_mismatch` traitlet is True to trigger a warning
# message

uncertainty = StdDevUncertainty([1, 1, np.nan, np.nan, np.nan, np.nan] * u.Jy)
spec = Spectrum1D(flux=[2, 4, 3, 4, 5, 6]*u.Jy, uncertainty=uncertainty)
specviz_helper.load_data(spec)

plugin = specviz_helper.plugins['Model Fitting']
plugin.create_model_component('Linear1D')

with pytest.warns(AstropyUserWarning, match='Model is linear in parameters'):
plugin.calculate_fit()

# check that slope and intercept are fit correctly to just the first 2
# data points
plugin._obj.component_models[0]['parameters'][0]['value'] == 2.0
plugin._obj.component_models[0]['parameters'][1]['value'] == 2.0

# # and that this value is correctly set to false, even though there IS
# # a mismatch, since its the entire array it will be reset
assert plugin._obj.non_finite_uncertainty_mismatch is True
85 changes: 63 additions & 22 deletions jdaviz/configs/specviz/plugins/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from astropy.nddata import StdDevUncertainty
from specutils import Spectrum1D, SpectrumList, SpectrumCollection

from jdaviz.core.events import SnackbarMessage
from jdaviz.core.registries import data_parser_registry
from jdaviz.utils import standardize_metadata


__all__ = ["specviz_spectrum1d_parser"]


Expand All @@ -32,6 +34,7 @@ def specviz_spectrum1d_parser(app, data, data_label=None, format=None, show_in_v
the extensions within each spectrum file passed to the parser and
add a concatenated spectrum to the data collection.
"""

spectrum_viewer_reference_name = app._jdaviz_helper._default_spectrum_viewer_reference_name
# If no data label is assigned, give it a unique name
if not data_label:
Expand All @@ -41,8 +44,8 @@ def specviz_spectrum1d_parser(app, data, data_label=None, format=None, show_in_v
raise TypeError("SpectrumCollection detected."
" Please provide a Spectrum1D or SpectrumList")
elif isinstance(data, Spectrum1D):
data = [data]
data_label = [app.return_data_label(data_label, alt_name="specviz_data")]
data = [data]
# No special processing is needed in this case, but we include it for completeness
elif isinstance(data, SpectrumList):
pass
Expand All @@ -55,6 +58,7 @@ def specviz_spectrum1d_parser(app, data, data_label=None, format=None, show_in_v
try:
data = [Spectrum1D.read(str(path), format=format)]
data_label = [app.return_data_label(data_label, alt_name="specviz_data")]

except IORegistryError:
# Multi-extension files may throw a registry error
data = SpectrumList.read(str(path), format=format)
Expand All @@ -76,17 +80,57 @@ def specviz_spectrum1d_parser(app, data, data_label=None, format=None, show_in_v
raise ValueError(f"Length of data labels list ({len(data_label)}) is different"
f" than length of list of data ({len(data)})")

with app.data_collection.delay_link_manager_update():
# these are used to build a combined spectrum with all
# input spectra included (taken from https://github.com/spacetelescope/
# dat_pyinthesky/blob/main/jdat_notebooks/MRS_Mstar_analysis/
# JWST_Mstar_dataAnalysis_analysis.ipynb)
wlallorig = []
fnuallorig = []
dfnuallorig = []
wlallorig = [] # to collect wavelengths
fnuallorig = [] # fluxes
dfnuallorig = [] # and uncertanties (if present)

for spec in data:
for wlind in range(len(spec.spectral_axis)):
wlallorig.append(spec.spectral_axis[wlind].value)
fnuallorig.append(spec.flux[wlind].value)

# because some spec in the list might have uncertainties and
# others may not, if uncert is None, set to list of NaN. later,
# if the concatenated list of uncertanties is all nan (meaning
# they were all nan to begin with, or all None), it will be set
# to None on the final Spectrum1D
if spec.uncertainty[wlind] is not None:
dfnuallorig.append(spec.uncertainty[wlind].array)
else:
dfnuallorig.append(np.nan)

# if the entire uncert. array is Nan and the data is not, model fitting won't
# work (more generally, if uncert[i] is nan/inf and flux[i] is not, fitting will
# fail, but just deal with the all nan case here since it is straightforward).
# set uncerts. to None if they are all nan/inf, and display a warning message.
set_nans_to_none = False
if isinstance(data, SpectrumList):
uncerts = dfnuallorig # alias these for clarity later on
if uncerts is not None and not np.any(uncerts):
uncerts = None
set_nans_to_none = True
else:
if data[0].uncertainty is not None:
uncerts_finite = np.isfinite(data[0].uncertainty.array)
if not np.any(uncerts_finite):
data[0].uncertainty = None
set_nans_to_none = True

if set_nans_to_none:
# alert user that we have changed their all-nan uncertainty array to None
msg = 'All uncertainties are nonfinite, replacing with uncertainty=None.'
app.hub.broadcast(SnackbarMessage(msg, color="warning", sender=app))

with app.data_collection.delay_link_manager_update():
for i, spec in enumerate(data):

# note: if SpectrumList, this is just going to be the last unit when
# combined in the next block. should put a check here to make sure
# units are all the same or collect them in a list?
wave_units = spec.spectral_axis.unit
flux_units = spec.flux.unit

Expand All @@ -96,17 +140,8 @@ def specviz_spectrum1d_parser(app, data, data_label=None, format=None, show_in_v
app.add_data(spec, data_label[i])

# handle display, with the SpectrumList special case in mind.
if show_in_viewer:
if isinstance(data, SpectrumList):

# add spectrum to combined result
for wlind in range(len(spec.spectral_axis)):
wlallorig.append(spec.spectral_axis[wlind].value)
fnuallorig.append(spec.flux[wlind].value)
dfnuallorig.append(spec.uncertainty[wlind].array)

elif i == 0:
app.add_data_to_viewer(spectrum_viewer_reference_name, data_label[i])
if i == 0 and show_in_viewer:
app.add_data_to_viewer(spectrum_viewer_reference_name, data_label[i])

if concat_by_file and isinstance(data, SpectrumList):
# If >1 spectra in the list were opened from the same FITS file,
Expand All @@ -115,8 +150,9 @@ def specviz_spectrum1d_parser(app, data, data_label=None, format=None, show_in_v
unique_files = group_spectra_by_filename(app.data_collection)
for filename, datasets in unique_files.items():
if len(datasets) > 1:
spec = combine_lists_to_1d_spectrum(wlallorig, fnuallorig, dfnuallorig,
wave_units, flux_units)
spec = combine_lists_to_1d_spectrum(wlallorig, fnuallorig,
dfnuallorig, wave_units,
flux_units)

# Make metadata layout conform with other viz.
spec.meta = standardize_metadata(spec.meta)
Expand Down Expand Up @@ -162,7 +198,7 @@ def combine_lists_to_1d_spectrum(wl, fnu, dfnu, wave_units, flux_units):
Wavelength in each spectral channel
fnu : list of `~astropy.units.Quantity`s
Flux in each spectral channel
dfnu : list of `~astropy.units.Quantity`s
dfnu : list of `~astropy.units.Quantity`s or None
Uncertainty on each flux
Returns
Expand All @@ -172,14 +208,19 @@ def combine_lists_to_1d_spectrum(wl, fnu, dfnu, wave_units, flux_units):
"""
wlallarr = np.array(wl)
fnuallarr = np.array(fnu)
dfnuallarr = np.array(dfnu)
srtind = np.argsort(wlallarr)
if dfnu is not None:
dfnuallarr = np.array(dfnu)
fnuallerr = dfnuallarr[srtind]
wlall = wlallarr[srtind]
fnuall = fnuallarr[srtind]
fnuallerr = dfnuallarr[srtind]

# units are not being handled properly yet.
unc = StdDevUncertainty(fnuallerr * flux_units)
if dfnu is not None:
unc = StdDevUncertainty(fnuallerr * flux_units)
else:
unc = None

spec = Spectrum1D(flux=fnuall * flux_units, spectral_axis=wlall * wave_units,
uncertainty=unc)
return spec
68 changes: 68 additions & 0 deletions jdaviz/core/template_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
'DatasetSpectralSubsetValidMixin',
'ViewerSelect', 'ViewerSelectMixin',
'LayerSelect', 'LayerSelectMixin',
'NonFiniteUncertaintyMismatchMixin',
'DatasetSelect', 'DatasetSelectMixin',
'FileImportSelectPluginComponent', 'HasFileImportSelect',
'Table', 'TableMixin',
Expand Down Expand Up @@ -1662,6 +1663,7 @@ class DatasetSpectralSubsetValidMixin(VuetifyTemplate, HubListener):

@observe("dataset_selected", "spectral_subset_selected")
def _check_dataset_spectral_subset_valid(self, event={}, return_ranges=False):

if self.spectral_subset_selected == "Entire Spectrum":
self.spectral_subset_valid = True
else:
Expand All @@ -1677,6 +1679,72 @@ def _check_dataset_spectral_subset_valid(self, event={}, return_ranges=False):
return self.spectral_subset_valid


class NonFiniteUncertaintyMismatchMixin(VuetifyTemplate, HubListener):
"""Adds a traitlet that identifies if there are any finite data values
that correspond to a non-finite uncertainty at that index.
In model fitting, the presence of finite, fittable data with corresponding
non-finite uncertainties can cause issues. Finite data values will be
filtered out in this case which may be undesirable. This traitlet when
True triggers a warning in the model fitting plugin (in Specviz only,
currently) if there are any finite values with non-finite uncertainties.
Note that if a the uncertainty array is FULLY non-finite and the data is
FULLY finite, uncertainties will be set to None (in the Specviz parser),
so this traitlet will be False in that case (and therefore no warning
message displayed in the plugin).
"""

non_finite_uncertainty_mismatch = Bool(False).tag(sync=True)

# every time a data/subset selection is changed, check the data selection and
# its uncertainties to see if there are any finite data elements with
# uncertainties. Warn in plugin if this occurs.

@observe("dataset_selected", "spectral_subset_selected")
def _check_non_finite_uncertainty_mismatch(self, event={}):

if not hasattr(self, 'dataset') or self.dataset_selected == '':
# during initial init, this can trigger before the component is initialized
return

if not hasattr(self, '_get_1d_spectrum'):
# only model_fitting has _get_1d_spectrum(), but this method is here
# instead of there because it may eventually be used by other plugins.
# if that happens, move _get_1d_spectrum() somewhere more general
raise NotImplementedError("_get_1d_spectrum() must be available in "
"plugin to use NonFiniteUncertaintyMismatchMixin")

spec = self._get_1d_spectrum()

if spec.uncertainty is None:
self.non_finite_uncertainty_mismatch = False
return

if self.spectral_subset_selected == "Entire Spectrum":
flux = spec.flux
uncert = spec.uncertainty
else:
# get selected subset
spec = self._apply_subset_masks(self._get_1d_spectrum(), self.spectral_subset)
flux = spec.flux[~spec.mask]
uncert = spec.uncertainty[~spec.mask]

uncert = uncert.array

if not np.any(uncert):
self.non_finite_uncertainty_mismatch = False
return

flux = flux.value

mismatch = np.any(np.logical_and(~np.isfinite(uncert), np.isfinite(flux)))

# np.any returns numpy bool type, which traitlets doesn't like
# so cast to boolean
self.non_finite_uncertainty_mismatch = bool(mismatch)


class ViewerSelect(SelectPluginComponent):
"""
Plugin select for viewers, with support for single or multi-selection.
Expand Down

0 comments on commit 8bb2bd4

Please sign in to comment.