diff --git a/jdaviz/configs/default/default.yaml b/jdaviz/configs/default/default.yaml index 4dda8cbb79..ebf9e8449f 100644 --- a/jdaviz/configs/default/default.yaml +++ b/jdaviz/configs/default/default.yaml @@ -15,4 +15,4 @@ toolbar: tray: - g-subset-plugin - g-gaussian-smooth - - export + - export \ No newline at end of file diff --git a/jdaviz/configs/default/plugins/__init__.py b/jdaviz/configs/default/plugins/__init__.py index e516c2dd82..a79315363a 100644 --- a/jdaviz/configs/default/plugins/__init__.py +++ b/jdaviz/configs/default/plugins/__init__.py @@ -11,3 +11,4 @@ from .export.export import * # noqa from .plot_options.plot_options import * # noqa from .markers.markers import * # noqa +from .data_quality.data_quality import * # noqa diff --git a/jdaviz/configs/default/plugins/data_quality/__init__.py b/jdaviz/configs/default/plugins/data_quality/__init__.py index 8e834ef147..9fd5bcb8ff 100644 --- a/jdaviz/configs/default/plugins/data_quality/__init__.py +++ b/jdaviz/configs/default/plugins/data_quality/__init__.py @@ -1 +1 @@ -from .dq_utils import * # noqa +from .data_quality import * # noqa diff --git a/jdaviz/configs/default/plugins/data_quality/data_quality.py b/jdaviz/configs/default/plugins/data_quality/data_quality.py new file mode 100644 index 0000000000..3523e97b09 --- /dev/null +++ b/jdaviz/configs/default/plugins/data_quality/data_quality.py @@ -0,0 +1,197 @@ +import os +from traitlets import Any, Dict, Bool, List, Unicode, observe + +import numpy as np +from glue_jupyter.common.toolbar_vuetify import read_icon +from echo import delay_callback +from matplotlib.colors import hex2color + +from jdaviz.core.registries import tray_registry +from jdaviz.core.template_mixin import ( + PluginTemplateMixin, ViewerSelect, LayerSelect +) +from jdaviz.core.tools import ICON_DIR +from jdaviz.configs.default.plugins.data_quality.dq_utils import ( + decode_flags, generate_listed_colormap, dq_flag_map_paths, load_flag_map +) + +__all__ = ['DataQuality'] + +telescope_names = { + "jwst": "JWST", + "roman": "Roman" +} + + +@tray_registry('g-data-quality', label="Data Quality", viewer_requirements="image") +class DataQuality(PluginTemplateMixin): + template_file = __file__, "data_quality.vue" + + viewer_multiselect = Bool(False).tag(sync=True) + viewer_items = List().tag(sync=True) + viewer_selected = Any().tag(sync=True) # Any needed for multiselect + viewer_limits = Dict().tag(sync=True) + + # `layer` is the science data layer + science_layer_multiselect = Bool(False).tag(sync=True) + science_layer_items = List().tag(sync=True) + science_layer_selected = Any().tag(sync=True) # Any needed for multiselect + + # `dq_layer` is teh data quality layer corresponding to the + # science data in `layer` + dq_layer_multiselect = Bool(False).tag(sync=True) + dq_layer_items = List().tag(sync=True) + dq_layer_selected = Any().tag(sync=True) # Any needed for multiselect + + flag_map_definitions = Dict().tag(sync=True) + flag_map_selected = Any().tag(sync=True) + flag_map_items = List().tag(sync=True) + viewer_selected = Any().tag(sync=True) # Any needed for multiselect + decoded_flags = List().tag(sync=True) + + icons = Dict().tag(sync=True) + icon_radialtocheck = Unicode(read_icon(os.path.join(ICON_DIR, 'radialtocheck.svg'), 'svg+xml')).tag(sync=True) # noqa + icon_checktoradial = Unicode(read_icon(os.path.join(ICON_DIR, 'checktoradial.svg'), 'svg+xml')).tag(sync=True) # noqa + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.icons = {k: v for k, v in self.app.state.icons.items()} + + self.viewer = ViewerSelect( + self, 'viewer_items', 'viewer_selected', 'viewer_multiselect' + ) + self.science_layer = LayerSelect( + self, 'science_layer_items', 'science_layer_selected', + 'viewer_selected', 'science_layer_multiselect', is_root=True + ) + + self.dq_layer = LayerSelect( + self, 'dq_layer_items', 'dq_layer_selected', + 'viewer_selected', 'dq_layer_multiselect', is_root=False, + is_child_of=self.science_layer.selected + ) + + self.load_default_flag_maps() + self.init_decoding() + + @observe('science_layer_selected') + def update_dq_layer(self, *args): + if not hasattr(self, 'dq_layer'): + return + + self.dq_layer.filter_is_child_of = self.science_layer_selected + self.dq_layer._update_layer_items() + + def load_default_flag_maps(self): + for name in dq_flag_map_paths: + self.flag_map_definitions[name] = load_flag_map(name) + self.flag_map_items = self.flag_map_items + [telescope_names[name]] + + @property + def multiselect(self): + logging.warning(f"DeprecationWarning: multiselect has been replaced by separate viewer_multiselect and layer_multiselect and will be removed in the future. This currently evaluates viewer_multiselect or layer_multiselect") # noqa + return self.viewer_multiselect or self.layer_multiselect + + @multiselect.setter + def multiselect(self, value): + logging.warning(f"DeprecationWarning: multiselect has been replaced by separate viewer_multiselect and layer_multiselect and will be removed in the future. This currently sets viewer_multiselect and layer_multiselect") # noqa + self.viewer_multiselect = value + self.layer_multiselect = value + + def vue_set_value(self, data): + attr_name = data.get('name') + value = data.get('value') + setattr(self, attr_name, value) + + @property + def unique_flags(self): + selected_dq = self.dq_layer.selected_obj + if not len(selected_dq): + return [] + + dq = selected_dq[0].get_image_data() + return np.unique(dq[~np.isnan(dq)]) + + @property + def flag_map_definitions_selected(self): + return self.flag_map_definitions[self.flag_map_selected.lower()] + + @property + def validate_flag_decode_possible(self): + return ( + self.flag_map_selected is not None and + len(self.dq_layer.selected_obj) > 0 + ) + + @observe('dq_layer_selected') + def init_decoding(self, event={}): + if not self.validate_flag_decode_possible: + return + + unique_flags = self.unique_flags + cmap, rgba_colors = generate_listed_colormap(n_flags=len(unique_flags)) + self.decoded_flags = decode_flags( + flag_map=self.flag_map_definitions_selected, + unique_flags=unique_flags, + rgba_colors=rgba_colors + ) + + viewer = self.viewer.selected_obj + [dq_layer] = [ + layer for layer in viewer.layers if + layer.layer.label == self.dq_layer_selected + ] + dq_layer.composite._allow_bad_alpha = True + + flag_bits = np.float32([flag['flag'] for flag in self.decoded_flags]) + + with delay_callback(dq_layer.state, 'alpha', 'cmap', 'v_min', 'v_max', 'stretch'): + dq_layer.state.cmap = cmap + + dq_layer.state.stretch = 'lookup' + stretch_object = dq_layer.state.stretch_object + stretch_object.flags = flag_bits + + dq_layer.state.v_min = min(flag_bits) + dq_layer.state.v_max = max(flag_bits) + dq_layer.state.alpha = 0.9 + + @observe('decoded_flags') + def update_cmap(self, event={}): + viewer = self.viewer.selected_obj + [dq_layer] = [ + layer for layer in viewer.layers if + layer.layer.label == self.dq_layer_selected + ] + flag_bits = np.float32([flag['flag'] for flag in self.decoded_flags]) + rgb_colors = [hex2color(flag['color']) for flag in self.decoded_flags] + + # update the colors of the listed colormap without + # reassigning the layer.state.cmap object + cmap = dq_layer.state.cmap + cmap.colors = rgb_colors + cmap._init() + + with delay_callback(dq_layer.state, 'v_min', 'v_max', 'alpha'): + # trigger updates to cmap in viewer: + dq_layer.update() + + # set correct stretch and limits: + dq_layer.state.stretch = 'lookup' + dq_layer.state.v_min = min(flag_bits) + dq_layer.state.v_max = max(flag_bits) + dq_layer.state.alpha = 0.9 + + @observe('science_layer_selected') + def mission_or_instrument_from_meta(self, event): + if not hasattr(self, 'science_layer'): + return + + layer = self.science_layer.selected_obj + if len(layer): + # this is defined for JWST and ROMAN, should be upper case: + telescope = layer[0].layer.meta.get('telescope', None) + + if telescope is not None: + self.flag_map_selected = telescope_names[telescope.lower()] diff --git a/jdaviz/configs/default/plugins/data_quality/data_quality.vue b/jdaviz/configs/default/plugins/data_quality/data_quality.vue new file mode 100644 index 0000000000..4cfa4064cf --- /dev/null +++ b/jdaviz/configs/default/plugins/data_quality/data_quality.vue @@ -0,0 +1,124 @@ + + + + + + diff --git a/jdaviz/configs/default/plugins/data_quality/dq_utils.py b/jdaviz/configs/default/plugins/data_quality/dq_utils.py index 019fa603a7..940071295a 100644 --- a/jdaviz/configs/default/plugins/data_quality/dq_utils.py +++ b/jdaviz/configs/default/plugins/data_quality/dq_utils.py @@ -45,7 +45,7 @@ def load_flag_map(mission_or_instrument=None, path=None): flag_mapping = {} for flag, name, desc in flag_table.iterrows(): - flag_mapping[flag] = dict(name=name, description=desc) + flag_mapping[int(flag)] = dict(name=name, description=desc) return flag_mapping @@ -77,7 +77,7 @@ def write_flag_map(flag_mapping, csv_path, **kwargs): table.write(csv_path, format='ascii.csv', **kwargs) -def generate_listed_colormap(n_flags, seed=42): +def generate_listed_colormap(n_flags=None, seed=3): """ Generate a list of random "light" colors of length ``n_flags``. @@ -103,14 +103,14 @@ def generate_listed_colormap(n_flags, seed=42): # Generate random colors that are generally "light", i.e. with # RGB values in the upper half of the interval (0, 1): rgba_colors = [ - tuple(rng.uniform(low=0.5, high=1, size=3).tolist() + [default_alpha]) + tuple(np.insert(rng.uniform(size=2), rng.integers(0, 3), 1).tolist() + [default_alpha]) for _ in range(n_flags) ] cmap = ListedColormap(rgba_colors) # setting `bad` alpha=0 will make NaNs transparent: - cmap.set_bad(alpha=0) + cmap.set_bad(color='k', alpha=0) return cmap, rgba_colors diff --git a/jdaviz/configs/default/plugins/plot_options/plot_options.py b/jdaviz/configs/default/plugins/plot_options/plot_options.py index ef78823fdf..77dfa3600e 100644 --- a/jdaviz/configs/default/plugins/plot_options/plot_options.py +++ b/jdaviz/configs/default/plugins/plot_options/plot_options.py @@ -115,10 +115,61 @@ def update_knots(self, x, y): self.spline = PchipInterpolator(self._x, self._y) +class LookupStretch: + """ + Stretch class specific to DQ arrays. + + Attributes + ---------- + flags : array-like + DQ flags. + """ + + def __init__(self, flags=None): + # Default x, y values(0-1) range chosen for a typical initial spline shape. + # Can be modified if required. + if flags is None: + flags = np.linspace(0, 1, 5) + self.flags = np.asarray(flags) + + def __call__(self, values, out=None, clip=False): + # For our uses, we can ignore `out` and `clip`, but those would need + # to be implemented before contributing this class upstream. + + # find closest index in `self.flags` for each value in `values`: + if hasattr(values, 'squeeze'): + values = values.squeeze() + + # renormalize the flags on range (0, 1): + scaled_flags = self.flags / np.max(self.flags) + + # `values` will have already been passed through + # astropy.visualization.ManualInterval and normalized on (0, 1) + # before they arrive here. Now find the index of the closest entry in + # `scaled_flags` for each of `values` using array broadcasting. + min_indices = np.argmin(np.abs( + np.nan_to_num(values, nan=-10).flatten()[None, :] - scaled_flags[:, None] + ), axis=0).reshape(values.shape) + + # normalize by the number of flags, onto interval (0, 1): + renormed = min_indices / (len(self.flags) - 1) + + # preserve nans in the result: + renormed = np.where( + np.isnan(values), + np.nan, + renormed + ) + return renormed + + # Add the spline stretch to the glue stretch registry if not registered if "spline" not in stretches: stretches.add("spline", SplineStretch, display="Spline") +if "lookup" not in stretches: + stretches.add("lookup", LookupStretch, display="DQ") + def _round_step(step): # round the step for a float input diff --git a/jdaviz/configs/imviz/imviz.yaml b/jdaviz/configs/imviz/imviz.yaml index 4e7de049bb..2934cebabf 100644 --- a/jdaviz/configs/imviz/imviz.yaml +++ b/jdaviz/configs/imviz/imviz.yaml @@ -24,6 +24,7 @@ tray: - g-plot-options - g-subset-plugin - g-markers + - g-data-quality - imviz-compass - imviz-line-profile-xy - imviz-aper-phot-simple diff --git a/jdaviz/core/template_mixin.py b/jdaviz/core/template_mixin.py index 3a568b9f8e..a34594505f 100644 --- a/jdaviz/core/template_mixin.py +++ b/jdaviz/core/template_mixin.py @@ -1329,7 +1329,9 @@ def __init__(self, plugin, items, selected, viewer, multiselect=None, default_text=None, manual_options=[], default_mode='first', - only_wcs_layers=False): + only_wcs_layers=False, + is_root=True, + is_child_of=None): """ Parameters ---------- @@ -1382,11 +1384,24 @@ def __init__(self, plugin, items, selected, viewer, self._update_layer_items() self.update_wcs_only_filter(only_wcs_layers) - # ignore layers that are children in associations: - def is_parent(data): - return self.app._get_assoc_data_parent(data.label) is None + self.filter_is_root = is_root + self.filter_is_child_of = is_child_of - self.add_filter(is_parent) + if self.filter_is_root: + # ignore layers that are children in associations: + def filter_is_root(data): + return self.app._get_assoc_data_parent(data.label) is None + + self.add_filter(filter_is_root) + + elif not self.filter_is_root and self.filter_is_child_of is not None: + # only offer layers that are children of the correct parent: + def has_correct_parent(data): + if self.filter_is_child_of == '': + return False + return self.app._get_assoc_data_parent(data.label) == self.filter_is_child_of + + self.add_filter(has_correct_parent) def _get_viewer(self, viewer): # newer will likely be the viewer name in most cases, but viewer id in the case