Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable spectral conversion in cubeviz #2758

Merged
Merged
2 changes: 1 addition & 1 deletion jdaviz/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ def _link_new_data(self, reference_data=None, data_to_be_linked=None):
linked_flux_component = dc[-1].components[-1]

links = [
LinkSame(ref_wavelength_component, linked_wavelength_component),
LinkSameWithUnits(ref_wavelength_component, linked_wavelength_component),
LinkSame(ref_flux_component, linked_flux_component)
]

Expand Down
1 change: 1 addition & 0 deletions jdaviz/configs/cubeviz/cubeviz.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ tray:
- g-subset-plugin
- g-markers
- cubeviz-slice
- g-unit-conversion
- g-gaussian-smooth
- g-collapse
- g-model-fitting
Expand Down
16 changes: 11 additions & 5 deletions jdaviz/configs/cubeviz/plugins/moment_maps/moment_maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from specutils import manipulation, analysis, Spectrum1D

from jdaviz.core.custom_traitlets import IntHandleEmpty, FloatHandleEmpty
from jdaviz.core.events import SnackbarMessage
from jdaviz.core.events import SnackbarMessage, GlobalDisplayUnitChanged
from jdaviz.core.registries import tray_registry
from jdaviz.core.template_mixin import (PluginTemplateMixin,
DatasetSelectMixin,
Expand Down Expand Up @@ -96,6 +96,8 @@ def __init__(self, *args, **kwargs):

self.dataset.add_filter('is_cube')
self.add_results.viewer.filters = ['is_image_viewer']
self.hub.subscribe(self, GlobalDisplayUnitChanged,
handler=self._set_data_units)

if self.app.state.settings.get('server_is_remote', False):
# when the server is remote, saving the file in python would save on the server, not
Expand Down Expand Up @@ -155,12 +157,16 @@ def _set_data_units(self, event={}):
if self.dataset_selected != "":
# Spectral axis is first in this list
data = self.app.data_collection[self.dataset_selected]
if self.app.data_collection[self.dataset_selected].coords is not None:
if (self.spectrum_viewer and hasattr(self.spectrum_viewer.state, 'x_display_unit')
and self.spectrum_viewer.state.x_display_unit is not None):
sunit = self.spectrum_viewer.state.x_display_unit
elif self.app.data_collection[self.dataset_selected].coords is not None:
sunit = data.coords.world_axis_units[0]
self.dataset_spectral_unit = sunit
unit_dict["Spectral Unit"] = sunit
else:
self.dataset_spectral_unit = ""
sunit = ""
self.dataset_spectral_unit = sunit
unit_dict["Spectral Unit"] = sunit

unit_dict["Flux"] = data.get_component('flux').units

# Update units in selection item dictionary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def test_moment_velocity_calculation(cubeviz_helper, spectrum1d_cube):
uncert_viewer.state._set_axes_aspect_ratio(1)

mm = cubeviz_helper.plugins["Moment Maps"]
print(mm._obj.dataset_selected)
mm._obj.dataset_selected = 'test[FLUX]'

# Test moment 1 in velocity
Expand All @@ -170,6 +169,13 @@ def test_moment_velocity_calculation(cubeviz_helper, spectrum1d_cube):
"World 13h39m59.9731s +27d00m00.3600s (ICRS)",
"204.9998877673 27.0001000000 (deg)")

# Add test for unit conversion
assert mm._obj.output_radio_items[0]['unit_str'] == 'm'
uc_plugin = cubeviz_helper.plugins['Unit Conversion']._obj
uc_plugin.spectral_unit.selected = 'Angstrom'
assert mm._obj.output_radio_items[0]['unit_str'] == 'Angstrom'
uc_plugin.spectral_unit.selected = 'm'

# Test moment 2 in velocity
mm.n_moment = 2
mm.calculate_moment()
Expand Down
17 changes: 12 additions & 5 deletions jdaviz/configs/cubeviz/plugins/slice/slice.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import threading
import time
import warnings
from functools import cached_property

