Skip to content

Commit

Permalink
Enable spectral conversion in cubeviz (#2758)
Browse files Browse the repository at this point in the history
* Add unit conversion plugin with flux conversion disabled

Add message to flux conversion dropdown in cubeviz

Have moment map units update after conversion

Update continuum, moment maps, and spectral extraction

Fix model fitting and aperture plugins and add tests

Update slice_values to account for unit conversion

Remove print statment and unused import

Fix style

Fix moment map and slice tests

Fix style 2

Add LinkSameWithUnits and slice indicator updates

* Use cached_property for slice_values

* Use correct cache clearing protocol

Running CI without cached_property

Fix everything except todo statements

Uncommented code needed for uncert viewer to change slice

Remove comments

* Fix message handlers

* Remove space

* Minor changes

* Override display unit handling in slice mark so we don't change value twice

* Simplify and generalize viewer slice values code

* Fix style

* Move code out of try except block

* Address review comments

* Generalize display unit call in spectral extraction

* Fix test failure

* Rename slice_axis property to slice_display_unit_name

* Add changelog

---------

Co-authored-by: Ricky O'Steen <rosteen@stsci.edu>
  • Loading branch information
javerbukh and rosteen authored Apr 15, 2024
1 parent d2f4278 commit 6c57b66
Show file tree
Hide file tree
Showing 15 changed files with 160 additions and 44 deletions.
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ New Features
Cubeviz
^^^^^^^

- Enable spectral unit conversion in cubeviz. [#2758]

Imviz
^^^^^

Expand Down
2 changes: 1 addition & 1 deletion jdaviz/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def _link_new_data(self, reference_data=None, data_to_be_linked=None):
linked_flux_component = dc[-1].components[-1]

links = [
LinkSame(ref_wavelength_component, linked_wavelength_component),
LinkSameWithUnits(ref_wavelength_component, linked_wavelength_component),
LinkSame(ref_flux_component, linked_flux_component)
]

Expand Down
1 change: 1 addition & 0 deletions jdaviz/configs/cubeviz/cubeviz.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ tray:
- g-subset-plugin
- g-markers
- cubeviz-slice
- g-unit-conversion
- g-gaussian-smooth
- g-collapse
- g-model-fitting
Expand Down
16 changes: 11 additions & 5 deletions jdaviz/configs/cubeviz/plugins/moment_maps/moment_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from specutils import manipulation, analysis, Spectrum1D

from jdaviz.core.custom_traitlets import IntHandleEmpty, FloatHandleEmpty
from jdaviz.core.events import SnackbarMessage
from jdaviz.core.events import SnackbarMessage, GlobalDisplayUnitChanged
from jdaviz.core.registries import tray_registry
from jdaviz.core.template_mixin import (PluginTemplateMixin,
DatasetSelectMixin,
Expand Down Expand Up @@ -96,6 +96,8 @@ def __init__(self, *args, **kwargs):

self.dataset.add_filter('is_cube')
self.add_results.viewer.filters = ['is_image_viewer']
self.hub.subscribe(self, GlobalDisplayUnitChanged,
handler=self._set_data_units)

if self.app.state.settings.get('server_is_remote', False):
# when the server is remote, saving the file in python would save on the server, not
Expand Down Expand Up @@ -155,12 +157,16 @@ def _set_data_units(self, event={}):
if self.dataset_selected != "":
# Spectral axis is first in this list
data = self.app.data_collection[self.dataset_selected]
if self.app.data_collection[self.dataset_selected].coords is not None:
if (self.spectrum_viewer and hasattr(self.spectrum_viewer.state, 'x_display_unit')
and self.spectrum_viewer.state.x_display_unit is not None):
sunit = self.spectrum_viewer.state.x_display_unit
elif self.app.data_collection[self.dataset_selected].coords is not None:
sunit = data.coords.world_axis_units[0]
self.dataset_spectral_unit = sunit
unit_dict["Spectral Unit"] = sunit
else:
self.dataset_spectral_unit = ""
sunit = ""
self.dataset_spectral_unit = sunit
unit_dict["Spectral Unit"] = sunit

unit_dict["Flux"] = data.get_component('flux').units

# Update units in selection item dictionary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def test_moment_velocity_calculation(cubeviz_helper, spectrum1d_cube):
uncert_viewer.state._set_axes_aspect_ratio(1)

mm = cubeviz_helper.plugins["Moment Maps"]
print(mm._obj.dataset_selected)
mm._obj.dataset_selected = 'test[FLUX]'

# Test moment 1 in velocity
Expand All @@ -170,6 +169,13 @@ def test_moment_velocity_calculation(cubeviz_helper, spectrum1d_cube):
"World 13h39m59.9731s +27d00m00.3600s (ICRS)",
"204.9998877673 27.0001000000 (deg)")

# Add test for unit conversion
assert mm._obj.output_radio_items[0]['unit_str'] == 'm'
uc_plugin = cubeviz_helper.plugins['Unit Conversion']._obj
uc_plugin.spectral_unit.selected = 'Angstrom'
assert mm._obj.output_radio_items[0]['unit_str'] == 'Angstrom'
uc_plugin.spectral_unit.selected = 'm'

# Test moment 2 in velocity
mm.n_moment = 2
mm.calculate_moment()
Expand Down
21 changes: 14 additions & 7 deletions jdaviz/configs/cubeviz/plugins/slice/slice.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import threading
import time
import warnings
from functools import cached_property

import numpy as np
from astropy import units as u
Expand All @@ -12,7 +13,7 @@
CubevizImageView)
from jdaviz.configs.cubeviz.helper import _spectral_axis_names
from jdaviz.core.custom_traitlets import FloatHandleEmpty
from jdaviz.core.events import (AddDataMessage, SliceToolStateMessage,
from jdaviz.core.events import (AddDataMessage, RemoveDataMessage, SliceToolStateMessage,
SliceSelectSliceMessage, SliceValueUpdatedMessage,
NewViewerMessage, ViewerAddedMessage, ViewerRemovedMessage,
GlobalDisplayUnitChanged)
Expand Down Expand Up @@ -82,6 +83,8 @@ def __init__(self, *args, **kwargs):
handler=self._on_viewer_removed)
self.hub.subscribe(self, AddDataMessage,
handler=self._on_add_data)
self.hub.subscribe(self, RemoveDataMessage,
handler=self._on_valid_selection_values_changed)

# connect any pre-existing viewers
for viewer in self.app._viewer_store.values():
Expand Down Expand Up @@ -127,7 +130,7 @@ def _initialize_location(self, *args):
return

@property
def slice_axis(self):
def slice_display_unit_name(self):
# global display unit "axis" corresponding to the slice axis
return 'spectral'

Expand Down Expand Up @@ -221,6 +224,7 @@ def _on_add_data(self, msg):
if isinstance(msg.viewer, WithSliceSelection):
# instead of just setting viewer.slice_value, we'll make sure the "snapping" logic
# is updated (if enabled)
self._on_valid_selection_values_changed()
self._on_value_updated({'new': self.value})

def _on_select_slice_message(self, msg):
Expand All @@ -229,16 +233,21 @@ def _on_select_slice_message(self, msg):
self.value = msg.value

def _on_global_display_unit_changed(self, msg):
if msg.axis != self.slice_axis:
if msg.axis != self.slice_display_unit_name:
return
if not self.value_unit:
self.value_unit = str(msg.unit)
return
prev_unit = u.Unit(self.value_unit)
self.value_unit = str(msg.unit)
self.value = (self.value * prev_unit).to_value(msg.unit)
self._on_valid_selection_values_changed()
self.value = (self.value * prev_unit).to_value(msg.unit, equivalencies=u.spectral())

@property
def _on_valid_selection_values_changed(self, msg=None):
if 'valid_selection_values' in self.__dict__:
del self.__dict__['valid_selection_values']

@cached_property
def valid_selection_values(self):
# all available slice values from cubes (unsorted)
viewers = self.slice_selection_viewers
Expand Down Expand Up @@ -280,7 +289,6 @@ def _on_value_updated(self, event):
except ValueError:
return
return

if self.snap_to_slice and not self.value_editing:
valid_values = self.valid_selection_values
if len(valid_values):
Expand All @@ -292,7 +300,6 @@ def _on_value_updated(self, event):
self.value = float(closest_value)
# will trigger another call to this method
return

for viewer in self.slice_indicator_viewers:
viewer._set_slice_indicator_value(self.value)
for viewer in self.slice_selection_viewers:
Expand Down
10 changes: 10 additions & 0 deletions jdaviz/configs/cubeviz/plugins/slice/tests/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def test_slice(cubeviz_helper, spectrum1d_cube):
cubeviz_helper.select_wavelength(4.62360028e-07)
assert sl.value == slice_values[1]

# Add test for unit conversion
uc_plugin = cubeviz_helper.plugins['Unit Conversion']._obj
uc_plugin.spectral_unit.selected = 'Angstrom'
assert sl.value_unit == 'Angstrom'
cubeviz_helper.select_wavelength(4623.60028)
assert sl.value == 4623.600276968349

# Retrieve updated slice_values
slice_values = sl.valid_selection_values_sorted

# Test player buttons API

sl.vue_goto_first()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,10 @@ def user_api(self):

return PluginUserApi(self, expose=expose)

@property
def slice_display_unit_name(self):
return 'spectral'

@property
@deprecated(since="3.9", alternative="aperture")
def spatial_subset(self):
Expand Down Expand Up @@ -355,8 +359,7 @@ def get_aperture(self):
# Retrieve flux cube and create an array to represent the cone mask
flux_cube = self._app._jdaviz_helper._loaded_flux_cube.get_object(cls=Spectrum1D,
statistic=None)
# TODO: Replace with code for retrieving display_unit in cubeviz when it is available
display_unit = flux_cube.spectral_axis.unit
display_unit = astropy.units.Unit(self.app._get_display_unit(self.slice_display_unit_name))

# Center is reverse coordinates
center = (self.aperture.selected_spatial_region.center.y,
Expand All @@ -373,7 +376,7 @@ def get_aperture(self):
f'Spectral axis unit physical type is {display_unit.physical_type}, '
'must be length for cone aperture')

fac = flux_cube.spectral_axis.value / self.reference_spectral_value
fac = flux_cube.spectral_axis.to_value(display_unit) / self.reference_spectral_value

# TODO: Use flux_cube.spectral_axis.to_value(display_unit) when we have unit conversion.
if isinstance(aperture, CircularAperture):
Expand Down
89 changes: 79 additions & 10 deletions jdaviz/configs/cubeviz/plugins/viewers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np

import astropy.units as u
from functools import cached_property
from glue.core import BaseData

Expand All @@ -12,7 +12,7 @@
from jdaviz.configs.cubeviz.helper import layer_is_cube_image_data
from jdaviz.configs.default.plugins.viewers import JdavizViewerMixin
from jdaviz.configs.specviz.plugins.viewers import SpecvizProfileView
from jdaviz.core.events import AddDataMessage, RemoveDataMessage
from jdaviz.core.events import AddDataMessage, RemoveDataMessage, GlobalDisplayUnitChanged
from jdaviz.core.freezable_state import FreezableBqplotImageViewerState
from jdaviz.utils import get_subset_type

Expand All @@ -25,23 +25,44 @@ class WithSliceIndicator:
def slice_component_label(self):
return str(self.state.x_att)

@property
def slice_display_unit_name(self):
return 'spectral'

@cached_property
def slice_indicator(self):
# SliceIndicatorMarks does not yet exist
slice_indicator = SliceIndicatorMarks(self)
self.figure.marks = self.figure.marks + slice_indicator.marks
return slice_indicator

@property
@cached_property
def slice_values(self):

def _get_component(layer):
# Retrieve display units
slice_display_units = self.jdaviz_app._get_display_unit(
self.slice_display_unit_name
)

try:
return layer.layer.get_component(self.slice_component_label).data
# Retrieve layer data and units
data_obj = layer.layer.data.get_component(self.slice_component_label).data
data_units = layer.layer.data.get_component(self.slice_component_label).units
except (AttributeError, KeyError):
# layer either does not have get_component (because its a subset)
# or slice_component_label is not a component in this layer
# either way, return an empty array and skip this layer
return np.array([])

data_spec_axis = np.asarray(data_obj.data, dtype=float) * u.Unit(data_units)

# Convert axis if display units are set and are different
if slice_display_units and slice_display_units != data_units:
return data_spec_axis.to_value(slice_display_units,
equivalencies=u.spectral())
else:
return data_spec_axis
try:
return np.asarray(np.unique(np.concatenate([_get_component(layer) for layer in self.layers])), # noqa
dtype=float)
Expand All @@ -68,24 +89,50 @@ def slice_component_label(self):
return slice_plg._obj.slice_indicator_viewers[0].slice_component_label

@property
def slice_display_unit_name(self):
return 'spectral'

@cached_property
def slice_values(self):
# TODO: make a cached property and invalidate cache on add/remove data
# TODO: add support for multiple cubes (but then slice selection needs to be more complex)
# if slice_index is 0, then we want the equivalent of [:, 0, 0]
# if slice_index is 1, then we want the equivalent of [0, :, 0]
# if slice_index is 2, then we want the equivalent of [0, 0, :]
take_inds = [2, 1, 0]
take_inds.remove(self.slice_index)
converted_axis = np.array([])
for layer in self.layers:
world_comp_ids = layer.layer.data.world_component_ids
if self.slice_index >= len(world_comp_ids):
# Case where 2D image is loaded in image viewer
continue

# Retrieve display units
slice_display_units = self.jdaviz_app._get_display_unit(
self.slice_display_unit_name
)

try:
data_obj = layer.layer.data.get_component(self.slice_component_label).data
# Retrieve layer data and units using the slice index of the world components ids
data_obj = layer.layer.data.get_component(world_comp_ids[self.slice_index]).data
data_units = layer.layer.data.get_component(world_comp_ids[self.slice_index]).units
except (AttributeError, KeyError):
continue

# Find the spectral axis
data_spec_axis = np.asarray(data_obj.take(0, take_inds[0]).take(0, take_inds[1]), # noqa
dtype=float)

# Convert to display units if applicable
if slice_display_units and slice_display_units != data_units:
converted_axis = (data_spec_axis * u.Unit(data_units)).to_value(
slice_display_units,
equivalencies=u.spectral()
)
else:
break
else:
return np.array([])
return np.asarray(data_obj.take(0, take_inds[0]).take(0, take_inds[1]), dtype=float)
converted_axis = data_spec_axis

return converted_axis

@property
def slice(self):
Expand Down Expand Up @@ -143,6 +190,13 @@ def __init__(self, *args, **kwargs):
# Hide axes by default
self.state.show_axes = False

self.hub.subscribe(self, GlobalDisplayUnitChanged,
handler=self._on_global_display_unit_changed
)

self.hub.subscribe(self, AddDataMessage, handler=self._on_global_display_unit_changed)
self.hub.subscribe(self, RemoveDataMessage, handler=self._on_global_display_unit_changed)

@property
def _default_spectrum_viewer_reference_name(self):
return self.jdaviz_helper._default_spectrum_viewer_reference_name
Expand All @@ -167,6 +221,11 @@ def active_image_layer(self):

return visible_layers[-1]

def _on_global_display_unit_changed(self, msg):
# Clear cache of slice values when units change
if 'slice_values' in self.__dict__:
del self.__dict__['slice_values']

def _initial_x_axis(self, *args):
# Make sure that the x_att is correct on data load
ref_data = self.state.reference_data
Expand Down Expand Up @@ -223,16 +282,25 @@ def __init__(self, *args, **kwargs):
self.hub.subscribe(self, AddDataMessage,
handler=self._check_if_data_added)

self.hub.subscribe(self, GlobalDisplayUnitChanged,
handler=self._on_global_display_unit_changed)

@property
def _default_flux_viewer_reference_name(self):
return self.jdaviz_helper._default_flux_viewer_reference_name

def _on_global_display_unit_changed(self, msg=None):
# Clear cache of slice values when units change
if 'slice_values' in self.__dict__:
del self.__dict__['slice_values']

def _check_if_data_removed(self, msg):
# isinstance and the data uuid check will be true for the data
# that is being removed
self.figure.marks = [m for m in self.figure.marks
if not (isinstance(m, ShadowSpatialSpectral)
and m.data_uuid == msg.data.uuid)]
self._on_global_display_unit_changed()

def _check_if_data_added(self, msg=None):
# When data is added, make sure that all spatial subset layers
Expand All @@ -243,6 +311,7 @@ def _check_if_data_added(self, msg=None):
if (isinstance(layer.layer, GroupedSubset) and
get_subset_type(layer.layer.subset_state) == 'spatial'):
self._expected_subset_layer_default(layer)
self._on_global_display_unit_changed()

def _is_spatial_subset(self, layer):
subset_state = getattr(layer.layer, 'subset_state', None)
Expand Down
Loading

0 comments on commit 6c57b66

Please sign in to comment.