From 03d9377bcf4197ca68ca1937bd1cd2910d21e4e2 Mon Sep 17 00:00:00 2001 From: Jesse Averbukh Date: Mon, 15 Apr 2024 10:23:56 -0400 Subject: [PATCH] Enable spectral conversion in cubeviz (#2758) * Add unit conversion plugin with flux conversion disabled Add message to flux conversion dropdown in cubeviz Have moment map units update after conversion Update continuum, moment maps, and spectral extraction Fix model fitting and aperture plugins and add tests Update slice_values to account for unit conversion Remove print statment and unused import Fix style Fix moment map and slice tests Fix style 2 Add LinkSameWithUnits and slice indicator updates * Use cached_property for slice_values * Use correct cache clearing protocol Running CI without cached_property Fix everything except todo statements Uncommented code needed for uncert viewer to change slice Remove comments * Fix message handlers * Remove space * Minor changes * Override display unit handling in slice mark so we don't change value twice * Simplify and generalize viewer slice values code * Fix style * Move code out of try except block * Address review comments * Generalize display unit call in spectral extraction * Fix test failure * Rename slice_axis property to slice_display_unit_name * Add changelog --------- Co-authored-by: Ricky O'Steen --- CHANGES.rst | 2 + jdaviz/app.py | 2 +- jdaviz/configs/cubeviz/cubeviz.yaml | 1 + .../plugins/moment_maps/moment_maps.py | 16 ++-- .../moment_maps/tests/test_moment_maps.py | 8 +- jdaviz/configs/cubeviz/plugins/slice/slice.py | 21 +++-- .../cubeviz/plugins/slice/tests/test_slice.py | 10 +++ .../spectral_extraction.py | 9 +- jdaviz/configs/cubeviz/plugins/viewers.py | 89 ++++++++++++++++--- .../plugins/model_fitting/model_fitting.py | 6 +- .../unit_conversion/unit_conversion.py | 12 +-- .../unit_conversion/unit_conversion.vue | 6 ++ jdaviz/core/helpers.py | 3 +- jdaviz/core/marks.py | 6 ++ jdaviz/core/template_mixin.py | 13 +-- 15 files changed, 160 insertions(+), 44 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 9dc8b6ce84..691d956768 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,8 @@ New Features Cubeviz ^^^^^^^ +- Enable spectral unit conversion in cubeviz. [#2758] + Imviz ^^^^^ diff --git a/jdaviz/app.py b/jdaviz/app.py index 0258ac7253..7cabcd1200 100644 --- a/jdaviz/app.py +++ b/jdaviz/app.py @@ -653,7 +653,7 @@ def _link_new_data(self, reference_data=None, data_to_be_linked=None): linked_flux_component = dc[-1].components[-1] links = [ - LinkSame(ref_wavelength_component, linked_wavelength_component), + LinkSameWithUnits(ref_wavelength_component, linked_wavelength_component), LinkSame(ref_flux_component, linked_flux_component) ] diff --git a/jdaviz/configs/cubeviz/cubeviz.yaml b/jdaviz/configs/cubeviz/cubeviz.yaml index 92a7ac1a4c..a52a1de34a 100644 --- a/jdaviz/configs/cubeviz/cubeviz.yaml +++ b/jdaviz/configs/cubeviz/cubeviz.yaml @@ -23,6 +23,7 @@ tray: - g-subset-plugin - g-markers - cubeviz-slice + - g-unit-conversion - g-gaussian-smooth - g-collapse - g-model-fitting diff --git a/jdaviz/configs/cubeviz/plugins/moment_maps/moment_maps.py b/jdaviz/configs/cubeviz/plugins/moment_maps/moment_maps.py index 6db49c68f5..8ac8928506 100644 --- a/jdaviz/configs/cubeviz/plugins/moment_maps/moment_maps.py +++ b/jdaviz/configs/cubeviz/plugins/moment_maps/moment_maps.py @@ -9,7 +9,7 @@ from specutils import manipulation, analysis, Spectrum1D from jdaviz.core.custom_traitlets import IntHandleEmpty, FloatHandleEmpty -from jdaviz.core.events import SnackbarMessage +from jdaviz.core.events import SnackbarMessage, GlobalDisplayUnitChanged from jdaviz.core.registries import tray_registry from jdaviz.core.template_mixin import (PluginTemplateMixin, DatasetSelectMixin, @@ -96,6 +96,8 @@ def __init__(self, *args, **kwargs): self.dataset.add_filter('is_cube') self.add_results.viewer.filters = ['is_image_viewer'] + self.hub.subscribe(self, GlobalDisplayUnitChanged, + handler=self._set_data_units) if self.app.state.settings.get('server_is_remote', False): # when the server is remote, saving the file in python would save on the server, not @@ -155,12 +157,16 @@ def _set_data_units(self, event={}): if self.dataset_selected != "": # Spectral axis is first in this list data = self.app.data_collection[self.dataset_selected] - if self.app.data_collection[self.dataset_selected].coords is not None: + if (self.spectrum_viewer and hasattr(self.spectrum_viewer.state, 'x_display_unit') + and self.spectrum_viewer.state.x_display_unit is not None): + sunit = self.spectrum_viewer.state.x_display_unit + elif self.app.data_collection[self.dataset_selected].coords is not None: sunit = data.coords.world_axis_units[0] - self.dataset_spectral_unit = sunit - unit_dict["Spectral Unit"] = sunit else: - self.dataset_spectral_unit = "" + sunit = "" + self.dataset_spectral_unit = sunit + unit_dict["Spectral Unit"] = sunit + unit_dict["Flux"] = data.get_component('flux').units # Update units in selection item dictionary diff --git a/jdaviz/configs/cubeviz/plugins/moment_maps/tests/test_moment_maps.py b/jdaviz/configs/cubeviz/plugins/moment_maps/tests/test_moment_maps.py index c9ac3599af..62a7229f14 100644 --- a/jdaviz/configs/cubeviz/plugins/moment_maps/tests/test_moment_maps.py +++ b/jdaviz/configs/cubeviz/plugins/moment_maps/tests/test_moment_maps.py @@ -147,7 +147,6 @@ def test_moment_velocity_calculation(cubeviz_helper, spectrum1d_cube): uncert_viewer.state._set_axes_aspect_ratio(1) mm = cubeviz_helper.plugins["Moment Maps"] - print(mm._obj.dataset_selected) mm._obj.dataset_selected = 'test[FLUX]' # Test moment 1 in velocity @@ -170,6 +169,13 @@ def test_moment_velocity_calculation(cubeviz_helper, spectrum1d_cube): "World 13h39m59.9731s +27d00m00.3600s (ICRS)", "204.9998877673 27.0001000000 (deg)") + # Add test for unit conversion + assert mm._obj.output_radio_items[0]['unit_str'] == 'm' + uc_plugin = cubeviz_helper.plugins['Unit Conversion']._obj + uc_plugin.spectral_unit.selected = 'Angstrom' + assert mm._obj.output_radio_items[0]['unit_str'] == 'Angstrom' + uc_plugin.spectral_unit.selected = 'm' + # Test moment 2 in velocity mm.n_moment = 2 mm.calculate_moment() diff --git a/jdaviz/configs/cubeviz/plugins/slice/slice.py b/jdaviz/configs/cubeviz/plugins/slice/slice.py index 0f137f5289..8cc5c5bd02 100644 --- a/jdaviz/configs/cubeviz/plugins/slice/slice.py +++ b/jdaviz/configs/cubeviz/plugins/slice/slice.py @@ -1,6 +1,7 @@ import threading import time import warnings +from functools import cached_property import numpy as np from astropy import units as u @@ -12,7 +13,7 @@ CubevizImageView) from jdaviz.configs.cubeviz.helper import _spectral_axis_names from jdaviz.core.custom_traitlets import FloatHandleEmpty -from jdaviz.core.events import (AddDataMessage, SliceToolStateMessage, +from jdaviz.core.events import (AddDataMessage, RemoveDataMessage, SliceToolStateMessage, SliceSelectSliceMessage, SliceValueUpdatedMessage, NewViewerMessage, ViewerAddedMessage, ViewerRemovedMessage, GlobalDisplayUnitChanged) @@ -82,6 +83,8 @@ def __init__(self, *args, **kwargs): handler=self._on_viewer_removed) self.hub.subscribe(self, AddDataMessage, handler=self._on_add_data) + self.hub.subscribe(self, RemoveDataMessage, + handler=self._on_valid_selection_values_changed) # connect any pre-existing viewers for viewer in self.app._viewer_store.values(): @@ -127,7 +130,7 @@ def _initialize_location(self, *args): return @property - def slice_axis(self): + def slice_display_unit_name(self): # global display unit "axis" corresponding to the slice axis return 'spectral' @@ -221,6 +224,7 @@ def _on_add_data(self, msg): if isinstance(msg.viewer, WithSliceSelection): # instead of just setting viewer.slice_value, we'll make sure the "snapping" logic # is updated (if enabled) + self._on_valid_selection_values_changed() self._on_value_updated({'new': self.value}) def _on_select_slice_message(self, msg): @@ -229,16 +233,21 @@ def _on_select_slice_message(self, msg): self.value = msg.value def _on_global_display_unit_changed(self, msg): - if msg.axis != self.slice_axis: + if msg.axis != self.slice_display_unit_name: return if not self.value_unit: self.value_unit = str(msg.unit) return prev_unit = u.Unit(self.value_unit) self.value_unit = str(msg.unit) - self.value = (self.value * prev_unit).to_value(msg.unit) + self._on_valid_selection_values_changed() + self.value = (self.value * prev_unit).to_value(msg.unit, equivalencies=u.spectral()) - @property + def _on_valid_selection_values_changed(self, msg=None): + if 'valid_selection_values' in self.__dict__: + del self.__dict__['valid_selection_values'] + + @cached_property def valid_selection_values(self): # all available slice values from cubes (unsorted) viewers = self.slice_selection_viewers @@ -280,7 +289,6 @@ def _on_value_updated(self, event): except ValueError: return return - if self.snap_to_slice and not self.value_editing: valid_values = self.valid_selection_values if len(valid_values): @@ -292,7 +300,6 @@ def _on_value_updated(self, event): self.value = float(closest_value) # will trigger another call to this method return - for viewer in self.slice_indicator_viewers: viewer._set_slice_indicator_value(self.value) for viewer in self.slice_selection_viewers: diff --git a/jdaviz/configs/cubeviz/plugins/slice/tests/test_slice.py b/jdaviz/configs/cubeviz/plugins/slice/tests/test_slice.py index 3c77484ade..466ab1c297 100644 --- a/jdaviz/configs/cubeviz/plugins/slice/tests/test_slice.py +++ b/jdaviz/configs/cubeviz/plugins/slice/tests/test_slice.py @@ -58,6 +58,16 @@ def test_slice(cubeviz_helper, spectrum1d_cube): cubeviz_helper.select_wavelength(4.62360028e-07) assert sl.value == slice_values[1] + # Add test for unit conversion + uc_plugin = cubeviz_helper.plugins['Unit Conversion']._obj + uc_plugin.spectral_unit.selected = 'Angstrom' + assert sl.value_unit == 'Angstrom' + cubeviz_helper.select_wavelength(4623.60028) + assert sl.value == 4623.600276968349 + + # Retrieve updated slice_values + slice_values = sl.valid_selection_values_sorted + # Test player buttons API sl.vue_goto_first() diff --git a/jdaviz/configs/cubeviz/plugins/spectral_extraction/spectral_extraction.py b/jdaviz/configs/cubeviz/plugins/spectral_extraction/spectral_extraction.py index d21c8c719f..764e0cb7ac 100644 --- a/jdaviz/configs/cubeviz/plugins/spectral_extraction/spectral_extraction.py +++ b/jdaviz/configs/cubeviz/plugins/spectral_extraction/spectral_extraction.py @@ -166,6 +166,10 @@ def user_api(self): return PluginUserApi(self, expose=expose) + @property + def slice_display_unit_name(self): + return 'spectral' + @property @deprecated(since="3.9", alternative="aperture") def spatial_subset(self): @@ -355,8 +359,7 @@ def get_aperture(self): # Retrieve flux cube and create an array to represent the cone mask flux_cube = self._app._jdaviz_helper._loaded_flux_cube.get_object(cls=Spectrum1D, statistic=None) - # TODO: Replace with code for retrieving display_unit in cubeviz when it is available - display_unit = flux_cube.spectral_axis.unit + display_unit = astropy.units.Unit(self.app._get_display_unit(self.slice_display_unit_name)) # Center is reverse coordinates center = (self.aperture.selected_spatial_region.center.y, @@ -373,7 +376,7 @@ def get_aperture(self): f'Spectral axis unit physical type is {display_unit.physical_type}, ' 'must be length for cone aperture') - fac = flux_cube.spectral_axis.value / self.reference_spectral_value + fac = flux_cube.spectral_axis.to_value(display_unit) / self.reference_spectral_value # TODO: Use flux_cube.spectral_axis.to_value(display_unit) when we have unit conversion. if isinstance(aperture, CircularAperture): diff --git a/jdaviz/configs/cubeviz/plugins/viewers.py b/jdaviz/configs/cubeviz/plugins/viewers.py index f2fb530238..20fa227e6d 100644 --- a/jdaviz/configs/cubeviz/plugins/viewers.py +++ b/jdaviz/configs/cubeviz/plugins/viewers.py @@ -1,5 +1,5 @@ import numpy as np - +import astropy.units as u from functools import cached_property from glue.core import BaseData @@ -12,7 +12,7 @@ from jdaviz.configs.cubeviz.helper import layer_is_cube_image_data from jdaviz.configs.default.plugins.viewers import JdavizViewerMixin from jdaviz.configs.specviz.plugins.viewers import SpecvizProfileView -from jdaviz.core.events import AddDataMessage, RemoveDataMessage +from jdaviz.core.events import AddDataMessage, RemoveDataMessage, GlobalDisplayUnitChanged from jdaviz.core.freezable_state import FreezableBqplotImageViewerState from jdaviz.utils import get_subset_type @@ -25,6 +25,10 @@ class WithSliceIndicator: def slice_component_label(self): return str(self.state.x_att) + @property + def slice_display_unit_name(self): + return 'spectral' + @cached_property def slice_indicator(self): # SliceIndicatorMarks does not yet exist @@ -32,16 +36,33 @@ def slice_indicator(self): self.figure.marks = self.figure.marks + slice_indicator.marks return slice_indicator - @property + @cached_property def slice_values(self): + def _get_component(layer): + # Retrieve display units + slice_display_units = self.jdaviz_app._get_display_unit( + self.slice_display_unit_name + ) + try: - return layer.layer.get_component(self.slice_component_label).data + # Retrieve layer data and units + data_obj = layer.layer.data.get_component(self.slice_component_label).data + data_units = layer.layer.data.get_component(self.slice_component_label).units except (AttributeError, KeyError): # layer either does not have get_component (because its a subset) # or slice_component_label is not a component in this layer # either way, return an empty array and skip this layer return np.array([]) + + data_spec_axis = np.asarray(data_obj.data, dtype=float) * u.Unit(data_units) + + # Convert axis if display units are set and are different + if slice_display_units and slice_display_units != data_units: + return data_spec_axis.to_value(slice_display_units, + equivalencies=u.spectral()) + else: + return data_spec_axis try: return np.asarray(np.unique(np.concatenate([_get_component(layer) for layer in self.layers])), # noqa dtype=float) @@ -68,24 +89,50 @@ def slice_component_label(self): return slice_plg._obj.slice_indicator_viewers[0].slice_component_label @property + def slice_display_unit_name(self): + return 'spectral' + + @cached_property def slice_values(self): - # TODO: make a cached property and invalidate cache on add/remove data # TODO: add support for multiple cubes (but then slice selection needs to be more complex) # if slice_index is 0, then we want the equivalent of [:, 0, 0] # if slice_index is 1, then we want the equivalent of [0, :, 0] # if slice_index is 2, then we want the equivalent of [0, 0, :] take_inds = [2, 1, 0] take_inds.remove(self.slice_index) + converted_axis = np.array([]) for layer in self.layers: + world_comp_ids = layer.layer.data.world_component_ids + if self.slice_index >= len(world_comp_ids): + # Case where 2D image is loaded in image viewer + continue + + # Retrieve display units + slice_display_units = self.jdaviz_app._get_display_unit( + self.slice_display_unit_name + ) + try: - data_obj = layer.layer.data.get_component(self.slice_component_label).data + # Retrieve layer data and units using the slice index of the world components ids + data_obj = layer.layer.data.get_component(world_comp_ids[self.slice_index]).data + data_units = layer.layer.data.get_component(world_comp_ids[self.slice_index]).units except (AttributeError, KeyError): continue + + # Find the spectral axis + data_spec_axis = np.asarray(data_obj.take(0, take_inds[0]).take(0, take_inds[1]), # noqa + dtype=float) + + # Convert to display units if applicable + if slice_display_units and slice_display_units != data_units: + converted_axis = (data_spec_axis * u.Unit(data_units)).to_value( + slice_display_units, + equivalencies=u.spectral() + ) else: - break - else: - return np.array([]) - return np.asarray(data_obj.take(0, take_inds[0]).take(0, take_inds[1]), dtype=float) + converted_axis = data_spec_axis + + return converted_axis @property def slice(self): @@ -143,6 +190,13 @@ def __init__(self, *args, **kwargs): # Hide axes by default self.state.show_axes = False + self.hub.subscribe(self, GlobalDisplayUnitChanged, + handler=self._on_global_display_unit_changed + ) + + self.hub.subscribe(self, AddDataMessage, handler=self._on_global_display_unit_changed) + self.hub.subscribe(self, RemoveDataMessage, handler=self._on_global_display_unit_changed) + @property def _default_spectrum_viewer_reference_name(self): return self.jdaviz_helper._default_spectrum_viewer_reference_name @@ -167,6 +221,11 @@ def active_image_layer(self): return visible_layers[-1] + def _on_global_display_unit_changed(self, msg): + # Clear cache of slice values when units change + if 'slice_values' in self.__dict__: + del self.__dict__['slice_values'] + def _initial_x_axis(self, *args): # Make sure that the x_att is correct on data load ref_data = self.state.reference_data @@ -223,16 +282,25 @@ def __init__(self, *args, **kwargs): self.hub.subscribe(self, AddDataMessage, handler=self._check_if_data_added) + self.hub.subscribe(self, GlobalDisplayUnitChanged, + handler=self._on_global_display_unit_changed) + @property def _default_flux_viewer_reference_name(self): return self.jdaviz_helper._default_flux_viewer_reference_name + def _on_global_display_unit_changed(self, msg=None): + # Clear cache of slice values when units change + if 'slice_values' in self.__dict__: + del self.__dict__['slice_values'] + def _check_if_data_removed(self, msg): # isinstance and the data uuid check will be true for the data # that is being removed self.figure.marks = [m for m in self.figure.marks if not (isinstance(m, ShadowSpatialSpectral) and m.data_uuid == msg.data.uuid)] + self._on_global_display_unit_changed() def _check_if_data_added(self, msg=None): # When data is added, make sure that all spatial subset layers @@ -243,6 +311,7 @@ def _check_if_data_added(self, msg=None): if (isinstance(layer.layer, GroupedSubset) and get_subset_type(layer.layer.subset_state) == 'spatial'): self._expected_subset_layer_default(layer) + self._on_global_display_unit_changed() def _is_spatial_subset(self, layer): subset_state = getattr(layer.layer, 'subset_state', None) diff --git a/jdaviz/configs/default/plugins/model_fitting/model_fitting.py b/jdaviz/configs/default/plugins/model_fitting/model_fitting.py index bdce1639db..9c1a5c403a 100644 --- a/jdaviz/configs/default/plugins/model_fitting/model_fitting.py +++ b/jdaviz/configs/default/plugins/model_fitting/model_fitting.py @@ -315,7 +315,8 @@ def _warn_if_no_equation(self): def _get_1d_spectrum(self): # retrieves the 1d spectrum (accounting for spatial subset for cubeviz, if necessary) - return self.dataset.selected_spectrum_for_spatial_subset(self.spatial_subset_selected) # noqa + return self.dataset.selected_spectrum_for_spatial_subset(self.spatial_subset_selected, + use_display_units=True) @observe("dataset_selected", "spatial_subset_selected") def _dataset_selected_changed(self, event=None): @@ -343,7 +344,6 @@ def _dataset_selected_changed(self, event=None): # Replace NaNs from collapsed Spectrum1D in Cubeviz # (won't affect calculations because these locations are masked) selected_spec.flux[np.isnan(selected_spec.flux)] = 0.0 - # TODO: can we simplify this logic? self._units["x"] = str( selected_spec.spectral_axis.unit) @@ -508,7 +508,6 @@ def _initialize_model_component(self, model_comp, comp_label, poly_order=None): "value": param_quant.value, "unit": str(param_quant.unit), "fixed": False}) - self._initialized_models[comp_label] = initialized_model new_model["Initialized"] = True @@ -841,7 +840,6 @@ def _fit_model_to_spectrum(self, add_data): models_to_fit = self._reinitialize_with_fixed() masked_spectrum = self._apply_subset_masks(self._get_1d_spectrum(), self.spectral_subset) - try: fitted_model, fitted_spectrum = fit_model_to_spectrum( masked_spectrum, diff --git a/jdaviz/configs/specviz/plugins/unit_conversion/unit_conversion.py b/jdaviz/configs/specviz/plugins/unit_conversion/unit_conversion.py index d10d2c6770..56f4c95d92 100644 --- a/jdaviz/configs/specviz/plugins/unit_conversion/unit_conversion.py +++ b/jdaviz/configs/specviz/plugins/unit_conversion/unit_conversion.py @@ -127,18 +127,18 @@ def _on_glue_y_display_unit_changed(self, y_unit): @observe('spectral_unit_selected') def _on_spectral_unit_changed(self, *args): - self.hub.broadcast(GlobalDisplayUnitChanged('spectral', - self.spectral_unit.selected, - sender=self)) xunit = _valid_glue_display_unit(self.spectral_unit.selected, self.spectrum_viewer, 'x') if self.spectrum_viewer.state.x_display_unit != xunit: self.spectrum_viewer.state.x_display_unit = xunit + self.hub.broadcast(GlobalDisplayUnitChanged('spectral', + self.spectral_unit.selected, + sender=self)) @observe('flux_unit_selected') def _on_flux_unit_changed(self, *args): - self.hub.broadcast(GlobalDisplayUnitChanged('flux', - self.flux_unit.selected, - sender=self)) yunit = _valid_glue_display_unit(self.flux_unit.selected, self.spectrum_viewer, 'y') if self.spectrum_viewer.state.y_display_unit != yunit: self.spectrum_viewer.state.y_display_unit = yunit + self.hub.broadcast(GlobalDisplayUnitChanged('flux', + self.flux_unit.selected, + sender=self)) diff --git a/jdaviz/configs/specviz/plugins/unit_conversion/unit_conversion.vue b/jdaviz/configs/specviz/plugins/unit_conversion/unit_conversion.vue index 8d5afc6ab8..e952337d97 100644 --- a/jdaviz/configs/specviz/plugins/unit_conversion/unit_conversion.vue +++ b/jdaviz/configs/specviz/plugins/unit_conversion/unit_conversion.vue @@ -28,7 +28,13 @@ label="Flux Unit" hint="Global display unit for flux." persistent-hint + :disabled="config === 'cubeviz'" > + + + Flux conversion is not yet implemented in Cubeviz. + + diff --git a/jdaviz/core/helpers.py b/jdaviz/core/helpers.py index e6f8e51f3d..fe89c94468 100644 --- a/jdaviz/core/helpers.py +++ b/jdaviz/core/helpers.py @@ -494,7 +494,8 @@ def _handle_display_units(data, use_display_units): u.spectral()), flux=data.flux.to(flux_unit, u.spectral_density(data.spectral_axis)), - uncertainty=new_uncert) + uncertainty=new_uncert, + mask=data.mask) else: # pragma: nocover raise NotImplementedError(f"converting {data.__class__.__name__} to display units is not supported") # noqa return data diff --git a/jdaviz/core/marks.py b/jdaviz/core/marks.py index a2634adfd7..e5df8bf736 100644 --- a/jdaviz/core/marks.py +++ b/jdaviz/core/marks.py @@ -307,6 +307,12 @@ def __init__(self, viewer, value=0, **kwargs): def marks(self): return [self, self.label] + def _on_global_display_unit_changed(self, msg): + # Updating the value is handled by the plugin itself, need to update unit string. + if msg.axis in ["spectral", "x"]: + self.xunit = msg.unit + self._update_label() + def _value_handle_oob(self, x=None, update_label=False): if x is None: x = self.value diff --git a/jdaviz/core/template_mixin.py b/jdaviz/core/template_mixin.py index 53e79cec85..8647b4c6b1 100644 --- a/jdaviz/core/template_mixin.py +++ b/jdaviz/core/template_mixin.py @@ -2774,7 +2774,8 @@ def _get_continuum(self, dataset, spatial_subset, spectral_subset, update_marks= if spatial_subset == 'per-pixel': if self.app.config != 'cubeviz': raise ValueError("per-pixel only supported for cubeviz") - full_spectrum = self.dataset.selected_obj + full_spectrum = self.app._jdaviz_helper.get_data(self.dataset.selected, + use_display_units=True) else: full_spectrum = dataset.selected_spectrum_for_spatial_subset(spatial_subset.selected if spatial_subset is not None else None, # noqa use_display_units=True) @@ -2802,8 +2803,8 @@ def _get_continuum(self, dataset, spatial_subset, spectral_subset, update_marks= simplify_spectral=True, use_display_units=True) spectrum = extract_region(full_spectrum, sr, return_single_spectrum=True) - sr_lower = np.nanmin(spectrum.spectral_axis[spectrum.spectral_axis.value >= sr.lower.value]) # noqa - sr_upper = np.nanmax(spectrum.spectral_axis[spectrum.spectral_axis.value <= sr.upper.value]) # noqa + sr_lower = np.nanmin(spectrum.spectral_axis[spectrum.spectral_axis >= sr.lower]) # noqa + sr_upper = np.nanmax(spectrum.spectral_axis[spectrum.spectral_axis <= sr.upper]) # noqa if self.continuum_subset_selected == 'None': self._update_continuum_marks() @@ -2861,7 +2862,7 @@ def _get_continuum(self, dataset, spatial_subset, spectral_subset, update_marks= continuum_mask = ~self._specviz_helper.get_data( dataset.selected, spectral_subset=self.continuum_subset_selected, - use_display_units=False).mask + use_display_units=True).mask spectral_axis_nanmasked = spectral_axis.value.copy() spectral_axis_nanmasked[~continuum_mask] = np.nan if not update_marks: @@ -2871,8 +2872,8 @@ def _get_continuum(self, dataset, spatial_subset, spectral_subset, update_marks= 'center': spectral_axis.value, 'right': []} else: - mark_x = {'left': spectral_axis_nanmasked[spectral_axis.value < sr_lower.value], - 'right': spectral_axis_nanmasked[spectral_axis.value > sr_upper.value]} + mark_x = {'left': spectral_axis_nanmasked[spectral_axis < sr_lower], + 'right': spectral_axis_nanmasked[spectral_axis > sr_upper]} # Center should extend (at least) across the line region between the full # range defined by the continuum subset(s). # OK for mark_x to be all NaNs.