import numpy as np
from astropy import units as u
Expand All @@ -12,7 +13,7 @@
CubevizImageView)
from jdaviz.configs.cubeviz.helper import _spectral_axis_names
from jdaviz.core.custom_traitlets import FloatHandleEmpty
from jdaviz.core.events import (AddDataMessage, SliceToolStateMessage,
from jdaviz.core.events import (AddDataMessage, RemoveDataMessage, SliceToolStateMessage,
SliceSelectSliceMessage, SliceValueUpdatedMessage,
NewViewerMessage, ViewerAddedMessage, ViewerRemovedMessage,
GlobalDisplayUnitChanged)
Expand Down Expand Up @@ -82,6 +83,8 @@ def __init__(self, *args, **kwargs):
handler=self._on_viewer_removed)
self.hub.subscribe(self, AddDataMessage,
handler=self._on_add_data)
self.hub.subscribe(self, RemoveDataMessage,
handler=self._on_valid_selection_values_changed)

# connect any pre-existing viewers
for viewer in self.app._viewer_store.values():
Expand Down Expand Up @@ -221,6 +224,7 @@ def _on_add_data(self, msg):
if isinstance(msg.viewer, WithSliceSelection):
# instead of just setting viewer.slice_value, we'll make sure the "snapping" logic
# is updated (if enabled)
self._on_valid_selection_values_changed()
self._on_value_updated({'new': self.value})

def _on_select_slice_message(self, msg):
Expand All @@ -236,9 +240,14 @@ def _on_global_display_unit_changed(self, msg):
return
prev_unit = u.Unit(self.value_unit)
self.value_unit = str(msg.unit)
self.value = (self.value * prev_unit).to_value(msg.unit)
self._on_valid_selection_values_changed()
self.value = (self.value * prev_unit).to_value(msg.unit, equivalencies=u.spectral())

@property
def _on_valid_selection_values_changed(self, msg=None):
if 'valid_selection_values' in self.__dict__:
del self.__dict__['valid_selection_values']

@cached_property
def valid_selection_values(self):
# all available slice values from cubes (unsorted)
viewers = self.slice_selection_viewers
Expand Down Expand Up @@ -280,7 +289,6 @@ def _on_value_updated(self, event):
except ValueError:
return
return

if self.snap_to_slice and not self.value_editing:
valid_values = self.valid_selection_values
if len(valid_values):
Expand All @@ -292,7 +300,6 @@ def _on_value_updated(self, event):
self.value = float(closest_value)
# will trigger another call to this method
return

for viewer in self.slice_indicator_viewers:
viewer._set_slice_indicator_value(self.value)
for viewer in self.slice_selection_viewers:
Expand Down
10 changes: 10 additions & 0 deletions jdaviz/configs/cubeviz/plugins/slice/tests/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def test_slice(cubeviz_helper, spectrum1d_cube):
cubeviz_helper.select_wavelength(4.62360028e-07)
assert sl.value == slice_values[1]

# Add test for unit conversion
uc_plugin = cubeviz_helper.plugins['Unit Conversion']._obj
uc_plugin.spectral_unit.selected = 'Angstrom'
assert sl.value_unit == 'Angstrom'
cubeviz_helper.select_wavelength(4623.60028)
assert sl.value == 4623.600276968349

# Retrieve updated slice_values
slice_values = sl.valid_selection_values_sorted

# Test player buttons API

sl.vue_goto_first()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,7 @@ def get_aperture(self):
# Retrieve flux cube and create an array to represent the cone mask
flux_cube = self._app._jdaviz_helper._loaded_flux_cube.get_object(cls=Spectrum1D,
statistic=None)
# TODO: Replace with code for retrieving display_unit in cubeviz when it is available
display_unit = flux_cube.spectral_axis.unit
display_unit = astropy.units.Unit(self.app._get_display_unit('spectral'))

# Center is reverse coordinates
center = (self.aperture.selected_spatial_region.center.y,
Expand All @@ -373,7 +372,7 @@ def get_aperture(self):
f'Spectral axis unit physical type is {display_unit.physical_type}, '
'must be length for cone aperture')

fac = flux_cube.spectral_axis.value / self.reference_spectral_value
fac = flux_cube.spectral_axis.to_value(display_unit) / self.reference_spectral_value

# TODO: Use flux_cube.spectral_axis.to_value(display_unit) when we have unit conversion.
if isinstance(aperture, CircularAperture):
Expand Down
69 changes: 61 additions & 8 deletions jdaviz/configs/cubeviz/plugins/viewers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np

import astropy.units as u
from functools import cached_property
from glue.core import BaseData

Expand All @@ -12,7 +12,7 @@
from jdaviz.configs.cubeviz.helper import layer_is_cube_image_data
from jdaviz.configs.default.plugins.viewers import JdavizViewerMixin
from jdaviz.configs.specviz.plugins.viewers import SpecvizProfileView
from jdaviz.core.events import AddDataMessage, RemoveDataMessage
from jdaviz.core.events import AddDataMessage, RemoveDataMessage, GlobalDisplayUnitChanged
from jdaviz.core.freezable_state import FreezableBqplotImageViewerState
from jdaviz.utils import get_subset_type

Expand All @@ -32,11 +32,25 @@
self.figure.marks = self.figure.marks + slice_indicator.marks
return slice_indicator

@property
@cached_property
def slice_values(self):

def _get_component(layer):
try:
return layer.layer.get_component(self.slice_component_label).data
# Retrieve display units
display_spectral_units = self.jdaviz_app._get_display_unit('spectral')
javerbukh marked this conversation as resolved.
Show resolved Hide resolved

# Retrieve layer data and units
data_obj = layer.layer.data.get_component(self.slice_component_label).data
data_units = layer.layer.data.get_component(self.slice_component_label).units
data_spec_axis = np.asarray(data_obj.data, dtype=float) * u.Unit(data_units)

# Convert axis if display units are set and are different
if display_spectral_units and display_spectral_units != data_units:
return data_spec_axis.to_value(display_spectral_units,

Check warning on line 50 in jdaviz/configs/cubeviz/plugins/viewers.py

View check run for this annotation

Codecov / codecov/patch

jdaviz/configs/cubeviz/plugins/viewers.py#L50

Added line #L50 was not covered by tests
equivalencies=u.spectral())
else:
return data_spec_axis
except (AttributeError, KeyError):
# layer either does not have get_component (because its a subset)
# or slice_component_label is not a component in this layer
Expand Down Expand Up @@ -67,25 +81,44 @@
raise ValueError("slice plugin must be activated to access slice_component_label")
return slice_plg._obj.slice_indicator_viewers[0].slice_component_label

@property
@cached_property
def slice_values(self):
# TODO: make a cached property and invalidate cache on add/remove data
# TODO: add support for multiple cubes (but then slice selection needs to be more complex)
# if slice_index is 0, then we want the equivalent of [:, 0, 0]
# if slice_index is 1, then we want the equivalent of [0, :, 0]
# if slice_index is 2, then we want the equivalent of [0, 0, :]
take_inds = [2, 1, 0]
take_inds.remove(self.slice_index)
for layer in self.layers:
world_comp_ids = layer.layer.data.world_component_ids
if self.slice_index >= len(world_comp_ids):
# Case where 2D image is loaded in image viewer
continue
try:
data_obj = layer.layer.data.get_component(self.slice_component_label).data
# Retrieve display units
display_spectral_units = self.jdaviz_app._get_display_unit('spectral')

# Retrieve layer data and units using the slice index of the world components ids
data_obj = layer.layer.data.get_component(world_comp_ids[self.slice_index]).data
data_units = layer.layer.data.get_component(world_comp_ids[self.slice_index]).units

# Find the spectral axis
data_spec_axis = np.asarray(data_obj.take(0, take_inds[0]).take(0, take_inds[1]), # noqa
dtype=float)
if display_spectral_units and display_spectral_units != data_units:
converted_axis = (data_spec_axis * u.Unit(data_units)).to_value(
display_spectral_units,
equivalencies=u.spectral()
)
else:
converted_axis = data_spec_axis
javerbukh marked this conversation as resolved.
Show resolved Hide resolved
javerbukh marked this conversation as resolved.
Show resolved Hide resolved
except (AttributeError, KeyError):
continue
else:
break
else:
return np.array([])
return np.asarray(data_obj.take(0, take_inds[0]).take(0, take_inds[1]), dtype=float)
return converted_axis

@property
def slice(self):
Expand Down Expand Up @@ -143,6 +176,13 @@
# Hide axes by default
self.state.show_axes = False

self.hub.subscribe(self, GlobalDisplayUnitChanged,
handler=self._on_global_display_unit_changed
)

self.hub.subscribe(self, AddDataMessage, handler=self._on_global_display_unit_changed)
self.hub.subscribe(self, RemoveDataMessage, handler=self._on_global_display_unit_changed)

@property
def _default_spectrum_viewer_reference_name(self):
return self.jdaviz_helper._default_spectrum_viewer_reference_name
Expand All @@ -167,6 +207,10 @@

return visible_layers[-1]

def _on_global_display_unit_changed(self, msg):
if 'slice_values' in self.__dict__:
del self.__dict__['slice_values']
javerbukh marked this conversation as resolved.
Show resolved Hide resolved

def _initial_x_axis(self, *args):
# Make sure that the x_att is correct on data load
ref_data = self.state.reference_data
Expand Down Expand Up @@ -223,16 +267,24 @@
self.hub.subscribe(self, AddDataMessage,
handler=self._check_if_data_added)

self.hub.subscribe(self, GlobalDisplayUnitChanged,
handler=self._on_global_display_unit_changed)

@property
def _default_flux_viewer_reference_name(self):
return self.jdaviz_helper._default_flux_viewer_reference_name

def _on_global_display_unit_changed(self, msg=None):
if 'slice_values' in self.__dict__:
del self.__dict__['slice_values']

def _check_if_data_removed(self, msg):
# isinstance and the data uuid check will be true for the data
# that is being removed
self.figure.marks = [m for m in self.figure.marks
if not (isinstance(m, ShadowSpatialSpectral)
and m.data_uuid == msg.data.uuid)]
self._on_global_display_unit_changed()

def _check_if_data_added(self, msg=None):
# When data is added, make sure that all spatial subset layers
Expand All @@ -243,6 +295,7 @@
if (isinstance(layer.layer, GroupedSubset) and
get_subset_type(layer.layer.subset_state) == 'spatial'):
self._expected_subset_layer_default(layer)
self._on_global_display_unit_changed()

def _is_spatial_subset(self, layer):
subset_state = getattr(layer.layer, 'subset_state', None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ def _warn_if_no_equation(self):

def _get_1d_spectrum(self):
# retrieves the 1d spectrum (accounting for spatial subset for cubeviz, if necessary)
return self.dataset.selected_spectrum_for_spatial_subset(self.spatial_subset_selected) # noqa
return self.dataset.selected_spectrum_for_spatial_subset(self.spatial_subset_selected,
use_display_units=True)

@observe("dataset_selected", "spatial_subset_selected")
def _dataset_selected_changed(self, event=None):
Expand Down Expand Up @@ -343,7 +344,6 @@ def _dataset_selected_changed(self, event=None):
# Replace NaNs from collapsed Spectrum1D in Cubeviz
# (won't affect calculations because these locations are masked)
selected_spec.flux[np.isnan(selected_spec.flux)] = 0.0

# TODO: can we simplify this logic?
self._units["x"] = str(
selected_spec.spectral_axis.unit)
Expand Down Expand Up @@ -508,7 +508,6 @@ def _initialize_model_component(self, model_comp, comp_label, poly_order=None):
"value": param_quant.value,
"unit": str(param_quant.unit),
"fixed": False})

self._initialized_models[comp_label] = initialized_model

new_model["Initialized"] = True
Expand Down Expand Up @@ -841,7 +840,6 @@ def _fit_model_to_spectrum(self, add_data):
models_to_fit = self._reinitialize_with_fixed()

masked_spectrum = self._apply_subset_masks(self._get_1d_spectrum(), self.spectral_subset)

try:
fitted_model, fitted_spectrum = fit_model_to_spectrum(
masked_spectrum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,18 @@ def _on_glue_y_display_unit_changed(self, y_unit):

@observe('spectral_unit_selected')
def _on_spectral_unit_changed(self, *args):
self.hub.broadcast(GlobalDisplayUnitChanged('spectral',
self.spectral_unit.selected,
sender=self))
xunit = _valid_glue_display_unit(self.spectral_unit.selected, self.spectrum_viewer, 'x')
if self.spectrum_viewer.state.x_display_unit != xunit:
self.spectrum_viewer.state.x_display_unit = xunit
self.hub.broadcast(GlobalDisplayUnitChanged('spectral',
self.spectral_unit.selected,
sender=self))

@observe('flux_unit_selected')
def _on_flux_unit_changed(self, *args):
self.hub.broadcast(GlobalDisplayUnitChanged('flux',
self.flux_unit.selected,
sender=self))
yunit = _valid_glue_display_unit(self.flux_unit.selected, self.spectrum_viewer, 'y')
if self.spectrum_viewer.state.y_display_unit != yunit:
self.spectrum_viewer.state.y_display_unit = yunit
self.hub.broadcast(GlobalDisplayUnitChanged('flux',
self.flux_unit.selected,
sender=self))
Loading